wip: cache-tracking progress
This commit is contained in:
@@ -689,7 +689,6 @@ mod tests {
|
||||
use std::io::{Read, Write};
|
||||
use std::net::TcpListener;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
use std::thread;
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
@@ -699,15 +698,9 @@ mod tests {
|
||||
now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token,
|
||||
resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet,
|
||||
};
|
||||
use crate::test_env_lock;
|
||||
use crate::types::{ContentBlockDelta, MessageRequest};
|
||||
|
||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
}
|
||||
|
||||
fn temp_config_home() -> std::path::PathBuf {
|
||||
static NEXT_ID: AtomicU64 = AtomicU64::new(0);
|
||||
std::env::temp_dir().join(format!(
|
||||
@@ -753,7 +746,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn read_api_key_requires_presence() {
|
||||
let _guard = env_lock();
|
||||
let _guard = test_env_lock();
|
||||
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||
@@ -763,7 +756,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn read_api_key_requires_non_empty_value() {
|
||||
let _guard = env_lock();
|
||||
let _guard = test_env_lock();
|
||||
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "");
|
||||
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||
let error = super::read_api_key().expect_err("empty key should error");
|
||||
@@ -773,7 +766,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn read_api_key_prefers_api_key_env() {
|
||||
let _guard = env_lock();
|
||||
let _guard = test_env_lock();
|
||||
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
|
||||
std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
|
||||
assert_eq!(
|
||||
@@ -786,7 +779,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn read_auth_token_reads_auth_token_env() {
|
||||
let _guard = env_lock();
|
||||
let _guard = test_env_lock();
|
||||
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
|
||||
assert_eq!(super::read_auth_token().as_deref(), Some("auth-token"));
|
||||
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||
@@ -806,7 +799,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn auth_source_from_env_combines_api_key_and_bearer_token() {
|
||||
let _guard = env_lock();
|
||||
let _guard = test_env_lock();
|
||||
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
|
||||
std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
|
||||
let auth = AuthSource::from_env().expect("env auth");
|
||||
@@ -818,7 +811,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn auth_source_from_saved_oauth_when_env_absent() {
|
||||
let _guard = env_lock();
|
||||
let _guard = test_env_lock();
|
||||
let config_home = temp_config_home();
|
||||
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
||||
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||
@@ -857,7 +850,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn resolve_saved_oauth_token_refreshes_expired_credentials() {
|
||||
let _guard = env_lock();
|
||||
let _guard = test_env_lock();
|
||||
let config_home = temp_config_home();
|
||||
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
||||
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||
@@ -889,7 +882,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() {
|
||||
let _guard = env_lock();
|
||||
let _guard = test_env_lock();
|
||||
let config_home = temp_config_home();
|
||||
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
||||
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||
@@ -913,7 +906,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() {
|
||||
let _guard = env_lock();
|
||||
let _guard = test_env_lock();
|
||||
let config_home = temp_config_home();
|
||||
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
||||
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||
@@ -945,7 +938,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() {
|
||||
let _guard = env_lock();
|
||||
let _guard = test_env_lock();
|
||||
let config_home = temp_config_home();
|
||||
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
||||
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||
|
||||
@@ -20,3 +20,11 @@ pub use types::{
|
||||
MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent,
|
||||
ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn test_env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
|
||||
LOCK.get_or_init(|| std::sync::Mutex::new(()))
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
}
|
||||
|
||||
@@ -141,6 +141,7 @@ impl PromptCache {
|
||||
self.lock().stats.clone()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn lookup_completion(&self, request: &MessageRequest) -> Option<MessageResponse> {
|
||||
let request_hash = request_hash_hex(request);
|
||||
let (paths, ttl) = {
|
||||
@@ -191,6 +192,7 @@ impl PromptCache {
|
||||
Some(entry.response)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn record_response(
|
||||
&self,
|
||||
request: &MessageRequest,
|
||||
@@ -199,6 +201,7 @@ impl PromptCache {
|
||||
self.record_usage_internal(request, &response.usage, Some(response))
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn record_usage(&self, request: &MessageRequest, usage: &Usage) -> PromptCacheRecord {
|
||||
self.record_usage_internal(request, usage, None)
|
||||
}
|
||||
@@ -267,7 +270,6 @@ struct TrackedPromptState {
|
||||
observed_at_unix_secs: u64,
|
||||
#[serde(default = "current_fingerprint_version")]
|
||||
fingerprint_version: u32,
|
||||
request_hash: u64,
|
||||
model_hash: u64,
|
||||
system_hash: u64,
|
||||
tools_hash: u64,
|
||||
@@ -277,37 +279,34 @@ struct TrackedPromptState {
|
||||
|
||||
impl TrackedPromptState {
|
||||
fn from_usage(request: &MessageRequest, usage: &Usage) -> Self {
|
||||
let hashes = RequestHashes::from_request(request);
|
||||
let hashes = RequestFingerprints::from_request(request);
|
||||
Self {
|
||||
observed_at_unix_secs: now_unix_secs(),
|
||||
fingerprint_version: current_fingerprint_version(),
|
||||
request_hash: hashes.request_hash,
|
||||
model_hash: hashes.model_hash,
|
||||
system_hash: hashes.system_hash,
|
||||
tools_hash: hashes.tools_hash,
|
||||
messages_hash: hashes.messages_hash,
|
||||
model_hash: hashes.model,
|
||||
system_hash: hashes.system,
|
||||
tools_hash: hashes.tools,
|
||||
messages_hash: hashes.messages,
|
||||
cache_read_input_tokens: usage.cache_read_input_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct RequestHashes {
|
||||
request_hash: u64,
|
||||
model_hash: u64,
|
||||
system_hash: u64,
|
||||
tools_hash: u64,
|
||||
messages_hash: u64,
|
||||
struct RequestFingerprints {
|
||||
model: u64,
|
||||
system: u64,
|
||||
tools: u64,
|
||||
messages: u64,
|
||||
}
|
||||
|
||||
impl RequestHashes {
|
||||
impl RequestFingerprints {
|
||||
fn from_request(request: &MessageRequest) -> Self {
|
||||
Self {
|
||||
request_hash: hash_serializable(request),
|
||||
model_hash: hash_serializable(&request.model),
|
||||
system_hash: hash_serializable(&request.system),
|
||||
tools_hash: hash_serializable(&request.tools),
|
||||
messages_hash: hash_serializable(&request.messages),
|
||||
model: hash_serializable(&request.model),
|
||||
system: hash_serializable(&request.system),
|
||||
tools: hash_serializable(&request.tools),
|
||||
messages: hash_serializable(&request.messages),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -501,22 +500,15 @@ fn stable_hash_bytes(bytes: &[u8]) -> u64 {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use super::{
|
||||
detect_cache_break, read_json, request_hash_hex, sanitize_path_segment, PromptCache,
|
||||
PromptCacheConfig, PromptCachePaths, TrackedPromptState, REQUEST_FINGERPRINT_PREFIX,
|
||||
};
|
||||
use crate::test_env_lock;
|
||||
use crate::types::{InputMessage, MessageRequest, MessageResponse, OutputContentBlock, Usage};
|
||||
|
||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn path_builder_sanitizes_session_identifier() {
|
||||
let paths = PromptCachePaths::for_session("session:/with spaces");
|
||||
@@ -588,7 +580,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn completion_cache_round_trip_persists_recent_response() {
|
||||
let _guard = env_lock();
|
||||
let _guard = test_env_lock();
|
||||
let temp_root = std::env::temp_dir().join(format!(
|
||||
"prompt-cache-test-{}-{}",
|
||||
std::process::id(),
|
||||
@@ -624,6 +616,62 @@ mod tests {
|
||||
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn distinct_requests_do_not_collide_in_completion_cache() {
|
||||
let _guard = test_env_lock();
|
||||
let temp_root = std::env::temp_dir().join(format!(
|
||||
"prompt-cache-distinct-{}-{}",
|
||||
std::process::id(),
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time")
|
||||
.as_nanos()
|
||||
));
|
||||
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
|
||||
let cache = PromptCache::new("distinct-request-session");
|
||||
let first_request = sample_request("first");
|
||||
let second_request = sample_request("second");
|
||||
|
||||
let response = sample_response(42, 12, "cached");
|
||||
let _ = cache.record_response(&first_request, &response);
|
||||
|
||||
assert!(cache.lookup_completion(&second_request).is_none());
|
||||
|
||||
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
|
||||
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expired_completion_entries_are_not_reused() {
|
||||
let _guard = test_env_lock();
|
||||
let temp_root = std::env::temp_dir().join(format!(
|
||||
"prompt-cache-expired-{}-{}",
|
||||
std::process::id(),
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time")
|
||||
.as_nanos()
|
||||
));
|
||||
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
|
||||
let cache = PromptCache::with_config(PromptCacheConfig {
|
||||
session_id: "expired-session".to_string(),
|
||||
completion_ttl: Duration::ZERO,
|
||||
..PromptCacheConfig::default()
|
||||
});
|
||||
let request = sample_request("expire me");
|
||||
let response = sample_response(7, 3, "stale");
|
||||
|
||||
let _ = cache.record_response(&request, &response);
|
||||
|
||||
assert!(cache.lookup_completion(&request).is_none());
|
||||
let stats = cache.stats();
|
||||
assert_eq!(stats.completion_cache_hits, 0);
|
||||
assert_eq!(stats.completion_cache_misses, 1);
|
||||
|
||||
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
|
||||
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sanitize_path_caps_long_values() {
|
||||
let long_value = "x".repeat(200);
|
||||
|
||||
@@ -84,6 +84,7 @@ async fn send_message_posts_json_and_parses_response() {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[allow(clippy::await_holding_lock)]
|
||||
async fn stream_message_parses_sse_events_with_tool_use() {
|
||||
let _guard = env_lock();
|
||||
let temp_root = std::env::temp_dir().join(format!(
|
||||
@@ -180,12 +181,15 @@ async fn stream_message_parses_sse_events_with_tool_use() {
|
||||
let request = captured.first().expect("server should capture request");
|
||||
assert!(request.body.contains("\"stream\":true"));
|
||||
|
||||
let stats = client
|
||||
let cache_stats = client
|
||||
.prompt_cache_stats()
|
||||
.expect("prompt cache stats should exist");
|
||||
assert_eq!(stats.tracked_requests, 1);
|
||||
assert_eq!(stats.last_cache_read_input_tokens, Some(0));
|
||||
assert_eq!(stats.last_cache_source.as_deref(), Some("api-response"));
|
||||
assert_eq!(cache_stats.tracked_requests, 1);
|
||||
assert_eq!(cache_stats.last_cache_read_input_tokens, Some(0));
|
||||
assert_eq!(
|
||||
cache_stats.last_cache_source.as_deref(),
|
||||
Some("api-response")
|
||||
);
|
||||
|
||||
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
|
||||
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||
@@ -273,6 +277,7 @@ async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[allow(clippy::await_holding_lock)]
|
||||
async fn send_message_reuses_recent_completion_cache_entries() {
|
||||
let _guard = env_lock();
|
||||
let temp_root = std::env::temp_dir().join(format!(
|
||||
@@ -312,18 +317,19 @@ async fn send_message_reuses_recent_completion_cache_entries() {
|
||||
assert_eq!(first.content, second.content);
|
||||
assert_eq!(state.lock().await.len(), 1);
|
||||
|
||||
let stats = client
|
||||
let cache_stats = client
|
||||
.prompt_cache_stats()
|
||||
.expect("prompt cache stats should exist");
|
||||
assert_eq!(stats.completion_cache_hits, 1);
|
||||
assert_eq!(stats.completion_cache_misses, 1);
|
||||
assert_eq!(stats.completion_cache_writes, 1);
|
||||
assert_eq!(cache_stats.completion_cache_hits, 1);
|
||||
assert_eq!(cache_stats.completion_cache_misses, 1);
|
||||
assert_eq!(cache_stats.completion_cache_writes, 1);
|
||||
|
||||
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
|
||||
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[allow(clippy::await_holding_lock)]
|
||||
async fn send_message_tracks_unexpected_prompt_cache_breaks() {
|
||||
let _guard = env_lock();
|
||||
let temp_root = std::env::temp_dir().join(format!(
|
||||
@@ -372,12 +378,12 @@ async fn send_message_tracks_unexpected_prompt_cache_breaks() {
|
||||
.await
|
||||
.expect("second response should succeed");
|
||||
|
||||
let stats = client
|
||||
let cache_stats = client
|
||||
.prompt_cache_stats()
|
||||
.expect("prompt cache stats should exist");
|
||||
assert_eq!(stats.unexpected_cache_breaks, 1);
|
||||
assert_eq!(cache_stats.unexpected_cache_breaks, 1);
|
||||
assert_eq!(
|
||||
stats.last_break_reason.as_deref(),
|
||||
cache_stats.last_break_reason.as_deref(),
|
||||
Some("cache read tokens dropped while prompt fingerprint remained stable")
|
||||
);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user