From 0cf2204d43700cc53edf4108b735ea83c8e76f8c Mon Sep 17 00:00:00 2001 From: Yeachan-Heo Date: Wed, 1 Apr 2026 04:30:24 +0000 Subject: [PATCH] wip: cache-tracking progress --- rust/crates/api/src/client.rs | 91 ++- rust/crates/api/src/lib.rs | 5 + rust/crates/api/src/prompt_cache.rs | 679 ++++++++++++++++++++ rust/crates/api/tests/client_integration.rs | 146 ++++- 4 files changed, 916 insertions(+), 5 deletions(-) create mode 100644 rust/crates/api/src/prompt_cache.rs diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs index 7ef7e83..4d264b5 100644 --- a/rust/crates/api/src/client.rs +++ b/rust/crates/api/src/client.rs @@ -8,8 +8,9 @@ use runtime::{ use serde::Deserialize; use crate::error::ApiError; +use crate::prompt_cache::{PromptCache, PromptCacheStats}; use crate::sse::SseParser; -use crate::types::{MessageRequest, MessageResponse, StreamEvent}; +use crate::types::{MessageRequest, MessageResponse, StreamEvent, Usage}; const DEFAULT_BASE_URL: &str = "https://api.anthropic.com"; const ANTHROPIC_VERSION: &str = "2023-06-01"; @@ -108,6 +109,7 @@ pub struct AnthropicClient { max_retries: u32, initial_backoff: Duration, max_backoff: Duration, + prompt_cache: Option, } impl AnthropicClient { @@ -120,6 +122,7 @@ impl AnthropicClient { max_retries: DEFAULT_MAX_RETRIES, initial_backoff: DEFAULT_INITIAL_BACKOFF, max_backoff: DEFAULT_MAX_BACKOFF, + prompt_cache: None, } } @@ -132,6 +135,7 @@ impl AnthropicClient { max_retries: DEFAULT_MAX_RETRIES, initial_backoff: DEFAULT_INITIAL_BACKOFF, max_backoff: DEFAULT_MAX_BACKOFF, + prompt_cache: None, } } @@ -189,6 +193,22 @@ impl AnthropicClient { self } + #[must_use] + pub fn with_prompt_cache(mut self, prompt_cache: PromptCache) -> Self { + self.prompt_cache = Some(prompt_cache); + self + } + + #[must_use] + pub fn prompt_cache(&self) -> Option<&PromptCache> { + self.prompt_cache.as_ref() + } + + #[must_use] + pub fn prompt_cache_stats(&self) -> Option { + self.prompt_cache.as_ref().map(PromptCache::stats) + } + #[must_use] pub fn auth_source(&self) -> &AuthSource { &self.auth @@ -202,6 +222,11 @@ impl AnthropicClient { stream: false, ..request.clone() }; + if let Some(prompt_cache) = &self.prompt_cache { + if let Some(response) = prompt_cache.lookup_completion(&request) { + return Ok(response); + } + } let response = self.send_with_retry(&request).await?; let request_id = request_id_from_headers(response.headers()); let mut response = response @@ -211,6 +236,9 @@ impl AnthropicClient { if response.request_id.is_none() { response.request_id = request_id; } + if let Some(prompt_cache) = &self.prompt_cache { + let _ = prompt_cache.record_response(&request, &response); + } Ok(response) } @@ -227,6 +255,15 @@ impl AnthropicClient { parser: SseParser::new(), pending: VecDeque::new(), done: false, + cache_tracking: self + .prompt_cache + .as_ref() + .map(|prompt_cache| StreamCacheTracking { + prompt_cache: prompt_cache.clone(), + request: request.clone().with_streaming(), + last_usage: None, + finalized: false, + }), }) } @@ -527,6 +564,7 @@ pub struct MessageStream { parser: SseParser, pending: VecDeque, done: bool, + cache_tracking: Option, } impl MessageStream { @@ -538,6 +576,9 @@ impl MessageStream { pub async fn next_event(&mut self) -> Result, ApiError> { loop { if let Some(event) = self.pending.pop_front() { + if let Some(cache_tracking) = &mut self.cache_tracking { + cache_tracking.observe(&event); + } return Ok(Some(event)); } @@ -545,8 +586,14 @@ impl MessageStream { let remaining = self.parser.finish()?; self.pending.extend(remaining); if let Some(event) = self.pending.pop_front() { + if let Some(cache_tracking) = &mut self.cache_tracking { + cache_tracking.observe(&event); + } return Ok(Some(event)); } + if let Some(cache_tracking) = &mut self.cache_tracking { + cache_tracking.finalize(); + } return Ok(None); } @@ -562,6 +609,41 @@ impl MessageStream { } } +#[derive(Debug, Clone)] +struct StreamCacheTracking { + prompt_cache: PromptCache, + request: MessageRequest, + last_usage: Option, + finalized: bool, +} + +impl StreamCacheTracking { + fn observe(&mut self, event: &StreamEvent) { + match event { + StreamEvent::MessageStart(event) => { + self.last_usage = Some(event.message.usage.clone()); + } + StreamEvent::MessageDelta(event) => { + self.last_usage = Some(event.usage.clone()); + } + StreamEvent::ContentBlockStart(_) + | StreamEvent::ContentBlockDelta(_) + | StreamEvent::ContentBlockStop(_) + | StreamEvent::MessageStop(_) => {} + } + } + + fn finalize(&mut self) { + if self.finalized { + return; + } + if let Some(usage) = &self.last_usage { + let _ = self.prompt_cache.record_usage(&self.request, usage); + } + self.finalized = true; + } +} + async fn expect_success(response: reqwest::Response) -> Result { let status = response.status(); if status.is_success() { @@ -606,6 +688,7 @@ mod tests { use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER}; use std::io::{Read, Write}; use std::net::TcpListener; + use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Mutex, OnceLock}; use std::thread; use std::time::{Duration, SystemTime, UNIX_EPOCH}; @@ -622,13 +705,15 @@ mod tests { static LOCK: OnceLock> = OnceLock::new(); LOCK.get_or_init(|| Mutex::new(())) .lock() - .expect("env lock") + .unwrap_or_else(std::sync::PoisonError::into_inner) } fn temp_config_home() -> std::path::PathBuf { + static NEXT_ID: AtomicU64 = AtomicU64::new(0); std::env::temp_dir().join(format!( - "api-oauth-test-{}-{}", + "api-oauth-test-{}-{}-{}", std::process::id(), + NEXT_ID.fetch_add(1, Ordering::Relaxed), SystemTime::now() .duration_since(UNIX_EPOCH) .expect("time") diff --git a/rust/crates/api/src/lib.rs b/rust/crates/api/src/lib.rs index 4108187..43e2ffa 100644 --- a/rust/crates/api/src/lib.rs +++ b/rust/crates/api/src/lib.rs @@ -1,5 +1,6 @@ mod client; mod error; +mod prompt_cache; mod sse; mod types; @@ -8,6 +9,10 @@ pub use client::{ AnthropicClient, AuthSource, MessageStream, OAuthTokenSet, }; pub use error::ApiError; +pub use prompt_cache::{ + CacheBreakEvent, PromptCache, PromptCacheConfig, PromptCachePaths, PromptCacheRecord, + PromptCacheStats, +}; pub use sse::{parse_frame, SseParser}; pub use types::{ ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, diff --git a/rust/crates/api/src/prompt_cache.rs b/rust/crates/api/src/prompt_cache.rs new file mode 100644 index 0000000..5a6a7da --- /dev/null +++ b/rust/crates/api/src/prompt_cache.rs @@ -0,0 +1,679 @@ +use std::fs; +use std::path::{Path, PathBuf}; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use serde::{Deserialize, Serialize}; + +use crate::types::{MessageRequest, MessageResponse, Usage}; + +const DEFAULT_COMPLETION_TTL_SECS: u64 = 30; +const DEFAULT_PROMPT_TTL_SECS: u64 = 5 * 60; +const DEFAULT_BREAK_MIN_DROP: u32 = 2_000; +const MAX_SANITIZED_LENGTH: usize = 80; +const REQUEST_FINGERPRINT_VERSION: u32 = 1; +const REQUEST_FINGERPRINT_PREFIX: &str = "v1"; +const FNV_OFFSET_BASIS: u64 = 0xcbf2_9ce4_8422_2325; +const FNV_PRIME: u64 = 0x0000_0100_0000_01b3; + +#[derive(Debug, Clone)] +pub struct PromptCacheConfig { + pub session_id: String, + pub completion_ttl: Duration, + pub prompt_ttl: Duration, + pub cache_break_min_drop: u32, +} + +impl PromptCacheConfig { + #[must_use] + pub fn new(session_id: impl Into) -> Self { + Self { + session_id: session_id.into(), + completion_ttl: Duration::from_secs(DEFAULT_COMPLETION_TTL_SECS), + prompt_ttl: Duration::from_secs(DEFAULT_PROMPT_TTL_SECS), + cache_break_min_drop: DEFAULT_BREAK_MIN_DROP, + } + } +} + +impl Default for PromptCacheConfig { + fn default() -> Self { + Self::new("default") + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct PromptCachePaths { + pub root: PathBuf, + pub session_dir: PathBuf, + pub completion_dir: PathBuf, + pub session_state_path: PathBuf, + pub stats_path: PathBuf, +} + +impl PromptCachePaths { + #[must_use] + pub fn for_session(session_id: &str) -> Self { + let root = base_cache_root(); + let session_dir = root.join(sanitize_path_segment(session_id)); + let completion_dir = session_dir.join("completions"); + Self { + root, + session_state_path: session_dir.join("session-state.json"), + stats_path: session_dir.join("stats.json"), + session_dir, + completion_dir, + } + } + + #[must_use] + pub fn completion_entry_path(&self, request_hash: &str) -> PathBuf { + self.completion_dir.join(format!("{request_hash}.json")) + } +} + +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct PromptCacheStats { + pub tracked_requests: u64, + pub completion_cache_hits: u64, + pub completion_cache_misses: u64, + pub completion_cache_writes: u64, + pub expected_invalidations: u64, + pub unexpected_cache_breaks: u64, + pub total_cache_creation_input_tokens: u64, + pub total_cache_read_input_tokens: u64, + pub last_cache_creation_input_tokens: Option, + pub last_cache_read_input_tokens: Option, + pub last_request_hash: Option, + pub last_completion_cache_key: Option, + pub last_break_reason: Option, + pub last_cache_source: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct CacheBreakEvent { + pub unexpected: bool, + pub reason: String, + pub previous_cache_read_input_tokens: u32, + pub current_cache_read_input_tokens: u32, + pub token_drop: u32, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PromptCacheRecord { + pub cache_break: Option, + pub stats: PromptCacheStats, +} + +#[derive(Debug, Clone)] +pub struct PromptCache { + inner: Arc>, +} + +impl PromptCache { + #[must_use] + pub fn new(session_id: impl Into) -> Self { + Self::with_config(PromptCacheConfig::new(session_id)) + } + + #[must_use] + pub fn with_config(config: PromptCacheConfig) -> Self { + let paths = PromptCachePaths::for_session(&config.session_id); + let stats = read_json::(&paths.stats_path).unwrap_or_default(); + let previous = read_json::(&paths.session_state_path); + Self { + inner: Arc::new(Mutex::new(PromptCacheInner { + config, + paths, + stats, + previous, + })), + } + } + + #[must_use] + pub fn paths(&self) -> PromptCachePaths { + self.lock().paths.clone() + } + + #[must_use] + pub fn stats(&self) -> PromptCacheStats { + self.lock().stats.clone() + } + + pub fn lookup_completion(&self, request: &MessageRequest) -> Option { + let request_hash = request_hash_hex(request); + let (paths, ttl) = { + let inner = self.lock(); + (inner.paths.clone(), inner.config.completion_ttl) + }; + let entry_path = paths.completion_entry_path(&request_hash); + let entry = read_json::(&entry_path); + let Some(entry) = entry else { + let mut inner = self.lock(); + inner.stats.completion_cache_misses += 1; + inner.stats.last_completion_cache_key = Some(request_hash); + persist_state(&inner); + return None; + }; + + if entry.fingerprint_version != current_fingerprint_version() { + let mut inner = self.lock(); + inner.stats.completion_cache_misses += 1; + inner.stats.last_completion_cache_key = Some(request_hash.clone()); + let _ = fs::remove_file(entry_path); + persist_state(&inner); + return None; + } + + let expired = now_unix_secs().saturating_sub(entry.cached_at_unix_secs) >= ttl.as_secs(); + let mut inner = self.lock(); + inner.stats.last_completion_cache_key = Some(request_hash.clone()); + if expired { + inner.stats.completion_cache_misses += 1; + let _ = fs::remove_file(entry_path); + persist_state(&inner); + return None; + } + + inner.stats.completion_cache_hits += 1; + apply_usage_to_stats( + &mut inner.stats, + &entry.response.usage, + &request_hash, + "completion-cache", + ); + inner.previous = Some(TrackedPromptState::from_usage( + request, + &entry.response.usage, + )); + persist_state(&inner); + Some(entry.response) + } + + pub fn record_response( + &self, + request: &MessageRequest, + response: &MessageResponse, + ) -> PromptCacheRecord { + self.record_usage_internal(request, &response.usage, Some(response)) + } + + pub fn record_usage(&self, request: &MessageRequest, usage: &Usage) -> PromptCacheRecord { + self.record_usage_internal(request, usage, None) + } + + fn record_usage_internal( + &self, + request: &MessageRequest, + usage: &Usage, + response: Option<&MessageResponse>, + ) -> PromptCacheRecord { + let request_hash = request_hash_hex(request); + let mut inner = self.lock(); + let previous = inner.previous.clone(); + let current = TrackedPromptState::from_usage(request, usage); + let cache_break = detect_cache_break(&inner.config, previous.as_ref(), ¤t); + + inner.stats.tracked_requests += 1; + apply_usage_to_stats(&mut inner.stats, usage, &request_hash, "api-response"); + if let Some(event) = &cache_break { + if event.unexpected { + inner.stats.unexpected_cache_breaks += 1; + } else { + inner.stats.expected_invalidations += 1; + } + inner.stats.last_break_reason = Some(event.reason.clone()); + } + + inner.previous = Some(current); + if let Some(response) = response { + write_completion_entry(&inner.paths, &request_hash, response); + inner.stats.completion_cache_writes += 1; + } + persist_state(&inner); + + PromptCacheRecord { + cache_break, + stats: inner.stats.clone(), + } + } + + fn lock(&self) -> std::sync::MutexGuard<'_, PromptCacheInner> { + self.inner + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + } +} + +#[derive(Debug)] +struct PromptCacheInner { + config: PromptCacheConfig, + paths: PromptCachePaths, + stats: PromptCacheStats, + previous: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct CompletionCacheEntry { + cached_at_unix_secs: u64, + #[serde(default = "current_fingerprint_version")] + fingerprint_version: u32, + response: MessageResponse, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +struct TrackedPromptState { + observed_at_unix_secs: u64, + #[serde(default = "current_fingerprint_version")] + fingerprint_version: u32, + request_hash: u64, + model_hash: u64, + system_hash: u64, + tools_hash: u64, + messages_hash: u64, + cache_read_input_tokens: u32, +} + +impl TrackedPromptState { + fn from_usage(request: &MessageRequest, usage: &Usage) -> Self { + let hashes = RequestHashes::from_request(request); + Self { + observed_at_unix_secs: now_unix_secs(), + fingerprint_version: current_fingerprint_version(), + request_hash: hashes.request_hash, + model_hash: hashes.model_hash, + system_hash: hashes.system_hash, + tools_hash: hashes.tools_hash, + messages_hash: hashes.messages_hash, + cache_read_input_tokens: usage.cache_read_input_tokens, + } + } +} + +#[derive(Debug, Clone, Copy)] +struct RequestHashes { + request_hash: u64, + model_hash: u64, + system_hash: u64, + tools_hash: u64, + messages_hash: u64, +} + +impl RequestHashes { + fn from_request(request: &MessageRequest) -> Self { + Self { + request_hash: hash_serializable(request), + model_hash: hash_serializable(&request.model), + system_hash: hash_serializable(&request.system), + tools_hash: hash_serializable(&request.tools), + messages_hash: hash_serializable(&request.messages), + } + } +} + +fn detect_cache_break( + config: &PromptCacheConfig, + previous: Option<&TrackedPromptState>, + current: &TrackedPromptState, +) -> Option { + let previous = previous?; + if previous.fingerprint_version != current.fingerprint_version { + return Some(CacheBreakEvent { + unexpected: false, + reason: format!( + "fingerprint version changed (v{} -> v{})", + previous.fingerprint_version, current.fingerprint_version + ), + previous_cache_read_input_tokens: previous.cache_read_input_tokens, + current_cache_read_input_tokens: current.cache_read_input_tokens, + token_drop: previous + .cache_read_input_tokens + .saturating_sub(current.cache_read_input_tokens), + }); + } + let token_drop = previous + .cache_read_input_tokens + .saturating_sub(current.cache_read_input_tokens); + if token_drop < config.cache_break_min_drop { + return None; + } + + let mut reasons = Vec::new(); + if previous.model_hash != current.model_hash { + reasons.push("model changed"); + } + if previous.system_hash != current.system_hash { + reasons.push("system prompt changed"); + } + if previous.tools_hash != current.tools_hash { + reasons.push("tool definitions changed"); + } + if previous.messages_hash != current.messages_hash { + reasons.push("message payload changed"); + } + + let elapsed = current + .observed_at_unix_secs + .saturating_sub(previous.observed_at_unix_secs); + + let (unexpected, reason) = if reasons.is_empty() { + if elapsed > config.prompt_ttl.as_secs() { + ( + false, + format!("possible prompt cache TTL expiry after {elapsed}s"), + ) + } else { + ( + true, + "cache read tokens dropped while prompt fingerprint remained stable".to_string(), + ) + } + } else { + (false, reasons.join(", ")) + }; + + Some(CacheBreakEvent { + unexpected, + reason, + previous_cache_read_input_tokens: previous.cache_read_input_tokens, + current_cache_read_input_tokens: current.cache_read_input_tokens, + token_drop, + }) +} + +fn apply_usage_to_stats( + stats: &mut PromptCacheStats, + usage: &Usage, + request_hash: &str, + source: &str, +) { + stats.total_cache_creation_input_tokens += u64::from(usage.cache_creation_input_tokens); + stats.total_cache_read_input_tokens += u64::from(usage.cache_read_input_tokens); + stats.last_cache_creation_input_tokens = Some(usage.cache_creation_input_tokens); + stats.last_cache_read_input_tokens = Some(usage.cache_read_input_tokens); + stats.last_request_hash = Some(request_hash.to_string()); + stats.last_cache_source = Some(source.to_string()); +} + +fn persist_state(inner: &PromptCacheInner) { + let _ = ensure_cache_dirs(&inner.paths); + let _ = write_json(&inner.paths.stats_path, &inner.stats); + if let Some(previous) = &inner.previous { + let _ = write_json(&inner.paths.session_state_path, previous); + } +} + +fn write_completion_entry( + paths: &PromptCachePaths, + request_hash: &str, + response: &MessageResponse, +) { + let _ = ensure_cache_dirs(paths); + let entry = CompletionCacheEntry { + cached_at_unix_secs: now_unix_secs(), + fingerprint_version: current_fingerprint_version(), + response: response.clone(), + }; + let _ = write_json(&paths.completion_entry_path(request_hash), &entry); +} + +fn ensure_cache_dirs(paths: &PromptCachePaths) -> std::io::Result<()> { + fs::create_dir_all(&paths.completion_dir) +} + +fn write_json(path: &Path, value: &T) -> std::io::Result<()> { + let json = serde_json::to_vec_pretty(value) + .map_err(|error| std::io::Error::new(std::io::ErrorKind::InvalidData, error))?; + fs::write(path, json) +} + +fn read_json Deserialize<'de>>(path: &Path) -> Option { + let bytes = fs::read(path).ok()?; + serde_json::from_slice(&bytes).ok() +} + +fn request_hash_hex(request: &MessageRequest) -> String { + format!( + "{REQUEST_FINGERPRINT_PREFIX}-{:016x}", + hash_serializable(request) + ) +} + +fn hash_serializable(value: &T) -> u64 { + let json = serde_json::to_vec(value).unwrap_or_default(); + stable_hash_bytes(&json) +} + +fn sanitize_path_segment(value: &str) -> String { + let sanitized: String = value + .chars() + .map(|ch| if ch.is_ascii_alphanumeric() { ch } else { '-' }) + .collect(); + if sanitized.len() <= MAX_SANITIZED_LENGTH { + return sanitized; + } + let suffix = format!("-{:x}", hash_string(value)); + format!( + "{}{}", + &sanitized[..MAX_SANITIZED_LENGTH.saturating_sub(suffix.len())], + suffix + ) +} + +fn hash_string(value: &str) -> u64 { + stable_hash_bytes(value.as_bytes()) +} + +fn base_cache_root() -> PathBuf { + if let Some(config_home) = std::env::var_os("CLAUDE_CONFIG_HOME") { + return PathBuf::from(config_home) + .join("cache") + .join("prompt-cache"); + } + if let Some(home) = std::env::var_os("HOME") { + return PathBuf::from(home) + .join(".claude") + .join("cache") + .join("prompt-cache"); + } + std::env::temp_dir().join("claude-prompt-cache") +} + +fn now_unix_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_or(0, |duration| duration.as_secs()) +} + +const fn current_fingerprint_version() -> u32 { + REQUEST_FINGERPRINT_VERSION +} + +fn stable_hash_bytes(bytes: &[u8]) -> u64 { + let mut hash = FNV_OFFSET_BASIS; + for byte in bytes { + hash ^= u64::from(*byte); + hash = hash.wrapping_mul(FNV_PRIME); + } + hash +} + +#[cfg(test)] +mod tests { + use std::sync::{Mutex, OnceLock}; + use std::time::{SystemTime, UNIX_EPOCH}; + + use super::{ + detect_cache_break, read_json, request_hash_hex, sanitize_path_segment, PromptCache, + PromptCacheConfig, PromptCachePaths, TrackedPromptState, REQUEST_FINGERPRINT_PREFIX, + }; + use crate::types::{InputMessage, MessageRequest, MessageResponse, OutputContentBlock, Usage}; + + fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + } + + #[test] + fn path_builder_sanitizes_session_identifier() { + let paths = PromptCachePaths::for_session("session:/with spaces"); + let session_dir = paths + .session_dir + .file_name() + .and_then(|value| value.to_str()) + .expect("session dir name"); + assert_eq!(session_dir, "session--with-spaces"); + assert!(paths.completion_dir.ends_with("completions")); + assert!(paths.stats_path.ends_with("stats.json")); + assert!(paths.session_state_path.ends_with("session-state.json")); + } + + #[test] + fn request_fingerprint_drives_unexpected_break_detection() { + let request = sample_request("same"); + let previous = TrackedPromptState::from_usage( + &request, + &Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 6_000, + output_tokens: 0, + }, + ); + let current = TrackedPromptState::from_usage( + &request, + &Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 1_000, + output_tokens: 0, + }, + ); + let event = detect_cache_break(&PromptCacheConfig::default(), Some(&previous), ¤t) + .expect("break should be detected"); + assert!(event.unexpected); + assert!(event.reason.contains("stable")); + } + + #[test] + fn changed_prompt_marks_break_as_expected() { + let previous_request = sample_request("first"); + let current_request = sample_request("second"); + let previous = TrackedPromptState::from_usage( + &previous_request, + &Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 6_000, + output_tokens: 0, + }, + ); + let current = TrackedPromptState::from_usage( + ¤t_request, + &Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 1_000, + output_tokens: 0, + }, + ); + let event = detect_cache_break(&PromptCacheConfig::default(), Some(&previous), ¤t) + .expect("break should be detected"); + assert!(!event.unexpected); + assert!(event.reason.contains("message payload changed")); + } + + #[test] + fn completion_cache_round_trip_persists_recent_response() { + let _guard = env_lock(); + let temp_root = std::env::temp_dir().join(format!( + "prompt-cache-test-{}-{}", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); + let cache = PromptCache::new("unit-test-session"); + let request = sample_request("cache me"); + let response = sample_response(42, 12, "cached"); + + assert!(cache.lookup_completion(&request).is_none()); + let record = cache.record_response(&request, &response); + assert!(record.cache_break.is_none()); + + let cached = cache + .lookup_completion(&request) + .expect("cached response should load"); + assert_eq!(cached.content, response.content); + + let stats = cache.stats(); + assert_eq!(stats.completion_cache_hits, 1); + assert_eq!(stats.completion_cache_misses, 1); + assert_eq!(stats.completion_cache_writes, 1); + + let persisted = read_json::(&cache.paths().stats_path) + .expect("stats should persist"); + assert_eq!(persisted.completion_cache_hits, 1); + + std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + } + + #[test] + fn sanitize_path_caps_long_values() { + let long_value = "x".repeat(200); + let sanitized = sanitize_path_segment(&long_value); + assert!(sanitized.len() <= 80); + } + + #[test] + fn request_hashes_are_versioned_and_stable() { + let request = sample_request("stable"); + let first = request_hash_hex(&request); + let second = request_hash_hex(&request); + assert_eq!(first, second); + assert!(first.starts_with(REQUEST_FINGERPRINT_PREFIX)); + } + + fn sample_request(text: &str) -> MessageRequest { + MessageRequest { + model: "claude-3-7-sonnet-latest".to_string(), + max_tokens: 64, + messages: vec![InputMessage::user_text(text)], + system: Some("system".to_string()), + tools: None, + tool_choice: None, + stream: false, + } + } + + fn sample_response( + cache_read_input_tokens: u32, + output_tokens: u32, + text: &str, + ) -> MessageResponse { + MessageResponse { + id: "msg_test".to_string(), + kind: "message".to_string(), + role: "assistant".to_string(), + content: vec![OutputContentBlock::Text { + text: text.to_string(), + }], + model: "claude-3-7-sonnet-latest".to_string(), + stop_reason: Some("end_turn".to_string()), + stop_sequence: None, + usage: Usage { + input_tokens: 10, + cache_creation_input_tokens: 5, + cache_read_input_tokens, + output_tokens, + }, + request_id: Some("req_test".to_string()), + } + } +} diff --git a/rust/crates/api/tests/client_integration.rs b/rust/crates/api/tests/client_integration.rs index c37fa99..9f59710 100644 --- a/rust/crates/api/tests/client_integration.rs +++ b/rust/crates/api/tests/client_integration.rs @@ -1,17 +1,25 @@ use std::collections::HashMap; use std::sync::Arc; +use std::sync::{Mutex as StdMutex, OnceLock}; use std::time::Duration; use api::{ AnthropicClient, ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OutputContentBlock, - StreamEvent, ToolChoice, ToolDefinition, + PromptCache, StreamEvent, ToolChoice, ToolDefinition, }; use serde_json::json; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpListener; use tokio::sync::Mutex; +fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| StdMutex::new(())) + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) +} + #[tokio::test] async fn send_message_posts_json_and_parses_response() { let state = Arc::new(Mutex::new(Vec::::new())); @@ -77,6 +85,16 @@ async fn send_message_posts_json_and_parses_response() { #[tokio::test] async fn stream_message_parses_sse_events_with_tool_use() { + let _guard = env_lock(); + let temp_root = std::env::temp_dir().join(format!( + "api-stream-cache-{}-{}", + std::process::id(), + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); let state = Arc::new(Mutex::new(Vec::::new())); let sse = concat!( "event: message_start\n", @@ -106,7 +124,8 @@ async fn stream_message_parses_sse_events_with_tool_use() { let client = AnthropicClient::new("test-key") .with_auth_token(Some("proxy-token".to_string())) - .with_base_url(server.base_url()); + .with_base_url(server.base_url()) + .with_prompt_cache(PromptCache::new("stream-session")); let mut stream = client .stream_message(&sample_request(false)) .await @@ -160,6 +179,16 @@ async fn stream_message_parses_sse_events_with_tool_use() { let captured = state.lock().await; let request = captured.first().expect("server should capture request"); assert!(request.body.contains("\"stream\":true")); + + let stats = client + .prompt_cache_stats() + .expect("prompt cache stats should exist"); + assert_eq!(stats.tracked_requests, 1); + assert_eq!(stats.last_cache_read_input_tokens, Some(0)); + assert_eq!(stats.last_cache_source.as_deref(), Some("api-response")); + + std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); } #[tokio::test] @@ -243,6 +272,119 @@ async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() { } } +#[tokio::test] +async fn send_message_reuses_recent_completion_cache_entries() { + let _guard = env_lock(); + let temp_root = std::env::temp_dir().join(format!( + "api-prompt-cache-{}-{}", + std::process::id(), + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); + + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state.clone(), + vec![http_response( + "200 OK", + "application/json", + "{\"id\":\"msg_cached\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Cached once\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":5,\"cache_read_input_tokens\":4000,\"output_tokens\":2}}", + )], + ) + .await; + + let client = AnthropicClient::new("test-key") + .with_base_url(server.base_url()) + .with_prompt_cache(PromptCache::new("integration-session")); + + let first = client + .send_message(&sample_request(false)) + .await + .expect("first request should succeed"); + let second = client + .send_message(&sample_request(false)) + .await + .expect("second request should reuse cache"); + + assert_eq!(first.content, second.content); + assert_eq!(state.lock().await.len(), 1); + + let stats = client + .prompt_cache_stats() + .expect("prompt cache stats should exist"); + assert_eq!(stats.completion_cache_hits, 1); + assert_eq!(stats.completion_cache_misses, 1); + assert_eq!(stats.completion_cache_writes, 1); + + std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); +} + +#[tokio::test] +async fn send_message_tracks_unexpected_prompt_cache_breaks() { + let _guard = env_lock(); + let temp_root = std::env::temp_dir().join(format!( + "api-prompt-break-{}-{}", + std::process::id(), + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); + + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state, + vec![ + http_response( + "200 OK", + "application/json", + "{\"id\":\"msg_one\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"One\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":5,\"cache_read_input_tokens\":6000,\"output_tokens\":2}}", + ), + http_response( + "200 OK", + "application/json", + "{\"id\":\"msg_two\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Two\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":0,\"cache_read_input_tokens\":1000,\"output_tokens\":2}}", + ), + ], + ) + .await; + + let request = sample_request(false); + let client = AnthropicClient::new("test-key") + .with_base_url(server.base_url()) + .with_prompt_cache(PromptCache::with_config(api::PromptCacheConfig { + session_id: "break-session".to_string(), + completion_ttl: Duration::from_secs(0), + ..api::PromptCacheConfig::default() + })); + + client + .send_message(&request) + .await + .expect("first response should succeed"); + client + .send_message(&request) + .await + .expect("second response should succeed"); + + let stats = client + .prompt_cache_stats() + .expect("prompt cache stats should exist"); + assert_eq!(stats.unexpected_cache_breaks, 1); + assert_eq!( + stats.last_break_reason.as_deref(), + Some("cache read tokens dropped while prompt fingerprint remained stable") + ); + + std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); +} + #[tokio::test] #[ignore = "requires ANTHROPIC_API_KEY and network access"] async fn live_stream_smoke_test() {