feat: merge 2nd round from all rcc/* sessions
- api: tool_use parsing, message_delta, request_id tracking, retry logic - tools: extended tool suite (WebSearch, WebFetch, Agent, etc.) - cli: live streamed conversations, session restore, compact commands - runtime: config loading, system prompt builder, token usage, compaction
This commit is contained in:
@@ -1,9 +1,19 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::time::Duration;
|
||||
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::error::ApiError;
|
||||
use crate::sse::SseParser;
|
||||
use crate::types::{MessageRequest, MessageResponse, StreamEvent};
|
||||
|
||||
const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
|
||||
const ANTHROPIC_VERSION: &str = "2023-06-01";
|
||||
const REQUEST_ID_HEADER: &str = "request-id";
|
||||
const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
|
||||
const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200);
|
||||
const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2);
|
||||
const DEFAULT_MAX_RETRIES: u32 = 2;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AnthropicClient {
|
||||
@@ -11,6 +21,9 @@ pub struct AnthropicClient {
|
||||
api_key: String,
|
||||
auth_token: Option<String>,
|
||||
base_url: String,
|
||||
max_retries: u32,
|
||||
initial_backoff: Duration,
|
||||
max_backoff: Duration,
|
||||
}
|
||||
|
||||
impl AnthropicClient {
|
||||
@@ -21,6 +34,9 @@ impl AnthropicClient {
|
||||
api_key: api_key.into(),
|
||||
auth_token: None,
|
||||
base_url: DEFAULT_BASE_URL.to_string(),
|
||||
max_retries: DEFAULT_MAX_RETRIES,
|
||||
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
||||
max_backoff: DEFAULT_MAX_BACKOFF,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,6 +63,19 @@ impl AnthropicClient {
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_retry_policy(
|
||||
mut self,
|
||||
max_retries: u32,
|
||||
initial_backoff: Duration,
|
||||
max_backoff: Duration,
|
||||
) -> Self {
|
||||
self.max_retries = max_retries;
|
||||
self.initial_backoff = initial_backoff;
|
||||
self.max_backoff = max_backoff;
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn send_message(
|
||||
&self,
|
||||
request: &MessageRequest,
|
||||
@@ -55,12 +84,16 @@ impl AnthropicClient {
|
||||
stream: false,
|
||||
..request.clone()
|
||||
};
|
||||
let response = self.send_raw_request(&request).await?;
|
||||
let response = expect_success(response).await?;
|
||||
response
|
||||
let response = self.send_with_retry(&request).await?;
|
||||
let request_id = request_id_from_headers(response.headers());
|
||||
let mut response = response
|
||||
.json::<MessageResponse>()
|
||||
.await
|
||||
.map_err(ApiError::from)
|
||||
.map_err(ApiError::from)?;
|
||||
if response.request_id.is_none() {
|
||||
response.request_id = request_id;
|
||||
}
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn stream_message(
|
||||
@@ -68,17 +101,53 @@ impl AnthropicClient {
|
||||
request: &MessageRequest,
|
||||
) -> Result<MessageStream, ApiError> {
|
||||
let response = self
|
||||
.send_raw_request(&request.clone().with_streaming())
|
||||
.send_with_retry(&request.clone().with_streaming())
|
||||
.await?;
|
||||
let response = expect_success(response).await?;
|
||||
Ok(MessageStream {
|
||||
request_id: request_id_from_headers(response.headers()),
|
||||
response,
|
||||
parser: SseParser::new(),
|
||||
pending: std::collections::VecDeque::new(),
|
||||
pending: VecDeque::new(),
|
||||
done: false,
|
||||
})
|
||||
}
|
||||
|
||||
async fn send_with_retry(
|
||||
&self,
|
||||
request: &MessageRequest,
|
||||
) -> Result<reqwest::Response, ApiError> {
|
||||
let mut attempts = 0;
|
||||
let mut last_error: Option<ApiError>;
|
||||
|
||||
loop {
|
||||
attempts += 1;
|
||||
match self.send_raw_request(request).await {
|
||||
Ok(response) => match expect_success(response).await {
|
||||
Ok(response) => return Ok(response),
|
||||
Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
|
||||
last_error = Some(error);
|
||||
}
|
||||
Err(error) => return Err(error),
|
||||
},
|
||||
Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
|
||||
last_error = Some(error);
|
||||
}
|
||||
Err(error) => return Err(error),
|
||||
}
|
||||
|
||||
if attempts > self.max_retries {
|
||||
break;
|
||||
}
|
||||
|
||||
tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
|
||||
}
|
||||
|
||||
Err(ApiError::RetriesExhausted {
|
||||
attempts,
|
||||
last_error: Box::new(last_error.expect("retry loop must capture an error")),
|
||||
})
|
||||
}
|
||||
|
||||
async fn send_raw_request(
|
||||
&self,
|
||||
request: &MessageRequest,
|
||||
@@ -103,6 +172,19 @@ impl AnthropicClient {
|
||||
.await
|
||||
.map_err(ApiError::from)
|
||||
}
|
||||
|
||||
fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
|
||||
let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
|
||||
return Err(ApiError::BackoffOverflow {
|
||||
attempt,
|
||||
base_delay: self.initial_backoff,
|
||||
});
|
||||
};
|
||||
Ok(self
|
||||
.initial_backoff
|
||||
.checked_mul(multiplier)
|
||||
.map_or(self.max_backoff, |delay| delay.min(self.max_backoff)))
|
||||
}
|
||||
}
|
||||
|
||||
fn read_api_key(
|
||||
@@ -116,15 +198,29 @@ fn read_api_key(
|
||||
}
|
||||
}
|
||||
|
||||
fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
|
||||
headers
|
||||
.get(REQUEST_ID_HEADER)
|
||||
.or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.map(ToOwned::to_owned)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MessageStream {
|
||||
request_id: Option<String>,
|
||||
response: reqwest::Response,
|
||||
parser: SseParser,
|
||||
pending: std::collections::VecDeque<StreamEvent>,
|
||||
pending: VecDeque<StreamEvent>,
|
||||
done: bool,
|
||||
}
|
||||
|
||||
impl MessageStream {
|
||||
#[must_use]
|
||||
pub fn request_id(&self) -> Option<&str> {
|
||||
self.request_id.as_deref()
|
||||
}
|
||||
|
||||
pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
|
||||
loop {
|
||||
if let Some(event) = self.pending.pop_front() {
|
||||
@@ -159,14 +255,46 @@ async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response
|
||||
}
|
||||
|
||||
let body = response.text().await.unwrap_or_else(|_| String::new());
|
||||
Err(ApiError::UnexpectedStatus { status, body })
|
||||
let parsed_error = serde_json::from_str::<AnthropicErrorEnvelope>(&body).ok();
|
||||
let retryable = is_retryable_status(status);
|
||||
|
||||
Err(ApiError::Api {
|
||||
status,
|
||||
error_type: parsed_error
|
||||
.as_ref()
|
||||
.map(|error| error.error.error_type.clone()),
|
||||
message: parsed_error
|
||||
.as_ref()
|
||||
.map(|error| error.error.message.clone()),
|
||||
body,
|
||||
retryable,
|
||||
})
|
||||
}
|
||||
|
||||
const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
|
||||
matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AnthropicErrorEnvelope {
|
||||
error: AnthropicErrorBody,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AnthropicErrorBody {
|
||||
#[serde(rename = "type")]
|
||||
error_type: String,
|
||||
message: String,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::env::VarError;
|
||||
|
||||
use crate::types::MessageRequest;
|
||||
use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER};
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::types::{ContentBlockDelta, MessageRequest};
|
||||
|
||||
#[test]
|
||||
fn read_api_key_requires_presence() {
|
||||
@@ -194,9 +322,76 @@ mod tests {
|
||||
max_tokens: 64,
|
||||
messages: vec![],
|
||||
system: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
assert!(request.with_streaming().stream);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn backoff_doubles_until_maximum() {
|
||||
let client = super::AnthropicClient::new("test-key").with_retry_policy(
|
||||
3,
|
||||
Duration::from_millis(10),
|
||||
Duration::from_millis(25),
|
||||
);
|
||||
assert_eq!(
|
||||
client.backoff_for_attempt(1).expect("attempt 1"),
|
||||
Duration::from_millis(10)
|
||||
);
|
||||
assert_eq!(
|
||||
client.backoff_for_attempt(2).expect("attempt 2"),
|
||||
Duration::from_millis(20)
|
||||
);
|
||||
assert_eq!(
|
||||
client.backoff_for_attempt(3).expect("attempt 3"),
|
||||
Duration::from_millis(25)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn retryable_statuses_are_detected() {
|
||||
assert!(super::is_retryable_status(
|
||||
reqwest::StatusCode::TOO_MANY_REQUESTS
|
||||
));
|
||||
assert!(super::is_retryable_status(
|
||||
reqwest::StatusCode::INTERNAL_SERVER_ERROR
|
||||
));
|
||||
assert!(!super::is_retryable_status(
|
||||
reqwest::StatusCode::UNAUTHORIZED
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_delta_variant_round_trips() {
|
||||
let delta = ContentBlockDelta::InputJsonDelta {
|
||||
partial_json: "{\"city\":\"Paris\"}".to_string(),
|
||||
};
|
||||
let encoded = serde_json::to_string(&delta).expect("delta should serialize");
|
||||
let decoded: ContentBlockDelta =
|
||||
serde_json::from_str(&encoded).expect("delta should deserialize");
|
||||
assert_eq!(decoded, delta);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_id_uses_primary_or_fallback_header() {
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(REQUEST_ID_HEADER, "req_primary".parse().expect("header"));
|
||||
assert_eq!(
|
||||
super::request_id_from_headers(&headers).as_deref(),
|
||||
Some("req_primary")
|
||||
);
|
||||
|
||||
headers.clear();
|
||||
headers.insert(
|
||||
ALT_REQUEST_ID_HEADER,
|
||||
"req_fallback".parse().expect("header"),
|
||||
);
|
||||
assert_eq!(
|
||||
super::request_id_from_headers(&headers).as_deref(),
|
||||
Some("req_fallback")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user