feat: merge 2nd round from all rcc/* sessions

- api: tool_use parsing, message_delta, request_id tracking, retry logic
- tools: extended tool suite (WebSearch, WebFetch, Agent, etc.)
- cli: live streamed conversations, session restore, compact commands
- runtime: config loading, system prompt builder, token usage, compaction
This commit is contained in:
Yeachan-Heo
2026-03-31 17:43:25 +00:00
parent 44e4758078
commit 450556559a
23 changed files with 2388 additions and 3560 deletions

View File

@@ -1,9 +1,19 @@
use std::collections::VecDeque;
use std::time::Duration;
use serde::Deserialize;
use crate::error::ApiError;
use crate::sse::SseParser;
use crate::types::{MessageRequest, MessageResponse, StreamEvent};
const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
const ANTHROPIC_VERSION: &str = "2023-06-01";
const REQUEST_ID_HEADER: &str = "request-id";
const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200);
const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2);
const DEFAULT_MAX_RETRIES: u32 = 2;
#[derive(Debug, Clone)]
pub struct AnthropicClient {
@@ -11,6 +21,9 @@ pub struct AnthropicClient {
api_key: String,
auth_token: Option<String>,
base_url: String,
max_retries: u32,
initial_backoff: Duration,
max_backoff: Duration,
}
impl AnthropicClient {
@@ -21,6 +34,9 @@ impl AnthropicClient {
api_key: api_key.into(),
auth_token: None,
base_url: DEFAULT_BASE_URL.to_string(),
max_retries: DEFAULT_MAX_RETRIES,
initial_backoff: DEFAULT_INITIAL_BACKOFF,
max_backoff: DEFAULT_MAX_BACKOFF,
}
}
@@ -47,6 +63,19 @@ impl AnthropicClient {
self
}
#[must_use]
pub fn with_retry_policy(
mut self,
max_retries: u32,
initial_backoff: Duration,
max_backoff: Duration,
) -> Self {
self.max_retries = max_retries;
self.initial_backoff = initial_backoff;
self.max_backoff = max_backoff;
self
}
pub async fn send_message(
&self,
request: &MessageRequest,
@@ -55,12 +84,16 @@ impl AnthropicClient {
stream: false,
..request.clone()
};
let response = self.send_raw_request(&request).await?;
let response = expect_success(response).await?;
response
let response = self.send_with_retry(&request).await?;
let request_id = request_id_from_headers(response.headers());
let mut response = response
.json::<MessageResponse>()
.await
.map_err(ApiError::from)
.map_err(ApiError::from)?;
if response.request_id.is_none() {
response.request_id = request_id;
}
Ok(response)
}
pub async fn stream_message(
@@ -68,17 +101,53 @@ impl AnthropicClient {
request: &MessageRequest,
) -> Result<MessageStream, ApiError> {
let response = self
.send_raw_request(&request.clone().with_streaming())
.send_with_retry(&request.clone().with_streaming())
.await?;
let response = expect_success(response).await?;
Ok(MessageStream {
request_id: request_id_from_headers(response.headers()),
response,
parser: SseParser::new(),
pending: std::collections::VecDeque::new(),
pending: VecDeque::new(),
done: false,
})
}
async fn send_with_retry(
&self,
request: &MessageRequest,
) -> Result<reqwest::Response, ApiError> {
let mut attempts = 0;
let mut last_error: Option<ApiError>;
loop {
attempts += 1;
match self.send_raw_request(request).await {
Ok(response) => match expect_success(response).await {
Ok(response) => return Ok(response),
Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
last_error = Some(error);
}
Err(error) => return Err(error),
},
Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
last_error = Some(error);
}
Err(error) => return Err(error),
}
if attempts > self.max_retries {
break;
}
tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
}
Err(ApiError::RetriesExhausted {
attempts,
last_error: Box::new(last_error.expect("retry loop must capture an error")),
})
}
async fn send_raw_request(
&self,
request: &MessageRequest,
@@ -103,6 +172,19 @@ impl AnthropicClient {
.await
.map_err(ApiError::from)
}
fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
return Err(ApiError::BackoffOverflow {
attempt,
base_delay: self.initial_backoff,
});
};
Ok(self
.initial_backoff
.checked_mul(multiplier)
.map_or(self.max_backoff, |delay| delay.min(self.max_backoff)))
}
}
fn read_api_key(
@@ -116,15 +198,29 @@ fn read_api_key(
}
}
fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
headers
.get(REQUEST_ID_HEADER)
.or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
.and_then(|value| value.to_str().ok())
.map(ToOwned::to_owned)
}
#[derive(Debug)]
pub struct MessageStream {
request_id: Option<String>,
response: reqwest::Response,
parser: SseParser,
pending: std::collections::VecDeque<StreamEvent>,
pending: VecDeque<StreamEvent>,
done: bool,
}
impl MessageStream {
#[must_use]
pub fn request_id(&self) -> Option<&str> {
self.request_id.as_deref()
}
pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
loop {
if let Some(event) = self.pending.pop_front() {
@@ -159,14 +255,46 @@ async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response
}
let body = response.text().await.unwrap_or_else(|_| String::new());
Err(ApiError::UnexpectedStatus { status, body })
let parsed_error = serde_json::from_str::<AnthropicErrorEnvelope>(&body).ok();
let retryable = is_retryable_status(status);
Err(ApiError::Api {
status,
error_type: parsed_error
.as_ref()
.map(|error| error.error.error_type.clone()),
message: parsed_error
.as_ref()
.map(|error| error.error.message.clone()),
body,
retryable,
})
}
const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
}
#[derive(Debug, Deserialize)]
struct AnthropicErrorEnvelope {
error: AnthropicErrorBody,
}
#[derive(Debug, Deserialize)]
struct AnthropicErrorBody {
#[serde(rename = "type")]
error_type: String,
message: String,
}
#[cfg(test)]
mod tests {
use std::env::VarError;
use crate::types::MessageRequest;
use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER};
use std::time::Duration;
use crate::types::{ContentBlockDelta, MessageRequest};
#[test]
fn read_api_key_requires_presence() {
@@ -194,9 +322,76 @@ mod tests {
max_tokens: 64,
messages: vec![],
system: None,
tools: None,
tool_choice: None,
stream: false,
};
assert!(request.with_streaming().stream);
}
#[test]
fn backoff_doubles_until_maximum() {
let client = super::AnthropicClient::new("test-key").with_retry_policy(
3,
Duration::from_millis(10),
Duration::from_millis(25),
);
assert_eq!(
client.backoff_for_attempt(1).expect("attempt 1"),
Duration::from_millis(10)
);
assert_eq!(
client.backoff_for_attempt(2).expect("attempt 2"),
Duration::from_millis(20)
);
assert_eq!(
client.backoff_for_attempt(3).expect("attempt 3"),
Duration::from_millis(25)
);
}
#[test]
fn retryable_statuses_are_detected() {
assert!(super::is_retryable_status(
reqwest::StatusCode::TOO_MANY_REQUESTS
));
assert!(super::is_retryable_status(
reqwest::StatusCode::INTERNAL_SERVER_ERROR
));
assert!(!super::is_retryable_status(
reqwest::StatusCode::UNAUTHORIZED
));
}
#[test]
fn tool_delta_variant_round_trips() {
let delta = ContentBlockDelta::InputJsonDelta {
partial_json: "{\"city\":\"Paris\"}".to_string(),
};
let encoded = serde_json::to_string(&delta).expect("delta should serialize");
let decoded: ContentBlockDelta =
serde_json::from_str(&encoded).expect("delta should deserialize");
assert_eq!(decoded, delta);
}
#[test]
fn request_id_uses_primary_or_fallback_header() {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(REQUEST_ID_HEADER, "req_primary".parse().expect("header"));
assert_eq!(
super::request_id_from_headers(&headers).as_deref(),
Some("req_primary")
);
headers.clear();
headers.insert(
ALT_REQUEST_ID_HEADER,
"req_fallback".parse().expect("header"),
);
assert_eq!(
super::request_id_from_headers(&headers).as_deref(),
Some("req_fallback")
);
}
}

View File

@@ -1,5 +1,6 @@
use std::env::VarError;
use std::fmt::{Display, Formatter};
use std::time::Duration;
#[derive(Debug)]
pub enum ApiError {
@@ -8,11 +9,39 @@ pub enum ApiError {
Http(reqwest::Error),
Io(std::io::Error),
Json(serde_json::Error),
UnexpectedStatus {
Api {
status: reqwest::StatusCode,
error_type: Option<String>,
message: Option<String>,
body: String,
retryable: bool,
},
RetriesExhausted {
attempts: u32,
last_error: Box<ApiError>,
},
InvalidSseFrame(&'static str),
BackoffOverflow {
attempt: u32,
base_delay: Duration,
},
}
impl ApiError {
#[must_use]
pub fn is_retryable(&self) -> bool {
match self {
Self::Http(error) => error.is_connect() || error.is_timeout() || error.is_request(),
Self::Api { retryable, .. } => *retryable,
Self::RetriesExhausted { last_error, .. } => last_error.is_retryable(),
Self::MissingApiKey
| Self::InvalidApiKeyEnv(_)
| Self::Io(_)
| Self::Json(_)
| Self::InvalidSseFrame(_)
| Self::BackoffOverflow { .. } => false,
}
}
}
impl Display for ApiError {
@@ -30,10 +59,36 @@ impl Display for ApiError {
Self::Http(error) => write!(f, "http error: {error}"),
Self::Io(error) => write!(f, "io error: {error}"),
Self::Json(error) => write!(f, "json error: {error}"),
Self::UnexpectedStatus { status, body } => {
write!(f, "anthropic api returned {status}: {body}")
}
Self::Api {
status,
error_type,
message,
body,
..
} => match (error_type, message) {
(Some(error_type), Some(message)) => {
write!(
f,
"anthropic api returned {status} ({error_type}): {message}"
)
}
_ => write!(f, "anthropic api returned {status}: {body}"),
},
Self::RetriesExhausted {
attempts,
last_error,
} => write!(
f,
"anthropic api failed after {attempts} attempts: {last_error}"
),
Self::InvalidSseFrame(message) => write!(f, "invalid sse frame: {message}"),
Self::BackoffOverflow {
attempt,
base_delay,
} => write!(
f,
"retry backoff overflowed on attempt {attempt} with base delay {base_delay:?}"
),
}
}
}

View File

@@ -8,6 +8,7 @@ pub use error::ApiError;
pub use sse::{parse_frame, SseParser};
pub use types::{
ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
InputContentBlock, InputMessage, MessageRequest, MessageResponse, MessageStartEvent,
MessageStopEvent, OutputContentBlock, StreamEvent, Usage,
InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest,
MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent,
ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
};

View File

@@ -103,7 +103,7 @@ pub fn parse_frame(frame: &str) -> Result<Option<StreamEvent>, ApiError> {
#[cfg(test)]
mod tests {
use super::{parse_frame, SseParser};
use crate::types::{ContentBlockDelta, OutputContentBlock, StreamEvent};
use crate::types::{ContentBlockDelta, MessageDelta, OutputContentBlock, StreamEvent, Usage};
#[test]
fn parses_single_frame() {
@@ -158,6 +158,8 @@ mod tests {
": keepalive\n",
"event: ping\n",
"data: {\"type\":\"ping\"}\n\n",
"event: message_delta\n",
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":1,\"output_tokens\":2}}\n\n",
"event: message_stop\n",
"data: {\"type\":\"message_stop\"}\n\n",
"data: [DONE]\n\n"
@@ -168,7 +170,19 @@ mod tests {
.expect("parser should succeed");
assert_eq!(
events,
vec![StreamEvent::MessageStop(crate::types::MessageStopEvent {})]
vec![
StreamEvent::MessageDelta(crate::types::MessageDeltaEvent {
delta: MessageDelta {
stop_reason: Some("tool_use".to_string()),
stop_sequence: None,
},
usage: Usage {
input_tokens: 1,
output_tokens: 2,
},
}),
StreamEvent::MessageStop(crate::types::MessageStopEvent {}),
]
);
}

View File

@@ -1,12 +1,17 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MessageRequest {
pub model: String,
pub max_tokens: u32,
pub messages: Vec<InputMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ToolDefinition>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub stream: bool,
}
@@ -19,7 +24,7 @@ impl MessageRequest {
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct InputMessage {
pub role: String,
pub content: Vec<InputContentBlock>,
@@ -33,15 +38,64 @@ impl InputMessage {
content: vec![InputContentBlock::Text { text: text.into() }],
}
}
#[must_use]
pub fn user_tool_result(
tool_use_id: impl Into<String>,
content: impl Into<String>,
is_error: bool,
) -> Self {
Self {
role: "user".to_string(),
content: vec![InputContentBlock::ToolResult {
tool_use_id: tool_use_id.into(),
content: vec![ToolResultContentBlock::Text {
text: content.into(),
}],
is_error,
}],
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum InputContentBlock {
Text {
text: String,
},
ToolResult {
tool_use_id: String,
content: Vec<ToolResultContentBlock>,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
is_error: bool,
},
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolResultContentBlock {
Text { text: String },
Json { value: Value },
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub input_schema: Value,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum InputContentBlock {
Text { text: String },
pub enum ToolChoice {
Auto,
Any,
Tool { name: String },
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MessageResponse {
pub id: String,
#[serde(rename = "type")]
@@ -54,12 +108,28 @@ pub struct MessageResponse {
#[serde(default)]
pub stop_sequence: Option<String>,
pub usage: Usage,
#[serde(default)]
pub request_id: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
impl MessageResponse {
#[must_use]
pub fn total_tokens(&self) -> u32 {
self.usage.total_tokens()
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum OutputContentBlock {
Text { text: String },
Text {
text: String,
},
ToolUse {
id: String,
name: String,
input: Value,
},
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
@@ -68,18 +138,39 @@ pub struct Usage {
pub output_tokens: u32,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
impl Usage {
#[must_use]
pub const fn total_tokens(&self) -> u32 {
self.input_tokens + self.output_tokens
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MessageStartEvent {
pub message: MessageResponse,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MessageDeltaEvent {
pub delta: MessageDelta,
pub usage: Usage,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct MessageDelta {
#[serde(default)]
pub stop_reason: Option<String>,
#[serde(default)]
pub stop_sequence: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ContentBlockStartEvent {
pub index: u32,
pub content_block: OutputContentBlock,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ContentBlockDeltaEvent {
pub index: u32,
pub delta: ContentBlockDelta,
@@ -89,6 +180,7 @@ pub struct ContentBlockDeltaEvent {
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlockDelta {
TextDelta { text: String },
InputJsonDelta { partial_json: String },
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
@@ -99,10 +191,11 @@ pub struct ContentBlockStopEvent {
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct MessageStopEvent {}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum StreamEvent {
MessageStart(MessageStartEvent),
MessageDelta(MessageDeltaEvent),
ContentBlockStart(ContentBlockStartEvent),
ContentBlockDelta(ContentBlockDeltaEvent),
ContentBlockStop(ContentBlockStopEvent),

View File

@@ -1,7 +1,13 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use api::{AnthropicClient, InputMessage, MessageRequest, OutputContentBlock, StreamEvent};
use api::{
AnthropicClient, ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent,
InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OutputContentBlock,
StreamEvent, ToolChoice, ToolDefinition,
};
use serde_json::json;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::sync::Mutex;
@@ -18,10 +24,15 @@ async fn send_message_posts_json_and_parses_response() {
"\"model\":\"claude-3-7-sonnet-latest\",",
"\"stop_reason\":\"end_turn\",",
"\"stop_sequence\":null,",
"\"usage\":{\"input_tokens\":12,\"output_tokens\":4}",
"\"usage\":{\"input_tokens\":12,\"output_tokens\":4},",
"\"request_id\":\"req_body_123\"",
"}"
);
let server = spawn_server(state.clone(), http_response("application/json", body)).await;
let server = spawn_server(
state.clone(),
vec![http_response("200 OK", "application/json", body)],
)
.await;
let client = AnthropicClient::new("test-key")
.with_auth_token(Some("proxy-token".to_string()))
@@ -32,6 +43,8 @@ async fn send_message_posts_json_and_parses_response() {
.expect("request should succeed");
assert_eq!(response.id, "msg_test");
assert_eq!(response.total_tokens(), 16);
assert_eq!(response.request_id.as_deref(), Some("req_body_123"));
assert_eq!(
response.content,
vec![OutputContentBlock::Text {
@@ -51,39 +64,45 @@ async fn send_message_posts_json_and_parses_response() {
request.headers.get("authorization").map(String::as_str),
Some("Bearer proxy-token")
);
assert_eq!(
request.headers.get("anthropic-version").map(String::as_str),
Some("2023-06-01")
);
let body: serde_json::Value =
serde_json::from_str(&request.body).expect("request body should be json");
assert_eq!(
body.get("model").and_then(serde_json::Value::as_str),
Some("claude-3-7-sonnet-latest")
);
assert!(
body.get("stream").is_none(),
"non-stream request should omit stream=false"
);
assert!(body.get("stream").is_none());
assert_eq!(body["tools"][0]["name"], json!("get_weather"));
assert_eq!(body["tool_choice"]["type"], json!("auto"));
}
#[tokio::test]
async fn stream_message_parses_sse_events() {
async fn stream_message_parses_sse_events_with_tool_use() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let sse = concat!(
"event: message_start\n",
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"output_tokens\":0}}}\n\n",
"event: content_block_start\n",
"data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n",
"data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_123\",\"name\":\"get_weather\",\"input\":{}}}\n\n",
"event: content_block_delta\n",
"data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n\n",
"data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"city\\\":\\\"Paris\\\"}\"}}\n\n",
"event: content_block_stop\n",
"data: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
"event: message_delta\n",
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"output_tokens\":1}}\n\n",
"event: message_stop\n",
"data: {\"type\":\"message_stop\"}\n\n",
"data: [DONE]\n\n"
);
let server = spawn_server(state.clone(), http_response("text/event-stream", sse)).await;
let server = spawn_server(
state.clone(),
vec![http_response_with_headers(
"200 OK",
"text/event-stream",
sse,
&[("request-id", "req_stream_456")],
)],
)
.await;
let client = AnthropicClient::new("test-key")
.with_auth_token(Some("proxy-token".to_string()))
@@ -93,6 +112,8 @@ async fn stream_message_parses_sse_events() {
.await
.expect("stream should start");
assert_eq!(stream.request_id(), Some("req_stream_456"));
let mut events = Vec::new();
while let Some(event) = stream
.next_event()
@@ -102,18 +123,126 @@ async fn stream_message_parses_sse_events() {
events.push(event);
}
assert_eq!(events.len(), 5);
assert_eq!(events.len(), 6);
assert!(matches!(events[0], StreamEvent::MessageStart(_)));
assert!(matches!(events[1], StreamEvent::ContentBlockStart(_)));
assert!(matches!(events[2], StreamEvent::ContentBlockDelta(_)));
assert!(matches!(
events[1],
StreamEvent::ContentBlockStart(ContentBlockStartEvent {
content_block: OutputContentBlock::ToolUse { .. },
..
})
));
assert!(matches!(
events[2],
StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
delta: ContentBlockDelta::InputJsonDelta { .. },
..
})
));
assert!(matches!(events[3], StreamEvent::ContentBlockStop(_)));
assert!(matches!(events[4], StreamEvent::MessageStop(_)));
assert!(matches!(
events[4],
StreamEvent::MessageDelta(MessageDeltaEvent { .. })
));
assert!(matches!(events[5], StreamEvent::MessageStop(_)));
match &events[1] {
StreamEvent::ContentBlockStart(ContentBlockStartEvent {
content_block: OutputContentBlock::ToolUse { name, input, .. },
..
}) => {
assert_eq!(name, "get_weather");
assert_eq!(input, &json!({}));
}
other => panic!("expected tool_use block, got {other:?}"),
}
let captured = state.lock().await;
let request = captured.first().expect("server should capture request");
assert!(request.body.contains("\"stream\":true"));
}
#[tokio::test]
async fn retries_retryable_failures_before_succeeding() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let server = spawn_server(
state.clone(),
vec![
http_response(
"429 Too Many Requests",
"application/json",
"{\"type\":\"error\",\"error\":{\"type\":\"rate_limit_error\",\"message\":\"slow down\"}}",
),
http_response(
"200 OK",
"application/json",
"{\"id\":\"msg_retry\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Recovered\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}",
),
],
)
.await;
let client = AnthropicClient::new("test-key")
.with_base_url(server.base_url())
.with_retry_policy(2, Duration::from_millis(1), Duration::from_millis(2));
let response = client
.send_message(&sample_request(false))
.await
.expect("retry should eventually succeed");
assert_eq!(response.total_tokens(), 5);
assert_eq!(state.lock().await.len(), 2);
}
#[tokio::test]
async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let server = spawn_server(
state.clone(),
vec![
http_response(
"503 Service Unavailable",
"application/json",
"{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"busy\"}}",
),
http_response(
"503 Service Unavailable",
"application/json",
"{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"still busy\"}}",
),
],
)
.await;
let client = AnthropicClient::new("test-key")
.with_base_url(server.base_url())
.with_retry_policy(1, Duration::from_millis(1), Duration::from_millis(2));
let error = client
.send_message(&sample_request(false))
.await
.expect_err("persistent 503 should fail");
match error {
ApiError::RetriesExhausted {
attempts,
last_error,
} => {
assert_eq!(attempts, 2);
assert!(matches!(
*last_error,
ApiError::Api {
status: reqwest::StatusCode::SERVICE_UNAVAILABLE,
retryable: true,
..
}
));
}
other => panic!("expected retries exhausted, got {other:?}"),
}
}
#[tokio::test]
#[ignore = "requires ANTHROPIC_API_KEY and network access"]
async fn live_stream_smoke_test() {
@@ -127,51 +256,18 @@ async fn live_stream_smoke_test() {
"Reply with exactly: hello from rust",
)],
system: None,
tools: None,
tool_choice: None,
stream: false,
})
.await
.expect("live stream should start");
let mut saw_start = false;
let mut saw_follow_up = false;
let mut event_kinds = Vec::new();
while let Some(event) = stream
while let Some(_event) = stream
.next_event()
.await
.expect("live stream should yield events")
{
match event {
StreamEvent::MessageStart(_) => {
saw_start = true;
event_kinds.push("message_start");
}
StreamEvent::ContentBlockStart(_) => {
saw_follow_up = true;
event_kinds.push("content_block_start");
}
StreamEvent::ContentBlockDelta(_) => {
saw_follow_up = true;
event_kinds.push("content_block_delta");
}
StreamEvent::ContentBlockStop(_) => {
saw_follow_up = true;
event_kinds.push("content_block_stop");
}
StreamEvent::MessageStop(_) => {
saw_follow_up = true;
event_kinds.push("message_stop");
}
}
}
assert!(
saw_start,
"expected a message_start event; got {event_kinds:?}"
);
assert!(
saw_follow_up,
"expected at least one follow-up stream event; got {event_kinds:?}"
);
{}
}
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -199,7 +295,10 @@ impl Drop for TestServer {
}
}
async fn spawn_server(state: Arc<Mutex<Vec<CapturedRequest>>>, response: String) -> TestServer {
async fn spawn_server(
state: Arc<Mutex<Vec<CapturedRequest>>>,
responses: Vec<String>,
) -> TestServer {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
@@ -207,72 +306,75 @@ async fn spawn_server(state: Arc<Mutex<Vec<CapturedRequest>>>, response: String)
.local_addr()
.expect("listener should have local addr");
let join_handle = tokio::spawn(async move {
let (mut socket, _) = listener.accept().await.expect("server should accept");
let mut buffer = Vec::new();
let mut header_end = None;
for response in responses {
let (mut socket, _) = listener.accept().await.expect("server should accept");
let mut buffer = Vec::new();
let mut header_end = None;
loop {
let mut chunk = [0_u8; 1024];
let read = socket
.read(&mut chunk)
loop {
let mut chunk = [0_u8; 1024];
let read = socket
.read(&mut chunk)
.await
.expect("request read should succeed");
if read == 0 {
break;
}
buffer.extend_from_slice(&chunk[..read]);
if let Some(position) = find_header_end(&buffer) {
header_end = Some(position);
break;
}
}
let header_end = header_end.expect("request should include headers");
let (header_bytes, remaining) = buffer.split_at(header_end);
let header_text =
String::from_utf8(header_bytes.to_vec()).expect("headers should be utf8");
let mut lines = header_text.split("\r\n");
let request_line = lines.next().expect("request line should exist");
let mut parts = request_line.split_whitespace();
let method = parts.next().expect("method should exist").to_string();
let path = parts.next().expect("path should exist").to_string();
let mut headers = HashMap::new();
let mut content_length = 0_usize;
for line in lines {
if line.is_empty() {
continue;
}
let (name, value) = line.split_once(':').expect("header should have colon");
let value = value.trim().to_string();
if name.eq_ignore_ascii_case("content-length") {
content_length = value.parse().expect("content length should parse");
}
headers.insert(name.to_ascii_lowercase(), value);
}
let mut body = remaining[4..].to_vec();
while body.len() < content_length {
let mut chunk = vec![0_u8; content_length - body.len()];
let read = socket
.read(&mut chunk)
.await
.expect("body read should succeed");
if read == 0 {
break;
}
body.extend_from_slice(&chunk[..read]);
}
state.lock().await.push(CapturedRequest {
method,
path,
headers,
body: String::from_utf8(body).expect("body should be utf8"),
});
socket
.write_all(response.as_bytes())
.await
.expect("request read should succeed");
if read == 0 {
break;
}
buffer.extend_from_slice(&chunk[..read]);
if let Some(position) = find_header_end(&buffer) {
header_end = Some(position);
break;
}
.expect("response write should succeed");
}
let header_end = header_end.expect("request should include headers");
let (header_bytes, remaining) = buffer.split_at(header_end);
let header_text = String::from_utf8(header_bytes.to_vec()).expect("headers should be utf8");
let mut lines = header_text.split("\r\n");
let request_line = lines.next().expect("request line should exist");
let mut parts = request_line.split_whitespace();
let method = parts.next().expect("method should exist").to_string();
let path = parts.next().expect("path should exist").to_string();
let mut headers = HashMap::new();
let mut content_length = 0_usize;
for line in lines {
if line.is_empty() {
continue;
}
let (name, value) = line.split_once(':').expect("header should have colon");
let value = value.trim().to_string();
if name.eq_ignore_ascii_case("content-length") {
content_length = value.parse().expect("content length should parse");
}
headers.insert(name.to_ascii_lowercase(), value);
}
let mut body = remaining[4..].to_vec();
while body.len() < content_length {
let mut chunk = vec![0_u8; content_length - body.len()];
let read = socket
.read(&mut chunk)
.await
.expect("body read should succeed");
if read == 0 {
break;
}
body.extend_from_slice(&chunk[..read]);
}
state.lock().await.push(CapturedRequest {
method,
path,
headers,
body: String::from_utf8(body).expect("body should be utf8"),
});
socket
.write_all(response.as_bytes())
.await
.expect("response write should succeed");
});
TestServer {
@@ -285,9 +387,23 @@ fn find_header_end(bytes: &[u8]) -> Option<usize> {
bytes.windows(4).position(|window| window == b"\r\n\r\n")
}
fn http_response(content_type: &str, body: &str) -> String {
fn http_response(status: &str, content_type: &str, body: &str) -> String {
http_response_with_headers(status, content_type, body, &[])
}
fn http_response_with_headers(
status: &str,
content_type: &str,
body: &str,
headers: &[(&str, &str)],
) -> String {
let mut extra_headers = String::new();
for (name, value) in headers {
use std::fmt::Write as _;
write!(&mut extra_headers, "{name}: {value}\r\n").expect("header write should succeed");
}
format!(
"HTTP/1.1 200 OK\r\ncontent-type: {content_type}\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
"HTTP/1.1 {status}\r\ncontent-type: {content_type}\r\n{extra_headers}content-length: {}\r\nconnection: close\r\n\r\n{body}",
body.len()
)
}
@@ -296,8 +412,32 @@ fn sample_request(stream: bool) -> MessageRequest {
MessageRequest {
model: "claude-3-7-sonnet-latest".to_string(),
max_tokens: 64,
messages: vec![InputMessage::user_text("Say hello")],
system: None,
messages: vec![InputMessage {
role: "user".to_string(),
content: vec![
InputContentBlock::Text {
text: "Say hello".to_string(),
},
InputContentBlock::ToolResult {
tool_use_id: "toolu_prev".to_string(),
content: vec![api::ToolResultContentBlock::Json {
value: json!({"forecast": "sunny"}),
}],
is_error: false,
},
],
}],
system: Some("Use tools when needed".to_string()),
tools: Some(vec![ToolDefinition {
name: "get_weather".to_string(),
description: Some("Fetches the weather".to_string()),
input_schema: json!({
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"]
}),
}]),
tool_choice: Some(ToolChoice::Auto),
stream,
}
}