6 Commits

Author SHA1 Message Date
Yeachan-Heo
c38eac7a90 feat: hook-pipeline progress — tests passing 2026-04-01 05:58:00 +00:00
Yeachan-Heo
197065bfc8 feat: hook abort signal + Ctrl-C cancellation pipeline 2026-04-01 05:55:24 +00:00
Yeachan-Heo
555a245456 wip: hook progress UI + documentation 2026-04-01 04:50:26 +00:00
Yeachan-Heo
9efd029e26 wip: hook-pipeline progress 2026-04-01 04:40:18 +00:00
Yeachan-Heo
eb89fc95e7 wip: hook-pipeline progress 2026-04-01 04:30:25 +00:00
Yeachan-Heo
94199beabb wip: hook pipeline progress 2026-04-01 04:20:16 +00:00
12 changed files with 1531 additions and 1389 deletions

View File

@@ -1,5 +1,4 @@
use std::collections::VecDeque; use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::time::{Duration, SystemTime, UNIX_EPOCH};
use runtime::{ use runtime::{
@@ -9,9 +8,8 @@ use runtime::{
use serde::Deserialize; use serde::Deserialize;
use crate::error::ApiError; use crate::error::ApiError;
use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats};
use crate::sse::SseParser; use crate::sse::SseParser;
use crate::types::{MessageRequest, MessageResponse, StreamEvent, Usage}; use crate::types::{MessageRequest, MessageResponse, StreamEvent};
const DEFAULT_BASE_URL: &str = "https://api.anthropic.com"; const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
const ANTHROPIC_VERSION: &str = "2023-06-01"; const ANTHROPIC_VERSION: &str = "2023-06-01";
@@ -110,8 +108,6 @@ pub struct AnthropicClient {
max_retries: u32, max_retries: u32,
initial_backoff: Duration, initial_backoff: Duration,
max_backoff: Duration, max_backoff: Duration,
prompt_cache: Option<PromptCache>,
last_prompt_cache_record: Arc<Mutex<Option<PromptCacheRecord>>>,
} }
impl AnthropicClient { impl AnthropicClient {
@@ -124,8 +120,6 @@ impl AnthropicClient {
max_retries: DEFAULT_MAX_RETRIES, max_retries: DEFAULT_MAX_RETRIES,
initial_backoff: DEFAULT_INITIAL_BACKOFF, initial_backoff: DEFAULT_INITIAL_BACKOFF,
max_backoff: DEFAULT_MAX_BACKOFF, max_backoff: DEFAULT_MAX_BACKOFF,
prompt_cache: None,
last_prompt_cache_record: Arc::new(Mutex::new(None)),
} }
} }
@@ -138,8 +132,6 @@ impl AnthropicClient {
max_retries: DEFAULT_MAX_RETRIES, max_retries: DEFAULT_MAX_RETRIES,
initial_backoff: DEFAULT_INITIAL_BACKOFF, initial_backoff: DEFAULT_INITIAL_BACKOFF,
max_backoff: DEFAULT_MAX_BACKOFF, max_backoff: DEFAULT_MAX_BACKOFF,
prompt_cache: None,
last_prompt_cache_record: Arc::new(Mutex::new(None)),
} }
} }
@@ -197,30 +189,6 @@ impl AnthropicClient {
self self
} }
#[must_use]
pub fn with_prompt_cache(mut self, prompt_cache: PromptCache) -> Self {
self.prompt_cache = Some(prompt_cache);
self
}
#[must_use]
pub fn prompt_cache(&self) -> Option<&PromptCache> {
self.prompt_cache.as_ref()
}
#[must_use]
pub fn prompt_cache_stats(&self) -> Option<PromptCacheStats> {
self.prompt_cache.as_ref().map(PromptCache::stats)
}
#[must_use]
pub fn take_last_prompt_cache_record(&self) -> Option<PromptCacheRecord> {
self.last_prompt_cache_record()
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.take()
}
#[must_use] #[must_use]
pub fn auth_source(&self) -> &AuthSource { pub fn auth_source(&self) -> &AuthSource {
&self.auth &self.auth
@@ -230,19 +198,10 @@ impl AnthropicClient {
&self, &self,
request: &MessageRequest, request: &MessageRequest,
) -> Result<MessageResponse, ApiError> { ) -> Result<MessageResponse, ApiError> {
self.store_last_prompt_cache_record(None);
let request = MessageRequest { let request = MessageRequest {
stream: false, stream: false,
..request.clone() ..request.clone()
}; };
if let Some(prompt_cache) = &self.prompt_cache {
if let Some(response) = prompt_cache.lookup_completion(&request) {
self.store_last_prompt_cache_record(Some(prompt_cache_record_from_stats(
prompt_cache.stats(),
)));
return Ok(response);
}
}
let response = self.send_with_retry(&request).await?; let response = self.send_with_retry(&request).await?;
let request_id = request_id_from_headers(response.headers()); let request_id = request_id_from_headers(response.headers());
let mut response = response let mut response = response
@@ -252,10 +211,6 @@ impl AnthropicClient {
if response.request_id.is_none() { if response.request_id.is_none() {
response.request_id = request_id; response.request_id = request_id;
} }
if let Some(prompt_cache) = &self.prompt_cache {
let record = prompt_cache.record_response(&request, &response);
self.store_last_prompt_cache_record(Some(record));
}
Ok(response) Ok(response)
} }
@@ -263,7 +218,6 @@ impl AnthropicClient {
&self, &self,
request: &MessageRequest, request: &MessageRequest,
) -> Result<MessageStream, ApiError> { ) -> Result<MessageStream, ApiError> {
self.store_last_prompt_cache_record(None);
let response = self let response = self
.send_with_retry(&request.clone().with_streaming()) .send_with_retry(&request.clone().with_streaming())
.await?; .await?;
@@ -273,30 +227,9 @@ impl AnthropicClient {
parser: SseParser::new(), parser: SseParser::new(),
pending: VecDeque::new(), pending: VecDeque::new(),
done: false, done: false,
cache_tracking: self
.prompt_cache
.as_ref()
.map(|prompt_cache| StreamCacheTracking {
prompt_cache: prompt_cache.clone(),
request: request.clone().with_streaming(),
last_usage: None,
finalized: false,
last_record: self.last_prompt_cache_record.clone(),
}),
}) })
} }
fn store_last_prompt_cache_record(&self, record: Option<PromptCacheRecord>) {
*self
.last_prompt_cache_record()
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) = record;
}
fn last_prompt_cache_record(&self) -> &Arc<Mutex<Option<PromptCacheRecord>>> {
&self.last_prompt_cache_record
}
pub async fn exchange_oauth_code( pub async fn exchange_oauth_code(
&self, &self,
config: &OAuthConfig, config: &OAuthConfig,
@@ -594,7 +527,6 @@ pub struct MessageStream {
parser: SseParser, parser: SseParser,
pending: VecDeque<StreamEvent>, pending: VecDeque<StreamEvent>,
done: bool, done: bool,
cache_tracking: Option<StreamCacheTracking>,
} }
impl MessageStream { impl MessageStream {
@@ -606,9 +538,6 @@ impl MessageStream {
pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> { pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
loop { loop {
if let Some(event) = self.pending.pop_front() { if let Some(event) = self.pending.pop_front() {
if let Some(cache_tracking) = &mut self.cache_tracking {
cache_tracking.observe(&event);
}
return Ok(Some(event)); return Ok(Some(event));
} }
@@ -616,14 +545,8 @@ impl MessageStream {
let remaining = self.parser.finish()?; let remaining = self.parser.finish()?;
self.pending.extend(remaining); self.pending.extend(remaining);
if let Some(event) = self.pending.pop_front() { if let Some(event) = self.pending.pop_front() {
if let Some(cache_tracking) = &mut self.cache_tracking {
cache_tracking.observe(&event);
}
return Ok(Some(event)); return Ok(Some(event));
} }
if let Some(cache_tracking) = &mut self.cache_tracking {
cache_tracking.finalize();
}
return Ok(None); return Ok(None);
} }
@@ -639,53 +562,6 @@ impl MessageStream {
} }
} }
#[derive(Debug, Clone)]
struct StreamCacheTracking {
prompt_cache: PromptCache,
request: MessageRequest,
last_usage: Option<Usage>,
finalized: bool,
last_record: Arc<Mutex<Option<PromptCacheRecord>>>,
}
impl StreamCacheTracking {
fn observe(&mut self, event: &StreamEvent) {
match event {
StreamEvent::MessageStart(event) => {
self.last_usage = Some(event.message.usage.clone());
}
StreamEvent::MessageDelta(event) => {
self.last_usage = Some(event.usage.clone());
}
StreamEvent::ContentBlockStart(_)
| StreamEvent::ContentBlockDelta(_)
| StreamEvent::ContentBlockStop(_)
| StreamEvent::MessageStop(_) => {}
}
}
fn finalize(&mut self) {
if self.finalized {
return;
}
if let Some(usage) = &self.last_usage {
let record = self.prompt_cache.record_usage(&self.request, usage);
*self
.last_record
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) = Some(record);
}
self.finalized = true;
}
}
fn prompt_cache_record_from_stats(stats: PromptCacheStats) -> PromptCacheRecord {
PromptCacheRecord {
cache_break: None,
stats,
}
}
async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> { async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
let status = response.status(); let status = response.status();
if status.is_success() { if status.is_success() {
@@ -730,7 +606,7 @@ mod tests {
use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER}; use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER};
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::net::TcpListener; use std::net::TcpListener;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Mutex, OnceLock};
use std::thread; use std::thread;
use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::time::{Duration, SystemTime, UNIX_EPOCH};
@@ -740,15 +616,19 @@ mod tests {
now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token, now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token,
resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet, resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet,
}; };
use crate::test_env_lock;
use crate::types::{ContentBlockDelta, MessageRequest}; use crate::types::{ContentBlockDelta, MessageRequest};
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.expect("env lock")
}
fn temp_config_home() -> std::path::PathBuf { fn temp_config_home() -> std::path::PathBuf {
static NEXT_ID: AtomicU64 = AtomicU64::new(0);
std::env::temp_dir().join(format!( std::env::temp_dir().join(format!(
"api-oauth-test-{}-{}-{}", "api-oauth-test-{}-{}",
std::process::id(), std::process::id(),
NEXT_ID.fetch_add(1, Ordering::Relaxed),
SystemTime::now() SystemTime::now()
.duration_since(UNIX_EPOCH) .duration_since(UNIX_EPOCH)
.expect("time") .expect("time")
@@ -788,7 +668,7 @@ mod tests {
#[test] #[test]
fn read_api_key_requires_presence() { fn read_api_key_requires_presence() {
let _guard = test_env_lock(); let _guard = env_lock();
std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
std::env::remove_var("ANTHROPIC_API_KEY"); std::env::remove_var("ANTHROPIC_API_KEY");
std::env::remove_var("CLAUDE_CONFIG_HOME"); std::env::remove_var("CLAUDE_CONFIG_HOME");
@@ -798,7 +678,7 @@ mod tests {
#[test] #[test]
fn read_api_key_requires_non_empty_value() { fn read_api_key_requires_non_empty_value() {
let _guard = test_env_lock(); let _guard = env_lock();
std::env::set_var("ANTHROPIC_AUTH_TOKEN", ""); std::env::set_var("ANTHROPIC_AUTH_TOKEN", "");
std::env::remove_var("ANTHROPIC_API_KEY"); std::env::remove_var("ANTHROPIC_API_KEY");
let error = super::read_api_key().expect_err("empty key should error"); let error = super::read_api_key().expect_err("empty key should error");
@@ -808,7 +688,7 @@ mod tests {
#[test] #[test]
fn read_api_key_prefers_api_key_env() { fn read_api_key_prefers_api_key_env() {
let _guard = test_env_lock(); let _guard = env_lock();
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
assert_eq!( assert_eq!(
@@ -821,7 +701,7 @@ mod tests {
#[test] #[test]
fn read_auth_token_reads_auth_token_env() { fn read_auth_token_reads_auth_token_env() {
let _guard = test_env_lock(); let _guard = env_lock();
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
assert_eq!(super::read_auth_token().as_deref(), Some("auth-token")); assert_eq!(super::read_auth_token().as_deref(), Some("auth-token"));
std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -841,7 +721,7 @@ mod tests {
#[test] #[test]
fn auth_source_from_env_combines_api_key_and_bearer_token() { fn auth_source_from_env_combines_api_key_and_bearer_token() {
let _guard = test_env_lock(); let _guard = env_lock();
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
let auth = AuthSource::from_env().expect("env auth"); let auth = AuthSource::from_env().expect("env auth");
@@ -853,7 +733,7 @@ mod tests {
#[test] #[test]
fn auth_source_from_saved_oauth_when_env_absent() { fn auth_source_from_saved_oauth_when_env_absent() {
let _guard = test_env_lock(); let _guard = env_lock();
let config_home = temp_config_home(); let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -892,7 +772,7 @@ mod tests {
#[test] #[test]
fn resolve_saved_oauth_token_refreshes_expired_credentials() { fn resolve_saved_oauth_token_refreshes_expired_credentials() {
let _guard = test_env_lock(); let _guard = env_lock();
let config_home = temp_config_home(); let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -924,7 +804,7 @@ mod tests {
#[test] #[test]
fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() { fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() {
let _guard = test_env_lock(); let _guard = env_lock();
let config_home = temp_config_home(); let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -948,7 +828,7 @@ mod tests {
#[test] #[test]
fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() { fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() {
let _guard = test_env_lock(); let _guard = env_lock();
let config_home = temp_config_home(); let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -980,7 +860,7 @@ mod tests {
#[test] #[test]
fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() { fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() {
let _guard = test_env_lock(); let _guard = env_lock();
let config_home = temp_config_home(); let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_AUTH_TOKEN");

View File

@@ -1,6 +1,5 @@
mod client; mod client;
mod error; mod error;
mod prompt_cache;
mod sse; mod sse;
mod types; mod types;
@@ -9,10 +8,6 @@ pub use client::{
AnthropicClient, AuthSource, MessageStream, OAuthTokenSet, AnthropicClient, AuthSource, MessageStream, OAuthTokenSet,
}; };
pub use error::ApiError; pub use error::ApiError;
pub use prompt_cache::{
CacheBreakEvent, PromptCache, PromptCacheConfig, PromptCachePaths, PromptCacheRecord,
PromptCacheStats,
};
pub use sse::{parse_frame, SseParser}; pub use sse::{parse_frame, SseParser};
pub use types::{ pub use types::{
ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
@@ -20,11 +15,3 @@ pub use types::{
MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent, MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent,
ToolChoice, ToolDefinition, ToolResultContentBlock, Usage, ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
}; };
#[cfg(test)]
pub(crate) fn test_env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
LOCK.get_or_init(|| std::sync::Mutex::new(()))
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}

View File

@@ -1,727 +0,0 @@
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use crate::types::{MessageRequest, MessageResponse, Usage};
const DEFAULT_COMPLETION_TTL_SECS: u64 = 30;
const DEFAULT_PROMPT_TTL_SECS: u64 = 5 * 60;
const DEFAULT_BREAK_MIN_DROP: u32 = 2_000;
const MAX_SANITIZED_LENGTH: usize = 80;
const REQUEST_FINGERPRINT_VERSION: u32 = 1;
const REQUEST_FINGERPRINT_PREFIX: &str = "v1";
const FNV_OFFSET_BASIS: u64 = 0xcbf2_9ce4_8422_2325;
const FNV_PRIME: u64 = 0x0000_0100_0000_01b3;
#[derive(Debug, Clone)]
pub struct PromptCacheConfig {
pub session_id: String,
pub completion_ttl: Duration,
pub prompt_ttl: Duration,
pub cache_break_min_drop: u32,
}
impl PromptCacheConfig {
#[must_use]
pub fn new(session_id: impl Into<String>) -> Self {
Self {
session_id: session_id.into(),
completion_ttl: Duration::from_secs(DEFAULT_COMPLETION_TTL_SECS),
prompt_ttl: Duration::from_secs(DEFAULT_PROMPT_TTL_SECS),
cache_break_min_drop: DEFAULT_BREAK_MIN_DROP,
}
}
}
impl Default for PromptCacheConfig {
fn default() -> Self {
Self::new("default")
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PromptCachePaths {
pub root: PathBuf,
pub session_dir: PathBuf,
pub completion_dir: PathBuf,
pub session_state_path: PathBuf,
pub stats_path: PathBuf,
}
impl PromptCachePaths {
#[must_use]
pub fn for_session(session_id: &str) -> Self {
let root = base_cache_root();
let session_dir = root.join(sanitize_path_segment(session_id));
let completion_dir = session_dir.join("completions");
Self {
root,
session_state_path: session_dir.join("session-state.json"),
stats_path: session_dir.join("stats.json"),
session_dir,
completion_dir,
}
}
#[must_use]
pub fn completion_entry_path(&self, request_hash: &str) -> PathBuf {
self.completion_dir.join(format!("{request_hash}.json"))
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct PromptCacheStats {
pub tracked_requests: u64,
pub completion_cache_hits: u64,
pub completion_cache_misses: u64,
pub completion_cache_writes: u64,
pub expected_invalidations: u64,
pub unexpected_cache_breaks: u64,
pub total_cache_creation_input_tokens: u64,
pub total_cache_read_input_tokens: u64,
pub last_cache_creation_input_tokens: Option<u32>,
pub last_cache_read_input_tokens: Option<u32>,
pub last_request_hash: Option<String>,
pub last_completion_cache_key: Option<String>,
pub last_break_reason: Option<String>,
pub last_cache_source: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CacheBreakEvent {
pub unexpected: bool,
pub reason: String,
pub previous_cache_read_input_tokens: u32,
pub current_cache_read_input_tokens: u32,
pub token_drop: u32,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PromptCacheRecord {
pub cache_break: Option<CacheBreakEvent>,
pub stats: PromptCacheStats,
}
#[derive(Debug, Clone)]
pub struct PromptCache {
inner: Arc<Mutex<PromptCacheInner>>,
}
impl PromptCache {
#[must_use]
pub fn new(session_id: impl Into<String>) -> Self {
Self::with_config(PromptCacheConfig::new(session_id))
}
#[must_use]
pub fn with_config(config: PromptCacheConfig) -> Self {
let paths = PromptCachePaths::for_session(&config.session_id);
let stats = read_json::<PromptCacheStats>(&paths.stats_path).unwrap_or_default();
let previous = read_json::<TrackedPromptState>(&paths.session_state_path);
Self {
inner: Arc::new(Mutex::new(PromptCacheInner {
config,
paths,
stats,
previous,
})),
}
}
#[must_use]
pub fn paths(&self) -> PromptCachePaths {
self.lock().paths.clone()
}
#[must_use]
pub fn stats(&self) -> PromptCacheStats {
self.lock().stats.clone()
}
#[must_use]
pub fn lookup_completion(&self, request: &MessageRequest) -> Option<MessageResponse> {
let request_hash = request_hash_hex(request);
let (paths, ttl) = {
let inner = self.lock();
(inner.paths.clone(), inner.config.completion_ttl)
};
let entry_path = paths.completion_entry_path(&request_hash);
let entry = read_json::<CompletionCacheEntry>(&entry_path);
let Some(entry) = entry else {
let mut inner = self.lock();
inner.stats.completion_cache_misses += 1;
inner.stats.last_completion_cache_key = Some(request_hash);
persist_state(&inner);
return None;
};
if entry.fingerprint_version != current_fingerprint_version() {
let mut inner = self.lock();
inner.stats.completion_cache_misses += 1;
inner.stats.last_completion_cache_key = Some(request_hash.clone());
let _ = fs::remove_file(entry_path);
persist_state(&inner);
return None;
}
let expired = now_unix_secs().saturating_sub(entry.cached_at_unix_secs) >= ttl.as_secs();
let mut inner = self.lock();
inner.stats.last_completion_cache_key = Some(request_hash.clone());
if expired {
inner.stats.completion_cache_misses += 1;
let _ = fs::remove_file(entry_path);
persist_state(&inner);
return None;
}
inner.stats.completion_cache_hits += 1;
apply_usage_to_stats(
&mut inner.stats,
&entry.response.usage,
&request_hash,
"completion-cache",
);
inner.previous = Some(TrackedPromptState::from_usage(
request,
&entry.response.usage,
));
persist_state(&inner);
Some(entry.response)
}
#[must_use]
pub fn record_response(
&self,
request: &MessageRequest,
response: &MessageResponse,
) -> PromptCacheRecord {
self.record_usage_internal(request, &response.usage, Some(response))
}
#[must_use]
pub fn record_usage(&self, request: &MessageRequest, usage: &Usage) -> PromptCacheRecord {
self.record_usage_internal(request, usage, None)
}
fn record_usage_internal(
&self,
request: &MessageRequest,
usage: &Usage,
response: Option<&MessageResponse>,
) -> PromptCacheRecord {
let request_hash = request_hash_hex(request);
let mut inner = self.lock();
let previous = inner.previous.clone();
let current = TrackedPromptState::from_usage(request, usage);
let cache_break = detect_cache_break(&inner.config, previous.as_ref(), &current);
inner.stats.tracked_requests += 1;
apply_usage_to_stats(&mut inner.stats, usage, &request_hash, "api-response");
if let Some(event) = &cache_break {
if event.unexpected {
inner.stats.unexpected_cache_breaks += 1;
} else {
inner.stats.expected_invalidations += 1;
}
inner.stats.last_break_reason = Some(event.reason.clone());
}
inner.previous = Some(current);
if let Some(response) = response {
write_completion_entry(&inner.paths, &request_hash, response);
inner.stats.completion_cache_writes += 1;
}
persist_state(&inner);
PromptCacheRecord {
cache_break,
stats: inner.stats.clone(),
}
}
fn lock(&self) -> std::sync::MutexGuard<'_, PromptCacheInner> {
self.inner
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
}
#[derive(Debug)]
struct PromptCacheInner {
config: PromptCacheConfig,
paths: PromptCachePaths,
stats: PromptCacheStats,
previous: Option<TrackedPromptState>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct CompletionCacheEntry {
cached_at_unix_secs: u64,
#[serde(default = "current_fingerprint_version")]
fingerprint_version: u32,
response: MessageResponse,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
struct TrackedPromptState {
observed_at_unix_secs: u64,
#[serde(default = "current_fingerprint_version")]
fingerprint_version: u32,
model_hash: u64,
system_hash: u64,
tools_hash: u64,
messages_hash: u64,
cache_read_input_tokens: u32,
}
impl TrackedPromptState {
fn from_usage(request: &MessageRequest, usage: &Usage) -> Self {
let hashes = RequestFingerprints::from_request(request);
Self {
observed_at_unix_secs: now_unix_secs(),
fingerprint_version: current_fingerprint_version(),
model_hash: hashes.model,
system_hash: hashes.system,
tools_hash: hashes.tools,
messages_hash: hashes.messages,
cache_read_input_tokens: usage.cache_read_input_tokens,
}
}
}
#[derive(Debug, Clone, Copy)]
struct RequestFingerprints {
model: u64,
system: u64,
tools: u64,
messages: u64,
}
impl RequestFingerprints {
fn from_request(request: &MessageRequest) -> Self {
Self {
model: hash_serializable(&request.model),
system: hash_serializable(&request.system),
tools: hash_serializable(&request.tools),
messages: hash_serializable(&request.messages),
}
}
}
fn detect_cache_break(
config: &PromptCacheConfig,
previous: Option<&TrackedPromptState>,
current: &TrackedPromptState,
) -> Option<CacheBreakEvent> {
let previous = previous?;
if previous.fingerprint_version != current.fingerprint_version {
return Some(CacheBreakEvent {
unexpected: false,
reason: format!(
"fingerprint version changed (v{} -> v{})",
previous.fingerprint_version, current.fingerprint_version
),
previous_cache_read_input_tokens: previous.cache_read_input_tokens,
current_cache_read_input_tokens: current.cache_read_input_tokens,
token_drop: previous
.cache_read_input_tokens
.saturating_sub(current.cache_read_input_tokens),
});
}
let token_drop = previous
.cache_read_input_tokens
.saturating_sub(current.cache_read_input_tokens);
if token_drop < config.cache_break_min_drop {
return None;
}
let mut reasons = Vec::new();
if previous.model_hash != current.model_hash {
reasons.push("model changed");
}
if previous.system_hash != current.system_hash {
reasons.push("system prompt changed");
}
if previous.tools_hash != current.tools_hash {
reasons.push("tool definitions changed");
}
if previous.messages_hash != current.messages_hash {
reasons.push("message payload changed");
}
let elapsed = current
.observed_at_unix_secs
.saturating_sub(previous.observed_at_unix_secs);
let (unexpected, reason) = if reasons.is_empty() {
if elapsed > config.prompt_ttl.as_secs() {
(
false,
format!("possible prompt cache TTL expiry after {elapsed}s"),
)
} else {
(
true,
"cache read tokens dropped while prompt fingerprint remained stable".to_string(),
)
}
} else {
(false, reasons.join(", "))
};
Some(CacheBreakEvent {
unexpected,
reason,
previous_cache_read_input_tokens: previous.cache_read_input_tokens,
current_cache_read_input_tokens: current.cache_read_input_tokens,
token_drop,
})
}
fn apply_usage_to_stats(
stats: &mut PromptCacheStats,
usage: &Usage,
request_hash: &str,
source: &str,
) {
stats.total_cache_creation_input_tokens += u64::from(usage.cache_creation_input_tokens);
stats.total_cache_read_input_tokens += u64::from(usage.cache_read_input_tokens);
stats.last_cache_creation_input_tokens = Some(usage.cache_creation_input_tokens);
stats.last_cache_read_input_tokens = Some(usage.cache_read_input_tokens);
stats.last_request_hash = Some(request_hash.to_string());
stats.last_cache_source = Some(source.to_string());
}
fn persist_state(inner: &PromptCacheInner) {
let _ = ensure_cache_dirs(&inner.paths);
let _ = write_json(&inner.paths.stats_path, &inner.stats);
if let Some(previous) = &inner.previous {
let _ = write_json(&inner.paths.session_state_path, previous);
}
}
fn write_completion_entry(
paths: &PromptCachePaths,
request_hash: &str,
response: &MessageResponse,
) {
let _ = ensure_cache_dirs(paths);
let entry = CompletionCacheEntry {
cached_at_unix_secs: now_unix_secs(),
fingerprint_version: current_fingerprint_version(),
response: response.clone(),
};
let _ = write_json(&paths.completion_entry_path(request_hash), &entry);
}
fn ensure_cache_dirs(paths: &PromptCachePaths) -> std::io::Result<()> {
fs::create_dir_all(&paths.completion_dir)
}
fn write_json<T: Serialize>(path: &Path, value: &T) -> std::io::Result<()> {
let json = serde_json::to_vec_pretty(value)
.map_err(|error| std::io::Error::new(std::io::ErrorKind::InvalidData, error))?;
fs::write(path, json)
}
fn read_json<T: for<'de> Deserialize<'de>>(path: &Path) -> Option<T> {
let bytes = fs::read(path).ok()?;
serde_json::from_slice(&bytes).ok()
}
fn request_hash_hex(request: &MessageRequest) -> String {
format!(
"{REQUEST_FINGERPRINT_PREFIX}-{:016x}",
hash_serializable(request)
)
}
fn hash_serializable<T: Serialize>(value: &T) -> u64 {
let json = serde_json::to_vec(value).unwrap_or_default();
stable_hash_bytes(&json)
}
fn sanitize_path_segment(value: &str) -> String {
let sanitized: String = value
.chars()
.map(|ch| if ch.is_ascii_alphanumeric() { ch } else { '-' })
.collect();
if sanitized.len() <= MAX_SANITIZED_LENGTH {
return sanitized;
}
let suffix = format!("-{:x}", hash_string(value));
format!(
"{}{}",
&sanitized[..MAX_SANITIZED_LENGTH.saturating_sub(suffix.len())],
suffix
)
}
fn hash_string(value: &str) -> u64 {
stable_hash_bytes(value.as_bytes())
}
fn base_cache_root() -> PathBuf {
if let Some(config_home) = std::env::var_os("CLAUDE_CONFIG_HOME") {
return PathBuf::from(config_home)
.join("cache")
.join("prompt-cache");
}
if let Some(home) = std::env::var_os("HOME") {
return PathBuf::from(home)
.join(".claude")
.join("cache")
.join("prompt-cache");
}
std::env::temp_dir().join("claude-prompt-cache")
}
fn now_unix_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |duration| duration.as_secs())
}
const fn current_fingerprint_version() -> u32 {
REQUEST_FINGERPRINT_VERSION
}
fn stable_hash_bytes(bytes: &[u8]) -> u64 {
let mut hash = FNV_OFFSET_BASIS;
for byte in bytes {
hash ^= u64::from(*byte);
hash = hash.wrapping_mul(FNV_PRIME);
}
hash
}
#[cfg(test)]
mod tests {
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use super::{
detect_cache_break, read_json, request_hash_hex, sanitize_path_segment, PromptCache,
PromptCacheConfig, PromptCachePaths, TrackedPromptState, REQUEST_FINGERPRINT_PREFIX,
};
use crate::test_env_lock;
use crate::types::{InputMessage, MessageRequest, MessageResponse, OutputContentBlock, Usage};
#[test]
fn path_builder_sanitizes_session_identifier() {
let paths = PromptCachePaths::for_session("session:/with spaces");
let session_dir = paths
.session_dir
.file_name()
.and_then(|value| value.to_str())
.expect("session dir name");
assert_eq!(session_dir, "session--with-spaces");
assert!(paths.completion_dir.ends_with("completions"));
assert!(paths.stats_path.ends_with("stats.json"));
assert!(paths.session_state_path.ends_with("session-state.json"));
}
#[test]
fn request_fingerprint_drives_unexpected_break_detection() {
let request = sample_request("same");
let previous = TrackedPromptState::from_usage(
&request,
&Usage {
input_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 6_000,
output_tokens: 0,
},
);
let current = TrackedPromptState::from_usage(
&request,
&Usage {
input_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 1_000,
output_tokens: 0,
},
);
let event = detect_cache_break(&PromptCacheConfig::default(), Some(&previous), &current)
.expect("break should be detected");
assert!(event.unexpected);
assert!(event.reason.contains("stable"));
}
#[test]
fn changed_prompt_marks_break_as_expected() {
let previous_request = sample_request("first");
let current_request = sample_request("second");
let previous = TrackedPromptState::from_usage(
&previous_request,
&Usage {
input_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 6_000,
output_tokens: 0,
},
);
let current = TrackedPromptState::from_usage(
&current_request,
&Usage {
input_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 1_000,
output_tokens: 0,
},
);
let event = detect_cache_break(&PromptCacheConfig::default(), Some(&previous), &current)
.expect("break should be detected");
assert!(!event.unexpected);
assert!(event.reason.contains("message payload changed"));
}
#[test]
fn completion_cache_round_trip_persists_recent_response() {
let _guard = test_env_lock();
let temp_root = std::env::temp_dir().join(format!(
"prompt-cache-test-{}-{}",
std::process::id(),
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let cache = PromptCache::new("unit-test-session");
let request = sample_request("cache me");
let response = sample_response(42, 12, "cached");
assert!(cache.lookup_completion(&request).is_none());
let record = cache.record_response(&request, &response);
assert!(record.cache_break.is_none());
let cached = cache
.lookup_completion(&request)
.expect("cached response should load");
assert_eq!(cached.content, response.content);
let stats = cache.stats();
assert_eq!(stats.completion_cache_hits, 1);
assert_eq!(stats.completion_cache_misses, 1);
assert_eq!(stats.completion_cache_writes, 1);
let persisted = read_json::<super::PromptCacheStats>(&cache.paths().stats_path)
.expect("stats should persist");
assert_eq!(persisted.completion_cache_hits, 1);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[test]
fn distinct_requests_do_not_collide_in_completion_cache() {
let _guard = test_env_lock();
let temp_root = std::env::temp_dir().join(format!(
"prompt-cache-distinct-{}-{}",
std::process::id(),
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let cache = PromptCache::new("distinct-request-session");
let first_request = sample_request("first");
let second_request = sample_request("second");
let response = sample_response(42, 12, "cached");
let _ = cache.record_response(&first_request, &response);
assert!(cache.lookup_completion(&second_request).is_none());
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[test]
fn expired_completion_entries_are_not_reused() {
let _guard = test_env_lock();
let temp_root = std::env::temp_dir().join(format!(
"prompt-cache-expired-{}-{}",
std::process::id(),
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let cache = PromptCache::with_config(PromptCacheConfig {
session_id: "expired-session".to_string(),
completion_ttl: Duration::ZERO,
..PromptCacheConfig::default()
});
let request = sample_request("expire me");
let response = sample_response(7, 3, "stale");
let _ = cache.record_response(&request, &response);
assert!(cache.lookup_completion(&request).is_none());
let stats = cache.stats();
assert_eq!(stats.completion_cache_hits, 0);
assert_eq!(stats.completion_cache_misses, 1);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[test]
fn sanitize_path_caps_long_values() {
let long_value = "x".repeat(200);
let sanitized = sanitize_path_segment(&long_value);
assert!(sanitized.len() <= 80);
}
#[test]
fn request_hashes_are_versioned_and_stable() {
let request = sample_request("stable");
let first = request_hash_hex(&request);
let second = request_hash_hex(&request);
assert_eq!(first, second);
assert!(first.starts_with(REQUEST_FINGERPRINT_PREFIX));
}
fn sample_request(text: &str) -> MessageRequest {
MessageRequest {
model: "claude-3-7-sonnet-latest".to_string(),
max_tokens: 64,
messages: vec![InputMessage::user_text(text)],
system: Some("system".to_string()),
tools: None,
tool_choice: None,
stream: false,
}
}
fn sample_response(
cache_read_input_tokens: u32,
output_tokens: u32,
text: &str,
) -> MessageResponse {
MessageResponse {
id: "msg_test".to_string(),
kind: "message".to_string(),
role: "assistant".to_string(),
content: vec![OutputContentBlock::Text {
text: text.to_string(),
}],
model: "claude-3-7-sonnet-latest".to_string(),
stop_reason: Some("end_turn".to_string()),
stop_sequence: None,
usage: Usage {
input_tokens: 10,
cache_creation_input_tokens: 5,
cache_read_input_tokens,
output_tokens,
},
request_id: Some("req_test".to_string()),
}
}
}

View File

@@ -1,25 +1,17 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::sync::{Mutex as StdMutex, OnceLock};
use std::time::Duration; use std::time::Duration;
use api::{ use api::{
AnthropicClient, ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, AnthropicClient, ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent,
InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OutputContentBlock, InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OutputContentBlock,
PromptCache, StreamEvent, ToolChoice, ToolDefinition, StreamEvent, ToolChoice, ToolDefinition,
}; };
use serde_json::json; use serde_json::json;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::sync::Mutex; use tokio::sync::Mutex;
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<StdMutex<()>> = OnceLock::new();
LOCK.get_or_init(|| StdMutex::new(()))
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
#[tokio::test] #[tokio::test]
async fn send_message_posts_json_and_parses_response() { async fn send_message_posts_json_and_parses_response() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new())); let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
@@ -53,8 +45,6 @@ async fn send_message_posts_json_and_parses_response() {
assert_eq!(response.id, "msg_test"); assert_eq!(response.id, "msg_test");
assert_eq!(response.total_tokens(), 16); assert_eq!(response.total_tokens(), 16);
assert_eq!(response.request_id.as_deref(), Some("req_body_123")); assert_eq!(response.request_id.as_deref(), Some("req_body_123"));
assert_eq!(response.usage.cache_creation_input_tokens, 0);
assert_eq!(response.usage.cache_read_input_tokens, 0);
assert_eq!( assert_eq!(
response.content, response.content,
vec![OutputContentBlock::Text { vec![OutputContentBlock::Text {
@@ -86,55 +76,11 @@ async fn send_message_posts_json_and_parses_response() {
} }
#[tokio::test] #[tokio::test]
async fn send_message_parses_prompt_cache_token_usage_from_response() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let body = concat!(
"{",
"\"id\":\"msg_cache_tokens\",",
"\"type\":\"message\",",
"\"role\":\"assistant\",",
"\"content\":[{\"type\":\"text\",\"text\":\"Cache tokens\"}],",
"\"model\":\"claude-3-7-sonnet-latest\",",
"\"stop_reason\":\"end_turn\",",
"\"stop_sequence\":null,",
"\"usage\":{\"input_tokens\":12,\"cache_creation_input_tokens\":321,\"cache_read_input_tokens\":654,\"output_tokens\":4}",
"}"
);
let server = spawn_server(
state,
vec![http_response("200 OK", "application/json", body)],
)
.await;
let client = AnthropicClient::new("test-key").with_base_url(server.base_url());
let response = client
.send_message(&sample_request(false))
.await
.expect("request should succeed");
assert_eq!(response.usage.input_tokens, 12);
assert_eq!(response.usage.cache_creation_input_tokens, 321);
assert_eq!(response.usage.cache_read_input_tokens, 654);
assert_eq!(response.usage.output_tokens, 4);
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn stream_message_parses_sse_events_with_tool_use() { async fn stream_message_parses_sse_events_with_tool_use() {
let _guard = env_lock();
let temp_root = std::env::temp_dir().join(format!(
"api-stream-cache-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new())); let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let sse = concat!( let sse = concat!(
"event: message_start\n", "event: message_start\n",
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"cache_creation_input_tokens\":13,\"cache_read_input_tokens\":21,\"output_tokens\":0}}}\n\n", "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"output_tokens\":0}}}\n\n",
"event: content_block_start\n", "event: content_block_start\n",
"data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_123\",\"name\":\"get_weather\",\"input\":{}}}\n\n", "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_123\",\"name\":\"get_weather\",\"input\":{}}}\n\n",
"event: content_block_delta\n", "event: content_block_delta\n",
@@ -142,7 +88,7 @@ async fn stream_message_parses_sse_events_with_tool_use() {
"event: content_block_stop\n", "event: content_block_stop\n",
"data: {\"type\":\"content_block_stop\",\"index\":0}\n\n", "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
"event: message_delta\n", "event: message_delta\n",
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"cache_creation_input_tokens\":34,\"cache_read_input_tokens\":55,\"output_tokens\":1}}\n\n", "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"output_tokens\":1}}\n\n",
"event: message_stop\n", "event: message_stop\n",
"data: {\"type\":\"message_stop\"}\n\n", "data: {\"type\":\"message_stop\"}\n\n",
"data: [DONE]\n\n" "data: [DONE]\n\n"
@@ -160,8 +106,7 @@ async fn stream_message_parses_sse_events_with_tool_use() {
let client = AnthropicClient::new("test-key") let client = AnthropicClient::new("test-key")
.with_auth_token(Some("proxy-token".to_string())) .with_auth_token(Some("proxy-token".to_string()))
.with_base_url(server.base_url()) .with_base_url(server.base_url());
.with_prompt_cache(PromptCache::new("stream-session"));
let mut stream = client let mut stream = client
.stream_message(&sample_request(false)) .stream_message(&sample_request(false))
.await .await
@@ -215,20 +160,6 @@ async fn stream_message_parses_sse_events_with_tool_use() {
let captured = state.lock().await; let captured = state.lock().await;
let request = captured.first().expect("server should capture request"); let request = captured.first().expect("server should capture request");
assert!(request.body.contains("\"stream\":true")); assert!(request.body.contains("\"stream\":true"));
let cache_stats = client
.prompt_cache_stats()
.expect("prompt cache stats should exist");
assert_eq!(cache_stats.tracked_requests, 1);
assert_eq!(cache_stats.last_cache_creation_input_tokens, Some(34));
assert_eq!(cache_stats.last_cache_read_input_tokens, Some(55));
assert_eq!(
cache_stats.last_cache_source.as_deref(),
Some("api-response")
);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
} }
#[tokio::test] #[tokio::test]
@@ -312,121 +243,6 @@ async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() {
} }
} }
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn send_message_reuses_recent_completion_cache_entries() {
let _guard = env_lock();
let temp_root = std::env::temp_dir().join(format!(
"api-prompt-cache-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let server = spawn_server(
state.clone(),
vec![http_response(
"200 OK",
"application/json",
"{\"id\":\"msg_cached\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Cached once\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":5,\"cache_read_input_tokens\":4000,\"output_tokens\":2}}",
)],
)
.await;
let client = AnthropicClient::new("test-key")
.with_base_url(server.base_url())
.with_prompt_cache(PromptCache::new("integration-session"));
let first = client
.send_message(&sample_request(false))
.await
.expect("first request should succeed");
let second = client
.send_message(&sample_request(false))
.await
.expect("second request should reuse cache");
assert_eq!(first.content, second.content);
assert_eq!(state.lock().await.len(), 1);
let cache_stats = client
.prompt_cache_stats()
.expect("prompt cache stats should exist");
assert_eq!(cache_stats.completion_cache_hits, 1);
assert_eq!(cache_stats.completion_cache_misses, 1);
assert_eq!(cache_stats.completion_cache_writes, 1);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn send_message_tracks_unexpected_prompt_cache_breaks() {
let _guard = env_lock();
let temp_root = std::env::temp_dir().join(format!(
"api-prompt-break-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let server = spawn_server(
state,
vec![
http_response(
"200 OK",
"application/json",
"{\"id\":\"msg_one\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"One\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":5,\"cache_read_input_tokens\":6000,\"output_tokens\":2}}",
),
http_response(
"200 OK",
"application/json",
"{\"id\":\"msg_two\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Two\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":0,\"cache_read_input_tokens\":1000,\"output_tokens\":2}}",
),
],
)
.await;
let request = sample_request(false);
let client = AnthropicClient::new("test-key")
.with_base_url(server.base_url())
.with_prompt_cache(PromptCache::with_config(api::PromptCacheConfig {
session_id: "break-session".to_string(),
completion_ttl: Duration::from_secs(0),
..api::PromptCacheConfig::default()
}));
client
.send_message(&request)
.await
.expect("first response should succeed");
client
.send_message(&request)
.await
.expect("second response should succeed");
let cache_stats = client
.prompt_cache_stats()
.expect("prompt cache stats should exist");
assert_eq!(cache_stats.unexpected_cache_breaks, 1);
assert_eq!(
cache_stats.last_break_reason.as_deref(),
Some("cache read tokens dropped while prompt fingerprint remained stable")
);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[tokio::test] #[tokio::test]
#[ignore = "requires ANTHROPIC_API_KEY and network access"] #[ignore = "requires ANTHROPIC_API_KEY and network access"]
async fn live_stream_smoke_test() { async fn live_stream_smoke_test() {

View File

@@ -42,6 +42,7 @@ pub struct RuntimeFeatureConfig {
oauth: Option<OAuthConfig>, oauth: Option<OAuthConfig>,
model: Option<String>, model: Option<String>,
permission_mode: Option<ResolvedPermissionMode>, permission_mode: Option<ResolvedPermissionMode>,
permission_rules: RuntimePermissionRuleConfig,
sandbox: SandboxConfig, sandbox: SandboxConfig,
} }
@@ -49,6 +50,14 @@ pub struct RuntimeFeatureConfig {
pub struct RuntimeHookConfig { pub struct RuntimeHookConfig {
pre_tool_use: Vec<String>, pre_tool_use: Vec<String>,
post_tool_use: Vec<String>, post_tool_use: Vec<String>,
post_tool_use_failure: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct RuntimePermissionRuleConfig {
allow: Vec<String>,
deny: Vec<String>,
ask: Vec<String>,
} }
#[derive(Debug, Clone, PartialEq, Eq, Default)] #[derive(Debug, Clone, PartialEq, Eq, Default)]
@@ -235,6 +244,7 @@ impl ConfigLoader {
oauth: parse_optional_oauth_config(&merged_value, "merged settings.oauth")?, oauth: parse_optional_oauth_config(&merged_value, "merged settings.oauth")?,
model: parse_optional_model(&merged_value), model: parse_optional_model(&merged_value),
permission_mode: parse_optional_permission_mode(&merged_value)?, permission_mode: parse_optional_permission_mode(&merged_value)?,
permission_rules: parse_optional_permission_rules(&merged_value)?,
sandbox: parse_optional_sandbox_config(&merged_value)?, sandbox: parse_optional_sandbox_config(&merged_value)?,
}; };
@@ -306,6 +316,11 @@ impl RuntimeConfig {
self.feature_config.permission_mode self.feature_config.permission_mode
} }
#[must_use]
pub fn permission_rules(&self) -> &RuntimePermissionRuleConfig {
&self.feature_config.permission_rules
}
#[must_use] #[must_use]
pub fn sandbox(&self) -> &SandboxConfig { pub fn sandbox(&self) -> &SandboxConfig {
&self.feature_config.sandbox &self.feature_config.sandbox
@@ -344,6 +359,11 @@ impl RuntimeFeatureConfig {
self.permission_mode self.permission_mode
} }
#[must_use]
pub fn permission_rules(&self) -> &RuntimePermissionRuleConfig {
&self.permission_rules
}
#[must_use] #[must_use]
pub fn sandbox(&self) -> &SandboxConfig { pub fn sandbox(&self) -> &SandboxConfig {
&self.sandbox &self.sandbox
@@ -352,10 +372,15 @@ impl RuntimeFeatureConfig {
impl RuntimeHookConfig { impl RuntimeHookConfig {
#[must_use] #[must_use]
pub fn new(pre_tool_use: Vec<String>, post_tool_use: Vec<String>) -> Self { pub fn new(
pre_tool_use: Vec<String>,
post_tool_use: Vec<String>,
post_tool_use_failure: Vec<String>,
) -> Self {
Self { Self {
pre_tool_use, pre_tool_use,
post_tool_use, post_tool_use,
post_tool_use_failure,
} }
} }
@@ -368,6 +393,33 @@ impl RuntimeHookConfig {
pub fn post_tool_use(&self) -> &[String] { pub fn post_tool_use(&self) -> &[String] {
&self.post_tool_use &self.post_tool_use
} }
#[must_use]
pub fn post_tool_use_failure(&self) -> &[String] {
&self.post_tool_use_failure
}
}
impl RuntimePermissionRuleConfig {
#[must_use]
pub fn new(allow: Vec<String>, deny: Vec<String>, ask: Vec<String>) -> Self {
Self { allow, deny, ask }
}
#[must_use]
pub fn allow(&self) -> &[String] {
&self.allow
}
#[must_use]
pub fn deny(&self) -> &[String] {
&self.deny
}
#[must_use]
pub fn ask(&self) -> &[String] {
&self.ask
}
} }
impl McpConfigCollection { impl McpConfigCollection {
@@ -481,6 +533,32 @@ fn parse_optional_hooks_config(root: &JsonValue) -> Result<RuntimeHookConfig, Co
.unwrap_or_default(), .unwrap_or_default(),
post_tool_use: optional_string_array(hooks, "PostToolUse", "merged settings.hooks")? post_tool_use: optional_string_array(hooks, "PostToolUse", "merged settings.hooks")?
.unwrap_or_default(), .unwrap_or_default(),
post_tool_use_failure: optional_string_array(
hooks,
"PostToolUseFailure",
"merged settings.hooks",
)?
.unwrap_or_default(),
})
}
fn parse_optional_permission_rules(
root: &JsonValue,
) -> Result<RuntimePermissionRuleConfig, ConfigError> {
let Some(object) = root.as_object() else {
return Ok(RuntimePermissionRuleConfig::default());
};
let Some(permissions) = object.get("permissions").and_then(JsonValue::as_object) else {
return Ok(RuntimePermissionRuleConfig::default());
};
Ok(RuntimePermissionRuleConfig {
allow: optional_string_array(permissions, "allow", "merged settings.permissions")?
.unwrap_or_default(),
deny: optional_string_array(permissions, "deny", "merged settings.permissions")?
.unwrap_or_default(),
ask: optional_string_array(permissions, "ask", "merged settings.permissions")?
.unwrap_or_default(),
}) })
} }
@@ -843,7 +921,7 @@ mod tests {
.expect("write user compat config"); .expect("write user compat config");
fs::write( fs::write(
home.join("settings.json"), home.join("settings.json"),
r#"{"model":"sonnet","env":{"A2":"1"},"hooks":{"PreToolUse":["base"]},"permissions":{"defaultMode":"plan"}}"#, r#"{"model":"sonnet","env":{"A2":"1"},"hooks":{"PreToolUse":["base"]},"permissions":{"defaultMode":"plan","allow":["Read"],"deny":["Bash(rm -rf)"]}}"#,
) )
.expect("write user settings"); .expect("write user settings");
fs::write( fs::write(
@@ -853,7 +931,7 @@ mod tests {
.expect("write project compat config"); .expect("write project compat config");
fs::write( fs::write(
cwd.join(".claude").join("settings.json"), cwd.join(".claude").join("settings.json"),
r#"{"env":{"C":"3"},"hooks":{"PostToolUse":["project"]},"mcpServers":{"project":{"command":"uvx","args":["project"]}}}"#, r#"{"env":{"C":"3"},"hooks":{"PostToolUse":["project"],"PostToolUseFailure":["project-failure"]},"permissions":{"ask":["Edit"]},"mcpServers":{"project":{"command":"uvx","args":["project"]}}}"#,
) )
.expect("write project settings"); .expect("write project settings");
fs::write( fs::write(
@@ -898,6 +976,16 @@ mod tests {
.contains_key("PostToolUse")); .contains_key("PostToolUse"));
assert_eq!(loaded.hooks().pre_tool_use(), &["base".to_string()]); assert_eq!(loaded.hooks().pre_tool_use(), &["base".to_string()]);
assert_eq!(loaded.hooks().post_tool_use(), &["project".to_string()]); assert_eq!(loaded.hooks().post_tool_use(), &["project".to_string()]);
assert_eq!(
loaded.hooks().post_tool_use_failure(),
&["project-failure".to_string()]
);
assert_eq!(loaded.permission_rules().allow(), &["Read".to_string()]);
assert_eq!(
loaded.permission_rules().deny(),
&["Bash(rm -rf)".to_string()]
);
assert_eq!(loaded.permission_rules().ask(), &["Edit".to_string()]);
assert!(loaded.mcp().get("home").is_some()); assert!(loaded.mcp().get("home").is_some());
assert!(loaded.mcp().get("project").is_some()); assert!(loaded.mcp().get("project").is_some());

View File

@@ -5,8 +5,10 @@ use crate::compact::{
compact_session, estimate_session_tokens, CompactionConfig, CompactionResult, compact_session, estimate_session_tokens, CompactionConfig, CompactionResult,
}; };
use crate::config::RuntimeFeatureConfig; use crate::config::RuntimeFeatureConfig;
use crate::hooks::{HookRunResult, HookRunner}; use crate::hooks::{HookAbortSignal, HookProgressReporter, HookRunResult, HookRunner};
use crate::permissions::{PermissionOutcome, PermissionPolicy, PermissionPrompter}; use crate::permissions::{
PermissionContext, PermissionOutcome, PermissionPolicy, PermissionPrompter,
};
use crate::session::{ContentBlock, ConversationMessage, Session}; use crate::session::{ContentBlock, ConversationMessage, Session};
use crate::usage::{TokenUsage, UsageTracker}; use crate::usage::{TokenUsage, UsageTracker};
@@ -25,19 +27,9 @@ pub enum AssistantEvent {
input: String, input: String,
}, },
Usage(TokenUsage), Usage(TokenUsage),
PromptCache(PromptCacheEvent),
MessageStop, MessageStop,
} }
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PromptCacheEvent {
pub unexpected: bool,
pub reason: String,
pub previous_cache_read_input_tokens: u32,
pub current_cache_read_input_tokens: u32,
pub token_drop: u32,
}
pub trait ApiClient { pub trait ApiClient {
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>; fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>;
} }
@@ -94,7 +86,6 @@ impl std::error::Error for RuntimeError {}
pub struct TurnSummary { pub struct TurnSummary {
pub assistant_messages: Vec<ConversationMessage>, pub assistant_messages: Vec<ConversationMessage>,
pub tool_results: Vec<ConversationMessage>, pub tool_results: Vec<ConversationMessage>,
pub prompt_cache_events: Vec<PromptCacheEvent>,
pub iterations: usize, pub iterations: usize,
pub usage: TokenUsage, pub usage: TokenUsage,
} }
@@ -108,6 +99,8 @@ pub struct ConversationRuntime<C, T> {
max_iterations: usize, max_iterations: usize,
usage_tracker: UsageTracker, usage_tracker: UsageTracker,
hook_runner: HookRunner, hook_runner: HookRunner,
hook_abort_signal: HookAbortSignal,
hook_progress_reporter: Option<Box<dyn HookProgressReporter>>,
} }
impl<C, T> ConversationRuntime<C, T> impl<C, T> ConversationRuntime<C, T>
@@ -129,18 +122,19 @@ where
tool_executor, tool_executor,
permission_policy, permission_policy,
system_prompt, system_prompt,
&RuntimeFeatureConfig::default(), RuntimeFeatureConfig::default(),
) )
} }
#[must_use] #[must_use]
#[allow(clippy::needless_pass_by_value)]
pub fn new_with_features( pub fn new_with_features(
session: Session, session: Session,
api_client: C, api_client: C,
tool_executor: T, tool_executor: T,
permission_policy: PermissionPolicy, permission_policy: PermissionPolicy,
system_prompt: Vec<String>, system_prompt: Vec<String>,
feature_config: &RuntimeFeatureConfig, feature_config: RuntimeFeatureConfig,
) -> Self { ) -> Self {
let usage_tracker = UsageTracker::from_session(&session); let usage_tracker = UsageTracker::from_session(&session);
Self { Self {
@@ -151,7 +145,9 @@ where
system_prompt, system_prompt,
max_iterations: usize::MAX, max_iterations: usize::MAX,
usage_tracker, usage_tracker,
hook_runner: HookRunner::from_feature_config(feature_config), hook_runner: HookRunner::from_feature_config(&feature_config),
hook_abort_signal: HookAbortSignal::default(),
hook_progress_reporter: None,
} }
} }
@@ -161,6 +157,93 @@ where
self self
} }
#[must_use]
pub fn with_hook_abort_signal(mut self, hook_abort_signal: HookAbortSignal) -> Self {
self.hook_abort_signal = hook_abort_signal;
self
}
#[must_use]
pub fn with_hook_progress_reporter(
mut self,
hook_progress_reporter: Box<dyn HookProgressReporter>,
) -> Self {
self.hook_progress_reporter = Some(hook_progress_reporter);
self
}
fn run_pre_tool_use_hook(&mut self, tool_name: &str, input: &str) -> HookRunResult {
if let Some(reporter) = self.hook_progress_reporter.as_mut() {
self.hook_runner.run_pre_tool_use_with_context(
tool_name,
input,
Some(&self.hook_abort_signal),
Some(reporter.as_mut()),
)
} else {
self.hook_runner.run_pre_tool_use_with_context(
tool_name,
input,
Some(&self.hook_abort_signal),
None,
)
}
}
fn run_post_tool_use_hook(
&mut self,
tool_name: &str,
input: &str,
output: &str,
is_error: bool,
) -> HookRunResult {
if let Some(reporter) = self.hook_progress_reporter.as_mut() {
self.hook_runner.run_post_tool_use_with_context(
tool_name,
input,
output,
is_error,
Some(&self.hook_abort_signal),
Some(reporter.as_mut()),
)
} else {
self.hook_runner.run_post_tool_use_with_context(
tool_name,
input,
output,
is_error,
Some(&self.hook_abort_signal),
None,
)
}
}
fn run_post_tool_use_failure_hook(
&mut self,
tool_name: &str,
input: &str,
output: &str,
) -> HookRunResult {
if let Some(reporter) = self.hook_progress_reporter.as_mut() {
self.hook_runner.run_post_tool_use_failure_with_context(
tool_name,
input,
output,
Some(&self.hook_abort_signal),
Some(reporter.as_mut()),
)
} else {
self.hook_runner.run_post_tool_use_failure_with_context(
tool_name,
input,
output,
Some(&self.hook_abort_signal),
None,
)
}
}
#[allow(clippy::too_many_lines)]
pub fn run_turn( pub fn run_turn(
&mut self, &mut self,
user_input: impl Into<String>, user_input: impl Into<String>,
@@ -172,7 +255,6 @@ where
let mut assistant_messages = Vec::new(); let mut assistant_messages = Vec::new();
let mut tool_results = Vec::new(); let mut tool_results = Vec::new();
let mut prompt_cache_events = Vec::new();
let mut iterations = 0; let mut iterations = 0;
loop { loop {
@@ -188,12 +270,10 @@ where
messages: self.session.messages.clone(), messages: self.session.messages.clone(),
}; };
let events = self.api_client.stream(request)?; let events = self.api_client.stream(request)?;
let (assistant_message, usage, turn_prompt_cache_events) = let (assistant_message, usage) = build_assistant_message(events)?;
build_assistant_message(events)?;
if let Some(usage) = usage { if let Some(usage) = usage {
self.usage_tracker.record(usage); self.usage_tracker.record(usage);
} }
prompt_cache_events.extend(turn_prompt_cache_events);
let pending_tool_uses = assistant_message let pending_tool_uses = assistant_message
.blocks .blocks
.iter() .iter()
@@ -213,55 +293,85 @@ where
} }
for (tool_use_id, tool_name, input) in pending_tool_uses { for (tool_use_id, tool_name, input) in pending_tool_uses {
let permission_outcome = if let Some(prompt) = prompter.as_mut() { let pre_hook_result = self.run_pre_tool_use_hook(&tool_name, &input);
self.permission_policy let effective_input = pre_hook_result
.authorize(&tool_name, &input, Some(*prompt)) .updated_input()
.map_or_else(|| input.clone(), ToOwned::to_owned);
let permission_context = PermissionContext::new(
pre_hook_result.permission_override(),
pre_hook_result.permission_reason().map(ToOwned::to_owned),
);
let permission_outcome = if pre_hook_result.is_cancelled() {
PermissionOutcome::Deny {
reason: format_hook_message(
&pre_hook_result,
&format!("PreToolUse hook cancelled tool `{tool_name}`"),
),
}
} else if pre_hook_result.is_denied() {
PermissionOutcome::Deny {
reason: format_hook_message(
&pre_hook_result,
&format!("PreToolUse hook denied tool `{tool_name}`"),
),
}
} else if let Some(prompt) = prompter.as_mut() {
self.permission_policy.authorize_with_context(
&tool_name,
&effective_input,
&permission_context,
Some(*prompt),
)
} else { } else {
self.permission_policy.authorize(&tool_name, &input, None) self.permission_policy.authorize_with_context(
&tool_name,
&effective_input,
&permission_context,
None,
)
}; };
let result_message = match permission_outcome { let result_message = match permission_outcome {
PermissionOutcome::Allow => { PermissionOutcome::Allow => {
let pre_hook_result = self.hook_runner.run_pre_tool_use(&tool_name, &input); let (mut output, mut is_error) =
if pre_hook_result.is_denied() { match self.tool_executor.execute(&tool_name, &effective_input) {
let deny_message = format!("PreToolUse hook denied tool `{tool_name}`"); Ok(output) => (output, false),
ConversationMessage::tool_result( Err(error) => (error.to_string(), true),
tool_use_id, };
tool_name, output = merge_hook_feedback(pre_hook_result.messages(), output, false);
format_hook_message(&pre_hook_result, &deny_message),
true, let post_hook_result = if is_error {
self.run_post_tool_use_failure_hook(
&tool_name,
&effective_input,
&output,
) )
} else { } else {
let (mut output, mut is_error) = self.run_post_tool_use_hook(
match self.tool_executor.execute(&tool_name, &input) { &tool_name,
Ok(output) => (output, false), &effective_input,
Err(error) => (error.to_string(), true), &output,
}; false,
output = merge_hook_feedback(pre_hook_result.messages(), output, false);
let post_hook_result = self
.hook_runner
.run_post_tool_use(&tool_name, &input, &output, is_error);
if post_hook_result.is_denied() {
is_error = true;
}
output = merge_hook_feedback(
post_hook_result.messages(),
output,
post_hook_result.is_denied(),
);
ConversationMessage::tool_result(
tool_use_id,
tool_name,
output,
is_error,
) )
};
if post_hook_result.is_denied() || post_hook_result.is_cancelled() {
is_error = true;
} }
output = merge_hook_feedback(
post_hook_result.messages(),
output,
post_hook_result.is_denied() || post_hook_result.is_cancelled(),
);
ConversationMessage::tool_result(tool_use_id, tool_name, output, is_error)
} }
PermissionOutcome::Deny { reason } => { PermissionOutcome::Deny { reason } => ConversationMessage::tool_result(
ConversationMessage::tool_result(tool_use_id, tool_name, reason, true) tool_use_id,
} tool_name,
merge_hook_feedback(pre_hook_result.messages(), reason, true),
true,
),
}; };
self.session.messages.push(result_message.clone()); self.session.messages.push(result_message.clone());
tool_results.push(result_message); tool_results.push(result_message);
@@ -271,7 +381,6 @@ where
Ok(TurnSummary { Ok(TurnSummary {
assistant_messages, assistant_messages,
tool_results, tool_results,
prompt_cache_events,
iterations, iterations,
usage: self.usage_tracker.cumulative_usage(), usage: self.usage_tracker.cumulative_usage(),
}) })
@@ -305,17 +414,9 @@ where
fn build_assistant_message( fn build_assistant_message(
events: Vec<AssistantEvent>, events: Vec<AssistantEvent>,
) -> Result< ) -> Result<(ConversationMessage, Option<TokenUsage>), RuntimeError> {
(
ConversationMessage,
Option<TokenUsage>,
Vec<PromptCacheEvent>,
),
RuntimeError,
> {
let mut text = String::new(); let mut text = String::new();
let mut blocks = Vec::new(); let mut blocks = Vec::new();
let mut prompt_cache_events = Vec::new();
let mut finished = false; let mut finished = false;
let mut usage = None; let mut usage = None;
@@ -327,7 +428,6 @@ fn build_assistant_message(
blocks.push(ContentBlock::ToolUse { id, name, input }); blocks.push(ContentBlock::ToolUse { id, name, input });
} }
AssistantEvent::Usage(value) => usage = Some(value), AssistantEvent::Usage(value) => usage = Some(value),
AssistantEvent::PromptCache(event) => prompt_cache_events.push(event),
AssistantEvent::MessageStop => { AssistantEvent::MessageStop => {
finished = true; finished = true;
} }
@@ -348,7 +448,6 @@ fn build_assistant_message(
Ok(( Ok((
ConversationMessage::assistant_with_usage(blocks, usage), ConversationMessage::assistant_with_usage(blocks, usage),
usage, usage,
prompt_cache_events,
)) ))
} }
@@ -421,7 +520,7 @@ impl ToolExecutor for StaticToolExecutor {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{ use super::{
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, PromptCacheEvent, RuntimeError, ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError,
StaticToolExecutor, StaticToolExecutor,
}; };
use crate::compact::CompactionConfig; use crate::compact::CompactionConfig;
@@ -478,15 +577,6 @@ mod tests {
cache_creation_input_tokens: 1, cache_creation_input_tokens: 1,
cache_read_input_tokens: 3, cache_read_input_tokens: 3,
}), }),
AssistantEvent::PromptCache(PromptCacheEvent {
unexpected: true,
reason:
"cache read tokens dropped while prompt fingerprint remained stable"
.to_string(),
previous_cache_read_input_tokens: 6_000,
current_cache_read_input_tokens: 1_000,
token_drop: 5_000,
}),
AssistantEvent::MessageStop, AssistantEvent::MessageStop,
]) ])
} }
@@ -540,10 +630,8 @@ mod tests {
assert_eq!(summary.iterations, 2); assert_eq!(summary.iterations, 2);
assert_eq!(summary.assistant_messages.len(), 2); assert_eq!(summary.assistant_messages.len(), 2);
assert_eq!(summary.tool_results.len(), 1); assert_eq!(summary.tool_results.len(), 1);
assert_eq!(summary.prompt_cache_events.len(), 1);
assert_eq!(runtime.session().messages.len(), 4); assert_eq!(runtime.session().messages.len(), 4);
assert_eq!(summary.usage.output_tokens, 10); assert_eq!(summary.usage.output_tokens, 10);
assert!(summary.prompt_cache_events[0].unexpected);
assert!(matches!( assert!(matches!(
runtime.session().messages[1].blocks[1], runtime.session().messages[1].blocks[1],
ContentBlock::ToolUse { .. } ContentBlock::ToolUse { .. }
@@ -645,9 +733,10 @@ mod tests {
}), }),
PermissionPolicy::new(PermissionMode::DangerFullAccess), PermissionPolicy::new(PermissionMode::DangerFullAccess),
vec!["system".to_string()], vec!["system".to_string()],
&RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
vec![shell_snippet("printf 'blocked by hook'; exit 2")], vec![shell_snippet("printf 'blocked by hook'; exit 2")],
Vec::new(), Vec::new(),
Vec::new(),
)), )),
); );
@@ -711,9 +800,10 @@ mod tests {
StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())), StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
PermissionPolicy::new(PermissionMode::DangerFullAccess), PermissionPolicy::new(PermissionMode::DangerFullAccess),
vec!["system".to_string()], vec!["system".to_string()],
&RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
vec![shell_snippet("printf 'pre hook ran'")], vec![shell_snippet("printf 'pre hook ran'")],
vec![shell_snippet("printf 'post hook ran'")], vec![shell_snippet("printf 'post hook ran'")],
Vec::new(),
)), )),
); );

View File

@@ -1,29 +1,90 @@
use std::ffi::OsStr; use std::ffi::OsStr;
use std::process::Command; use std::io::Write;
use std::process::{Command, Stdio};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use std::thread;
use std::time::Duration;
use serde_json::json; use serde_json::{json, Value};
use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig}; use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
use crate::permissions::PermissionOverride;
pub type HookPermissionDecision = PermissionOverride;
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HookEvent { pub enum HookEvent {
PreToolUse, PreToolUse,
PostToolUse, PostToolUse,
PostToolUseFailure,
} }
impl HookEvent { impl HookEvent {
fn as_str(self) -> &'static str { #[must_use]
pub fn as_str(self) -> &'static str {
match self { match self {
Self::PreToolUse => "PreToolUse", Self::PreToolUse => "PreToolUse",
Self::PostToolUse => "PostToolUse", Self::PostToolUse => "PostToolUse",
Self::PostToolUseFailure => "PostToolUseFailure",
} }
} }
} }
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum HookProgressEvent {
Started {
event: HookEvent,
tool_name: String,
command: String,
},
Completed {
event: HookEvent,
tool_name: String,
command: String,
},
Cancelled {
event: HookEvent,
tool_name: String,
command: String,
},
}
pub trait HookProgressReporter {
fn on_event(&mut self, event: &HookProgressEvent);
}
#[derive(Debug, Clone, Default)]
pub struct HookAbortSignal {
aborted: Arc<AtomicBool>,
}
impl HookAbortSignal {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn abort(&self) {
self.aborted.store(true, Ordering::SeqCst);
}
#[must_use]
pub fn is_aborted(&self) -> bool {
self.aborted.load(Ordering::SeqCst)
}
}
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct HookRunResult { pub struct HookRunResult {
denied: bool, denied: bool,
cancelled: bool,
messages: Vec<String>, messages: Vec<String>,
permission_override: Option<PermissionOverride>,
permission_reason: Option<String>,
updated_input: Option<String>,
} }
impl HookRunResult { impl HookRunResult {
@@ -31,7 +92,11 @@ impl HookRunResult {
pub fn allow(messages: Vec<String>) -> Self { pub fn allow(messages: Vec<String>) -> Self {
Self { Self {
denied: false, denied: false,
cancelled: false,
messages, messages,
permission_override: None,
permission_reason: None,
updated_input: None,
} }
} }
@@ -40,10 +105,40 @@ impl HookRunResult {
self.denied self.denied
} }
#[must_use]
pub fn is_cancelled(&self) -> bool {
self.cancelled
}
#[must_use] #[must_use]
pub fn messages(&self) -> &[String] { pub fn messages(&self) -> &[String] {
&self.messages &self.messages
} }
#[must_use]
pub fn permission_override(&self) -> Option<PermissionOverride> {
self.permission_override
}
#[must_use]
pub fn permission_decision(&self) -> Option<HookPermissionDecision> {
self.permission_override
}
#[must_use]
pub fn permission_reason(&self) -> Option<&str> {
self.permission_reason.as_deref()
}
#[must_use]
pub fn updated_input(&self) -> Option<&str> {
self.updated_input.as_deref()
}
#[must_use]
pub fn updated_input_json(&self) -> Option<&str> {
self.updated_input()
}
} }
#[derive(Debug, Clone, PartialEq, Eq, Default)] #[derive(Debug, Clone, PartialEq, Eq, Default)]
@@ -64,6 +159,17 @@ impl HookRunner {
#[must_use] #[must_use]
pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult { pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult {
self.run_pre_tool_use_with_context(tool_name, tool_input, None, None)
}
#[must_use]
pub fn run_pre_tool_use_with_context(
&self,
tool_name: &str,
tool_input: &str,
abort_signal: Option<&HookAbortSignal>,
reporter: Option<&mut dyn HookProgressReporter>,
) -> HookRunResult {
Self::run_commands( Self::run_commands(
HookEvent::PreToolUse, HookEvent::PreToolUse,
self.config.pre_tool_use(), self.config.pre_tool_use(),
@@ -71,9 +177,21 @@ impl HookRunner {
tool_input, tool_input,
None, None,
false, false,
abort_signal,
reporter,
) )
} }
#[must_use]
pub fn run_pre_tool_use_with_signal(
&self,
tool_name: &str,
tool_input: &str,
abort_signal: Option<&HookAbortSignal>,
) -> HookRunResult {
self.run_pre_tool_use_with_context(tool_name, tool_input, abort_signal, None)
}
#[must_use] #[must_use]
pub fn run_post_tool_use( pub fn run_post_tool_use(
&self, &self,
@@ -81,6 +199,26 @@ impl HookRunner {
tool_input: &str, tool_input: &str,
tool_output: &str, tool_output: &str,
is_error: bool, is_error: bool,
) -> HookRunResult {
self.run_post_tool_use_with_context(
tool_name,
tool_input,
tool_output,
is_error,
None,
None,
)
}
#[must_use]
pub fn run_post_tool_use_with_context(
&self,
tool_name: &str,
tool_input: &str,
tool_output: &str,
is_error: bool,
abort_signal: Option<&HookAbortSignal>,
reporter: Option<&mut dyn HookProgressReporter>,
) -> HookRunResult { ) -> HookRunResult {
Self::run_commands( Self::run_commands(
HookEvent::PostToolUse, HookEvent::PostToolUse,
@@ -89,9 +227,79 @@ impl HookRunner {
tool_input, tool_input,
Some(tool_output), Some(tool_output),
is_error, is_error,
abort_signal,
reporter,
) )
} }
#[must_use]
pub fn run_post_tool_use_with_signal(
&self,
tool_name: &str,
tool_input: &str,
tool_output: &str,
is_error: bool,
abort_signal: Option<&HookAbortSignal>,
) -> HookRunResult {
self.run_post_tool_use_with_context(
tool_name,
tool_input,
tool_output,
is_error,
abort_signal,
None,
)
}
#[must_use]
pub fn run_post_tool_use_failure(
&self,
tool_name: &str,
tool_input: &str,
tool_error: &str,
) -> HookRunResult {
self.run_post_tool_use_failure_with_context(tool_name, tool_input, tool_error, None, None)
}
#[must_use]
pub fn run_post_tool_use_failure_with_context(
&self,
tool_name: &str,
tool_input: &str,
tool_error: &str,
abort_signal: Option<&HookAbortSignal>,
reporter: Option<&mut dyn HookProgressReporter>,
) -> HookRunResult {
Self::run_commands(
HookEvent::PostToolUseFailure,
self.config.post_tool_use_failure(),
tool_name,
tool_input,
Some(tool_error),
true,
abort_signal,
reporter,
)
}
#[must_use]
pub fn run_post_tool_use_failure_with_signal(
&self,
tool_name: &str,
tool_input: &str,
tool_error: &str,
abort_signal: Option<&HookAbortSignal>,
) -> HookRunResult {
self.run_post_tool_use_failure_with_context(
tool_name,
tool_input,
tool_error,
abort_signal,
None,
)
}
#[allow(clippy::too_many_arguments)]
fn run_commands( fn run_commands(
event: HookEvent, event: HookEvent,
commands: &[String], commands: &[String],
@@ -99,24 +307,39 @@ impl HookRunner {
tool_input: &str, tool_input: &str,
tool_output: Option<&str>, tool_output: Option<&str>,
is_error: bool, is_error: bool,
abort_signal: Option<&HookAbortSignal>,
mut reporter: Option<&mut dyn HookProgressReporter>,
) -> HookRunResult { ) -> HookRunResult {
if commands.is_empty() { if commands.is_empty() {
return HookRunResult::allow(Vec::new()); return HookRunResult::allow(Vec::new());
} }
let payload = json!({ if abort_signal.is_some_and(HookAbortSignal::is_aborted) {
"hook_event_name": event.as_str(), return HookRunResult {
"tool_name": tool_name, denied: false,
"tool_input": parse_tool_input(tool_input), cancelled: true,
"tool_input_json": tool_input, messages: vec![format!(
"tool_output": tool_output, "{} hook cancelled before execution",
"tool_result_is_error": is_error, event.as_str()
}) )],
.to_string(); permission_override: None,
permission_reason: None,
updated_input: None,
};
}
let mut messages = Vec::new(); let payload = hook_payload(event, tool_name, tool_input, tool_output, is_error).to_string();
let mut result = HookRunResult::allow(Vec::new());
for command in commands { for command in commands {
if let Some(reporter) = reporter.as_deref_mut() {
reporter.on_event(&HookProgressEvent::Started {
event,
tool_name: tool_name.to_string(),
command: command.clone(),
});
}
match Self::run_command( match Self::run_command(
command, command,
event, event,
@@ -125,29 +348,59 @@ impl HookRunner {
tool_output, tool_output,
is_error, is_error,
&payload, &payload,
abort_signal,
) { ) {
HookCommandOutcome::Allow { message } => { HookCommandOutcome::Allow { parsed } => {
if let Some(message) = message { if let Some(reporter) = reporter.as_deref_mut() {
messages.push(message); reporter.on_event(&HookProgressEvent::Completed {
event,
tool_name: tool_name.to_string(),
command: command.clone(),
});
} }
merge_parsed_hook_output(&mut result, parsed);
} }
HookCommandOutcome::Deny { message } => { HookCommandOutcome::Deny { parsed } => {
let message = message.unwrap_or_else(|| { if let Some(reporter) = reporter.as_deref_mut() {
format!("{} hook denied tool `{tool_name}`", event.as_str()) reporter.on_event(&HookProgressEvent::Completed {
}); event,
messages.push(message); tool_name: tool_name.to_string(),
return HookRunResult { command: command.clone(),
denied: true, });
messages, }
}; merge_parsed_hook_output(&mut result, parsed);
result.denied = true;
return result;
}
HookCommandOutcome::Warn { message } => {
if let Some(reporter) = reporter.as_deref_mut() {
reporter.on_event(&HookProgressEvent::Completed {
event,
tool_name: tool_name.to_string(),
command: command.clone(),
});
}
result.messages.push(message);
}
HookCommandOutcome::Cancelled { message } => {
if let Some(reporter) = reporter.as_deref_mut() {
reporter.on_event(&HookProgressEvent::Cancelled {
event,
tool_name: tool_name.to_string(),
command: command.clone(),
});
}
result.cancelled = true;
result.messages.push(message);
return result;
} }
HookCommandOutcome::Warn { message } => messages.push(message),
} }
} }
HookRunResult::allow(messages) result
} }
#[allow(clippy::too_many_arguments)]
fn run_command( fn run_command(
command: &str, command: &str,
event: HookEvent, event: HookEvent,
@@ -156,11 +409,12 @@ impl HookRunner {
tool_output: Option<&str>, tool_output: Option<&str>,
is_error: bool, is_error: bool,
payload: &str, payload: &str,
abort_signal: Option<&HookAbortSignal>,
) -> HookCommandOutcome { ) -> HookCommandOutcome {
let mut child = shell_command(command); let mut child = shell_command(command);
child.stdin(std::process::Stdio::piped()); child.stdin(Stdio::piped());
child.stdout(std::process::Stdio::piped()); child.stdout(Stdio::piped());
child.stderr(std::process::Stdio::piped()); child.stderr(Stdio::piped());
child.env("HOOK_EVENT", event.as_str()); child.env("HOOK_EVENT", event.as_str());
child.env("HOOK_TOOL_NAME", tool_name); child.env("HOOK_TOOL_NAME", tool_name);
child.env("HOOK_TOOL_INPUT", tool_input); child.env("HOOK_TOOL_INPUT", tool_input);
@@ -169,19 +423,30 @@ impl HookRunner {
child.env("HOOK_TOOL_OUTPUT", tool_output); child.env("HOOK_TOOL_OUTPUT", tool_output);
} }
match child.output_with_stdin(payload.as_bytes()) { match child.output_with_stdin(payload.as_bytes(), abort_signal) {
Ok(output) => { Ok(CommandExecution::Finished(output)) => {
let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
let message = (!stdout.is_empty()).then_some(stdout); let parsed = parse_hook_output(&stdout);
match output.status.code() { match output.status.code() {
Some(0) => HookCommandOutcome::Allow { message }, Some(0) => {
Some(2) => HookCommandOutcome::Deny { message }, if parsed.deny {
HookCommandOutcome::Deny { parsed }
} else {
HookCommandOutcome::Allow { parsed }
}
}
Some(2) => HookCommandOutcome::Deny {
parsed: parsed.with_fallback_message(format!(
"{} hook denied tool `{tool_name}`",
event.as_str()
)),
},
Some(code) => HookCommandOutcome::Warn { Some(code) => HookCommandOutcome::Warn {
message: format_hook_warning( message: format_hook_warning(
command, command,
code, code,
message.as_deref(), parsed.primary_message(),
stderr.as_str(), stderr.as_str(),
), ),
}, },
@@ -193,6 +458,12 @@ impl HookRunner {
}, },
} }
} }
Ok(CommandExecution::Cancelled) => HookCommandOutcome::Cancelled {
message: format!(
"{} hook `{command}` cancelled while handling `{tool_name}`",
event.as_str()
),
},
Err(error) => HookCommandOutcome::Warn { Err(error) => HookCommandOutcome::Warn {
message: format!( message: format!(
"{} hook `{command}` failed to start for `{tool_name}`: {error}", "{} hook `{command}` failed to start for `{tool_name}`: {error}",
@@ -204,12 +475,131 @@ impl HookRunner {
} }
enum HookCommandOutcome { enum HookCommandOutcome {
Allow { message: Option<String> }, Allow { parsed: ParsedHookOutput },
Deny { message: Option<String> }, Deny { parsed: ParsedHookOutput },
Warn { message: String }, Warn { message: String },
Cancelled { message: String },
} }
fn parse_tool_input(tool_input: &str) -> serde_json::Value { #[derive(Debug, Clone, PartialEq, Eq, Default)]
struct ParsedHookOutput {
messages: Vec<String>,
deny: bool,
permission_override: Option<PermissionOverride>,
permission_reason: Option<String>,
updated_input: Option<String>,
}
impl ParsedHookOutput {
fn with_fallback_message(mut self, fallback: String) -> Self {
if self.messages.is_empty() {
self.messages.push(fallback);
}
self
}
fn primary_message(&self) -> Option<&str> {
self.messages.first().map(String::as_str)
}
}
fn merge_parsed_hook_output(target: &mut HookRunResult, parsed: ParsedHookOutput) {
target.messages.extend(parsed.messages);
if parsed.permission_override.is_some() {
target.permission_override = parsed.permission_override;
}
if parsed.permission_reason.is_some() {
target.permission_reason = parsed.permission_reason;
}
if parsed.updated_input.is_some() {
target.updated_input = parsed.updated_input;
}
}
fn parse_hook_output(stdout: &str) -> ParsedHookOutput {
if stdout.is_empty() {
return ParsedHookOutput::default();
}
let Ok(Value::Object(root)) = serde_json::from_str::<Value>(stdout) else {
return ParsedHookOutput {
messages: vec![stdout.to_string()],
..ParsedHookOutput::default()
};
};
let mut parsed = ParsedHookOutput::default();
if let Some(message) = root.get("systemMessage").and_then(Value::as_str) {
parsed.messages.push(message.to_string());
}
if let Some(message) = root.get("reason").and_then(Value::as_str) {
parsed.messages.push(message.to_string());
}
if root.get("continue").and_then(Value::as_bool) == Some(false)
|| root.get("decision").and_then(Value::as_str) == Some("block")
{
parsed.deny = true;
}
if let Some(Value::Object(specific)) = root.get("hookSpecificOutput") {
if let Some(Value::String(additional_context)) = specific.get("additionalContext") {
parsed.messages.push(additional_context.clone());
}
if let Some(decision) = specific.get("permissionDecision").and_then(Value::as_str) {
parsed.permission_override = match decision {
"allow" => Some(PermissionOverride::Allow),
"deny" => Some(PermissionOverride::Deny),
"ask" => Some(PermissionOverride::Ask),
_ => None,
};
}
if let Some(reason) = specific
.get("permissionDecisionReason")
.and_then(Value::as_str)
{
parsed.permission_reason = Some(reason.to_string());
}
if let Some(updated_input) = specific.get("updatedInput") {
parsed.updated_input = serde_json::to_string(updated_input).ok();
}
}
if parsed.messages.is_empty() {
parsed.messages.push(stdout.to_string());
}
parsed
}
fn hook_payload(
event: HookEvent,
tool_name: &str,
tool_input: &str,
tool_output: Option<&str>,
is_error: bool,
) -> Value {
match event {
HookEvent::PostToolUseFailure => json!({
"hook_event_name": event.as_str(),
"tool_name": tool_name,
"tool_input": parse_tool_input(tool_input),
"tool_input_json": tool_input,
"tool_error": tool_output,
"tool_result_is_error": true,
}),
_ => json!({
"hook_event_name": event.as_str(),
"tool_name": tool_name,
"tool_input": parse_tool_input(tool_input),
"tool_input_json": tool_input,
"tool_output": tool_output,
"tool_result_is_error": is_error,
}),
}
}
fn parse_tool_input(tool_input: &str) -> Value {
serde_json::from_str(tool_input).unwrap_or_else(|_| json!({ "raw": tool_input })) serde_json::from_str(tool_input).unwrap_or_else(|_| json!({ "raw": tool_input }))
} }
@@ -253,17 +643,17 @@ impl CommandWithStdin {
Self { command } Self { command }
} }
fn stdin(&mut self, cfg: std::process::Stdio) -> &mut Self { fn stdin(&mut self, cfg: Stdio) -> &mut Self {
self.command.stdin(cfg); self.command.stdin(cfg);
self self
} }
fn stdout(&mut self, cfg: std::process::Stdio) -> &mut Self { fn stdout(&mut self, cfg: Stdio) -> &mut Self {
self.command.stdout(cfg); self.command.stdout(cfg);
self self
} }
fn stderr(&mut self, cfg: std::process::Stdio) -> &mut Self { fn stderr(&mut self, cfg: Stdio) -> &mut Self {
self.command.stderr(cfg); self.command.stderr(cfg);
self self
} }
@@ -277,26 +667,64 @@ impl CommandWithStdin {
self self
} }
fn output_with_stdin(&mut self, stdin: &[u8]) -> std::io::Result<std::process::Output> { fn output_with_stdin(
&mut self,
stdin: &[u8],
abort_signal: Option<&HookAbortSignal>,
) -> std::io::Result<CommandExecution> {
let mut child = self.command.spawn()?; let mut child = self.command.spawn()?;
if let Some(mut child_stdin) = child.stdin.take() { if let Some(mut child_stdin) = child.stdin.take() {
use std::io::Write;
child_stdin.write_all(stdin)?; child_stdin.write_all(stdin)?;
} }
child.wait_with_output()
loop {
if abort_signal.is_some_and(HookAbortSignal::is_aborted) {
let _ = child.kill();
let _ = child.wait_with_output();
return Ok(CommandExecution::Cancelled);
}
match child.try_wait()? {
Some(_) => return child.wait_with_output().map(CommandExecution::Finished),
None => thread::sleep(Duration::from_millis(20)),
}
}
} }
} }
enum CommandExecution {
Finished(std::process::Output),
Cancelled,
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{HookRunResult, HookRunner}; use std::thread;
use std::time::Duration;
use super::{
HookAbortSignal, HookEvent, HookProgressEvent, HookProgressReporter, HookRunResult,
HookRunner,
};
use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig}; use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
use crate::permissions::PermissionOverride;
struct RecordingReporter {
events: Vec<HookProgressEvent>,
}
impl HookProgressReporter for RecordingReporter {
fn on_event(&mut self, event: &HookProgressEvent) {
self.events.push(event.clone());
}
}
#[test] #[test]
fn allows_exit_code_zero_and_captures_stdout() { fn allows_exit_code_zero_and_captures_stdout() {
let runner = HookRunner::new(RuntimeHookConfig::new( let runner = HookRunner::new(RuntimeHookConfig::new(
vec![shell_snippet("printf 'pre ok'")], vec![shell_snippet("printf 'pre ok'")],
Vec::new(), Vec::new(),
Vec::new(),
)); ));
let result = runner.run_pre_tool_use("Read", r#"{"path":"README.md"}"#); let result = runner.run_pre_tool_use("Read", r#"{"path":"README.md"}"#);
@@ -309,6 +737,7 @@ mod tests {
let runner = HookRunner::new(RuntimeHookConfig::new( let runner = HookRunner::new(RuntimeHookConfig::new(
vec![shell_snippet("printf 'blocked by hook'; exit 2")], vec![shell_snippet("printf 'blocked by hook'; exit 2")],
Vec::new(), Vec::new(),
Vec::new(),
)); ));
let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#); let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#);
@@ -323,6 +752,7 @@ mod tests {
RuntimeHookConfig::new( RuntimeHookConfig::new(
vec![shell_snippet("printf 'warning hook'; exit 1")], vec![shell_snippet("printf 'warning hook'; exit 1")],
Vec::new(), Vec::new(),
Vec::new(),
), ),
)); ));
@@ -335,6 +765,82 @@ mod tests {
.any(|message| message.contains("allowing tool execution to continue"))); .any(|message| message.contains("allowing tool execution to continue")));
} }
#[test]
fn parses_pre_hook_permission_override_and_updated_input() {
let runner = HookRunner::new(RuntimeHookConfig::new(
vec![shell_snippet(
r#"printf '%s' '{"systemMessage":"updated","hookSpecificOutput":{"permissionDecision":"allow","permissionDecisionReason":"hook ok","updatedInput":{"command":"git status"}}}'"#,
)],
Vec::new(),
Vec::new(),
));
let result = runner.run_pre_tool_use("bash", r#"{"command":"pwd"}"#);
assert_eq!(
result.permission_override(),
Some(PermissionOverride::Allow)
);
assert_eq!(result.permission_reason(), Some("hook ok"));
assert_eq!(result.updated_input(), Some(r#"{"command":"git status"}"#));
assert!(result.messages().iter().any(|message| message == "updated"));
}
#[test]
fn runs_post_tool_use_failure_hooks() {
let runner = HookRunner::new(RuntimeHookConfig::new(
Vec::new(),
Vec::new(),
vec![shell_snippet("printf 'failure hook ran'")],
));
let result =
runner.run_post_tool_use_failure("bash", r#"{"command":"false"}"#, "command failed");
assert!(!result.is_denied());
assert_eq!(result.messages(), &["failure hook ran".to_string()]);
}
#[test]
fn abort_signal_cancels_long_running_hook_and_reports_progress() {
let runner = HookRunner::new(RuntimeHookConfig::new(
vec![shell_snippet("sleep 5")],
Vec::new(),
Vec::new(),
));
let abort_signal = HookAbortSignal::new();
let abort_signal_for_thread = abort_signal.clone();
let mut reporter = RecordingReporter { events: Vec::new() };
thread::spawn(move || {
thread::sleep(Duration::from_millis(100));
abort_signal_for_thread.abort();
});
let result = runner.run_pre_tool_use_with_context(
"bash",
r#"{"command":"sleep 5"}"#,
Some(&abort_signal),
Some(&mut reporter),
);
assert!(result.is_cancelled());
assert!(reporter.events.iter().any(|event| matches!(
event,
HookProgressEvent::Started {
event: HookEvent::PreToolUse,
..
}
)));
assert!(reporter.events.iter().any(|event| matches!(
event,
HookProgressEvent::Cancelled {
event: HookEvent::PreToolUse,
..
}
)));
}
#[cfg(windows)] #[cfg(windows)]
fn shell_snippet(script: &str) -> String { fn shell_snippet(script: &str) -> String {
script.replace('\'', "\"") script.replace('\'', "\"")

View File

@@ -28,18 +28,20 @@ pub use config::{
McpConfigCollection, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig, McpConfigCollection, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig,
McpServerConfig, McpStdioServerConfig, McpTransport, McpWebSocketServerConfig, OAuthConfig, McpServerConfig, McpStdioServerConfig, McpTransport, McpWebSocketServerConfig, OAuthConfig,
ResolvedPermissionMode, RuntimeConfig, RuntimeFeatureConfig, RuntimeHookConfig, ResolvedPermissionMode, RuntimeConfig, RuntimeFeatureConfig, RuntimeHookConfig,
ScopedMcpServerConfig, CLAUDE_CODE_SETTINGS_SCHEMA_NAME, RuntimePermissionRuleConfig, ScopedMcpServerConfig, CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
}; };
pub use conversation::{ pub use conversation::{
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, PromptCacheEvent, RuntimeError, ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, StaticToolExecutor,
StaticToolExecutor, ToolError, ToolExecutor, TurnSummary, ToolError, ToolExecutor, TurnSummary,
}; };
pub use file_ops::{ pub use file_ops::{
edit_file, glob_search, grep_search, read_file, write_file, EditFileOutput, GlobSearchOutput, edit_file, glob_search, grep_search, read_file, write_file, EditFileOutput, GlobSearchOutput,
GrepSearchInput, GrepSearchOutput, ReadFileOutput, StructuredPatchHunk, TextFilePayload, GrepSearchInput, GrepSearchOutput, ReadFileOutput, StructuredPatchHunk, TextFilePayload,
WriteFileOutput, WriteFileOutput,
}; };
pub use hooks::{HookEvent, HookRunResult, HookRunner}; pub use hooks::{
HookAbortSignal, HookEvent, HookProgressEvent, HookProgressReporter, HookRunResult, HookRunner,
};
pub use mcp::{ pub use mcp::{
mcp_server_signature, mcp_tool_name, mcp_tool_prefix, normalize_name_for_mcp, mcp_server_signature, mcp_tool_name, mcp_tool_prefix, normalize_name_for_mcp,
scoped_mcp_config_hash, unwrap_ccr_proxy_url, scoped_mcp_config_hash, unwrap_ccr_proxy_url,
@@ -64,8 +66,8 @@ pub use oauth::{
PkceChallengeMethod, PkceCodePair, PkceChallengeMethod, PkceCodePair,
}; };
pub use permissions::{ pub use permissions::{
PermissionMode, PermissionOutcome, PermissionPolicy, PermissionPromptDecision, PermissionContext, PermissionMode, PermissionOutcome, PermissionOverride, PermissionPolicy,
PermissionPrompter, PermissionRequest, PermissionPromptDecision, PermissionPrompter, PermissionRequest,
}; };
pub use prompt::{ pub use prompt::{
load_system_prompt, prepend_bullets, ContextFile, ProjectContext, PromptBuildError, load_system_prompt, prepend_bullets, ContextFile, ProjectContext, PromptBuildError,

View File

@@ -1,5 +1,9 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use serde_json::Value;
use crate::config::RuntimePermissionRuleConfig;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum PermissionMode { pub enum PermissionMode {
ReadOnly, ReadOnly,
@@ -22,12 +26,49 @@ impl PermissionMode {
} }
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PermissionOverride {
Allow,
Deny,
Ask,
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct PermissionContext {
override_decision: Option<PermissionOverride>,
override_reason: Option<String>,
}
impl PermissionContext {
#[must_use]
pub fn new(
override_decision: Option<PermissionOverride>,
override_reason: Option<String>,
) -> Self {
Self {
override_decision,
override_reason,
}
}
#[must_use]
pub fn override_decision(&self) -> Option<PermissionOverride> {
self.override_decision
}
#[must_use]
pub fn override_reason(&self) -> Option<&str> {
self.override_reason.as_deref()
}
}
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct PermissionRequest { pub struct PermissionRequest {
pub tool_name: String, pub tool_name: String,
pub input: String, pub input: String,
pub current_mode: PermissionMode, pub current_mode: PermissionMode,
pub required_mode: PermissionMode, pub required_mode: PermissionMode,
pub reason: Option<String>,
} }
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
@@ -50,6 +91,9 @@ pub enum PermissionOutcome {
pub struct PermissionPolicy { pub struct PermissionPolicy {
active_mode: PermissionMode, active_mode: PermissionMode,
tool_requirements: BTreeMap<String, PermissionMode>, tool_requirements: BTreeMap<String, PermissionMode>,
allow_rules: Vec<PermissionRule>,
deny_rules: Vec<PermissionRule>,
ask_rules: Vec<PermissionRule>,
} }
impl PermissionPolicy { impl PermissionPolicy {
@@ -58,6 +102,9 @@ impl PermissionPolicy {
Self { Self {
active_mode, active_mode,
tool_requirements: BTreeMap::new(), tool_requirements: BTreeMap::new(),
allow_rules: Vec::new(),
deny_rules: Vec::new(),
ask_rules: Vec::new(),
} }
} }
@@ -72,6 +119,26 @@ impl PermissionPolicy {
self self
} }
#[must_use]
pub fn with_permission_rules(mut self, config: &RuntimePermissionRuleConfig) -> Self {
self.allow_rules = config
.allow()
.iter()
.map(|rule| PermissionRule::parse(rule))
.collect();
self.deny_rules = config
.deny()
.iter()
.map(|rule| PermissionRule::parse(rule))
.collect();
self.ask_rules = config
.ask()
.iter()
.map(|rule| PermissionRule::parse(rule))
.collect();
self
}
#[must_use] #[must_use]
pub fn active_mode(&self) -> PermissionMode { pub fn active_mode(&self) -> PermissionMode {
self.active_mode self.active_mode
@@ -90,38 +157,121 @@ impl PermissionPolicy {
&self, &self,
tool_name: &str, tool_name: &str,
input: &str, input: &str,
mut prompter: Option<&mut dyn PermissionPrompter>, prompter: Option<&mut dyn PermissionPrompter>,
) -> PermissionOutcome { ) -> PermissionOutcome {
let current_mode = self.active_mode(); self.authorize_with_context(tool_name, input, &PermissionContext::default(), prompter)
let required_mode = self.required_mode_for(tool_name); }
if current_mode == PermissionMode::Allow || current_mode >= required_mode {
return PermissionOutcome::Allow; #[must_use]
#[allow(clippy::too_many_lines)]
pub fn authorize_with_context(
&self,
tool_name: &str,
input: &str,
context: &PermissionContext,
prompter: Option<&mut dyn PermissionPrompter>,
) -> PermissionOutcome {
if let Some(rule) = Self::find_matching_rule(&self.deny_rules, tool_name, input) {
return PermissionOutcome::Deny {
reason: format!(
"Permission to use {tool_name} has been denied by rule '{}'",
rule.raw
),
};
} }
let request = PermissionRequest { let current_mode = self.active_mode();
tool_name: tool_name.to_string(), let required_mode = self.required_mode_for(tool_name);
input: input.to_string(), let ask_rule = Self::find_matching_rule(&self.ask_rules, tool_name, input);
current_mode, let allow_rule = Self::find_matching_rule(&self.allow_rules, tool_name, input);
required_mode,
}; match context.override_decision() {
Some(PermissionOverride::Deny) => {
return PermissionOutcome::Deny {
reason: context.override_reason().map_or_else(
|| format!("tool '{tool_name}' denied by hook"),
ToOwned::to_owned,
),
};
}
Some(PermissionOverride::Ask) => {
let reason = context.override_reason().map_or_else(
|| format!("tool '{tool_name}' requires approval due to hook guidance"),
ToOwned::to_owned,
);
return Self::prompt_or_deny(
tool_name,
input,
current_mode,
required_mode,
Some(reason),
prompter,
);
}
Some(PermissionOverride::Allow) => {
if let Some(rule) = ask_rule {
let reason = format!(
"tool '{tool_name}' requires approval due to ask rule '{}'",
rule.raw
);
return Self::prompt_or_deny(
tool_name,
input,
current_mode,
required_mode,
Some(reason),
prompter,
);
}
if allow_rule.is_some()
|| current_mode == PermissionMode::Allow
|| current_mode >= required_mode
{
return PermissionOutcome::Allow;
}
}
None => {}
}
if let Some(rule) = ask_rule {
let reason = format!(
"tool '{tool_name}' requires approval due to ask rule '{}'",
rule.raw
);
return Self::prompt_or_deny(
tool_name,
input,
current_mode,
required_mode,
Some(reason),
prompter,
);
}
if allow_rule.is_some()
|| current_mode == PermissionMode::Allow
|| current_mode >= required_mode
{
return PermissionOutcome::Allow;
}
if current_mode == PermissionMode::Prompt if current_mode == PermissionMode::Prompt
|| (current_mode == PermissionMode::WorkspaceWrite || (current_mode == PermissionMode::WorkspaceWrite
&& required_mode == PermissionMode::DangerFullAccess) && required_mode == PermissionMode::DangerFullAccess)
{ {
return match prompter.as_mut() { let reason = Some(format!(
Some(prompter) => match prompter.decide(&request) { "tool '{tool_name}' requires approval to escalate from {} to {}",
PermissionPromptDecision::Allow => PermissionOutcome::Allow, current_mode.as_str(),
PermissionPromptDecision::Deny { reason } => PermissionOutcome::Deny { reason }, required_mode.as_str()
}, ));
None => PermissionOutcome::Deny { return Self::prompt_or_deny(
reason: format!( tool_name,
"tool '{tool_name}' requires approval to escalate from {} to {}", input,
current_mode.as_str(), current_mode,
required_mode.as_str() required_mode,
), reason,
}, prompter,
}; );
} }
PermissionOutcome::Deny { PermissionOutcome::Deny {
@@ -132,14 +282,191 @@ impl PermissionPolicy {
), ),
} }
} }
fn prompt_or_deny(
tool_name: &str,
input: &str,
current_mode: PermissionMode,
required_mode: PermissionMode,
reason: Option<String>,
mut prompter: Option<&mut dyn PermissionPrompter>,
) -> PermissionOutcome {
let request = PermissionRequest {
tool_name: tool_name.to_string(),
input: input.to_string(),
current_mode,
required_mode,
reason: reason.clone(),
};
match prompter.as_mut() {
Some(prompter) => match prompter.decide(&request) {
PermissionPromptDecision::Allow => PermissionOutcome::Allow,
PermissionPromptDecision::Deny { reason } => PermissionOutcome::Deny { reason },
},
None => PermissionOutcome::Deny {
reason: reason.unwrap_or_else(|| {
format!(
"tool '{tool_name}' requires approval to run while mode is {}",
current_mode.as_str()
)
}),
},
}
}
fn find_matching_rule<'a>(
rules: &'a [PermissionRule],
tool_name: &str,
input: &str,
) -> Option<&'a PermissionRule> {
rules.iter().find(|rule| rule.matches(tool_name, input))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct PermissionRule {
raw: String,
tool_name: String,
matcher: PermissionRuleMatcher,
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum PermissionRuleMatcher {
Any,
Exact(String),
Prefix(String),
}
impl PermissionRule {
fn parse(raw: &str) -> Self {
let trimmed = raw.trim();
let open = find_first_unescaped(trimmed, '(');
let close = find_last_unescaped(trimmed, ')');
if let (Some(open), Some(close)) = (open, close) {
if close == trimmed.len() - 1 && open < close {
let tool_name = trimmed[..open].trim();
let content = &trimmed[open + 1..close];
if !tool_name.is_empty() {
let matcher = parse_rule_matcher(content);
return Self {
raw: trimmed.to_string(),
tool_name: tool_name.to_string(),
matcher,
};
}
}
}
Self {
raw: trimmed.to_string(),
tool_name: trimmed.to_string(),
matcher: PermissionRuleMatcher::Any,
}
}
fn matches(&self, tool_name: &str, input: &str) -> bool {
if self.tool_name != tool_name {
return false;
}
match &self.matcher {
PermissionRuleMatcher::Any => true,
PermissionRuleMatcher::Exact(expected) => {
extract_permission_subject(input).is_some_and(|candidate| candidate == *expected)
}
PermissionRuleMatcher::Prefix(prefix) => extract_permission_subject(input)
.is_some_and(|candidate| candidate.starts_with(prefix)),
}
}
}
fn parse_rule_matcher(content: &str) -> PermissionRuleMatcher {
let unescaped = unescape_rule_content(content.trim());
if unescaped.is_empty() || unescaped == "*" {
PermissionRuleMatcher::Any
} else if let Some(prefix) = unescaped.strip_suffix(":*") {
PermissionRuleMatcher::Prefix(prefix.to_string())
} else {
PermissionRuleMatcher::Exact(unescaped)
}
}
fn unescape_rule_content(content: &str) -> String {
content
.replace(r"\(", "(")
.replace(r"\)", ")")
.replace(r"\\", r"\")
}
fn find_first_unescaped(value: &str, needle: char) -> Option<usize> {
let mut escaped = false;
for (idx, ch) in value.char_indices() {
if ch == '\\' {
escaped = !escaped;
continue;
}
if ch == needle && !escaped {
return Some(idx);
}
escaped = false;
}
None
}
fn find_last_unescaped(value: &str, needle: char) -> Option<usize> {
let chars = value.char_indices().collect::<Vec<_>>();
for (pos, (idx, ch)) in chars.iter().enumerate().rev() {
if *ch != needle {
continue;
}
let mut backslashes = 0;
for (_, prev) in chars[..pos].iter().rev() {
if *prev == '\\' {
backslashes += 1;
} else {
break;
}
}
if backslashes % 2 == 0 {
return Some(*idx);
}
}
None
}
fn extract_permission_subject(input: &str) -> Option<String> {
let parsed = serde_json::from_str::<Value>(input).ok();
if let Some(Value::Object(object)) = parsed {
for key in [
"command",
"path",
"file_path",
"filePath",
"notebook_path",
"notebookPath",
"url",
"pattern",
"code",
"message",
] {
if let Some(value) = object.get(key).and_then(Value::as_str) {
return Some(value.to_string());
}
}
}
(!input.trim().is_empty()).then(|| input.to_string())
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{ use super::{
PermissionMode, PermissionOutcome, PermissionPolicy, PermissionPromptDecision, PermissionContext, PermissionMode, PermissionOutcome, PermissionOverride, PermissionPolicy,
PermissionPrompter, PermissionRequest, PermissionPromptDecision, PermissionPrompter, PermissionRequest,
}; };
use crate::config::RuntimePermissionRuleConfig;
struct RecordingPrompter { struct RecordingPrompter {
seen: Vec<PermissionRequest>, seen: Vec<PermissionRequest>,
@@ -229,4 +556,120 @@ mod tests {
PermissionOutcome::Deny { reason } if reason == "not now" PermissionOutcome::Deny { reason } if reason == "not now"
)); ));
} }
#[test]
fn applies_rule_based_denials_and_allows() {
let rules = RuntimePermissionRuleConfig::new(
vec!["bash(git:*)".to_string()],
vec!["bash(rm -rf:*)".to_string()],
Vec::new(),
);
let policy = PermissionPolicy::new(PermissionMode::ReadOnly)
.with_tool_requirement("bash", PermissionMode::DangerFullAccess)
.with_permission_rules(&rules);
assert_eq!(
policy.authorize("bash", r#"{"command":"git status"}"#, None),
PermissionOutcome::Allow
);
assert!(matches!(
policy.authorize("bash", r#"{"command":"rm -rf /tmp/x"}"#, None),
PermissionOutcome::Deny { reason } if reason.contains("denied by rule")
));
}
#[test]
fn ask_rules_force_prompt_even_when_mode_allows() {
let rules = RuntimePermissionRuleConfig::new(
Vec::new(),
Vec::new(),
vec!["bash(git:*)".to_string()],
);
let policy = PermissionPolicy::new(PermissionMode::DangerFullAccess)
.with_tool_requirement("bash", PermissionMode::DangerFullAccess)
.with_permission_rules(&rules);
let mut prompter = RecordingPrompter {
seen: Vec::new(),
allow: true,
};
let outcome = policy.authorize("bash", r#"{"command":"git status"}"#, Some(&mut prompter));
assert_eq!(outcome, PermissionOutcome::Allow);
assert_eq!(prompter.seen.len(), 1);
assert!(prompter.seen[0]
.reason
.as_deref()
.is_some_and(|reason| reason.contains("ask rule")));
}
#[test]
fn hook_allow_still_respects_ask_rules() {
let rules = RuntimePermissionRuleConfig::new(
Vec::new(),
Vec::new(),
vec!["bash(git:*)".to_string()],
);
let policy = PermissionPolicy::new(PermissionMode::ReadOnly)
.with_tool_requirement("bash", PermissionMode::DangerFullAccess)
.with_permission_rules(&rules);
let context = PermissionContext::new(
Some(PermissionOverride::Allow),
Some("hook approved".to_string()),
);
let mut prompter = RecordingPrompter {
seen: Vec::new(),
allow: true,
};
let outcome = policy.authorize_with_context(
"bash",
r#"{"command":"git status"}"#,
&context,
Some(&mut prompter),
);
assert_eq!(outcome, PermissionOutcome::Allow);
assert_eq!(prompter.seen.len(), 1);
}
#[test]
fn hook_deny_short_circuits_permission_flow() {
let policy = PermissionPolicy::new(PermissionMode::DangerFullAccess)
.with_tool_requirement("bash", PermissionMode::DangerFullAccess);
let context = PermissionContext::new(
Some(PermissionOverride::Deny),
Some("blocked by hook".to_string()),
);
assert_eq!(
policy.authorize_with_context("bash", "{}", &context, None),
PermissionOutcome::Deny {
reason: "blocked by hook".to_string(),
}
);
}
#[test]
fn hook_ask_forces_prompt() {
let policy = PermissionPolicy::new(PermissionMode::DangerFullAccess)
.with_tool_requirement("bash", PermissionMode::DangerFullAccess);
let context = PermissionContext::new(
Some(PermissionOverride::Ask),
Some("hook requested confirmation".to_string()),
);
let mut prompter = RecordingPrompter {
seen: Vec::new(),
allow: true,
};
let outcome = policy.authorize_with_context("bash", "{}", &context, Some(&mut prompter));
assert_eq!(outcome, PermissionOutcome::Allow);
assert_eq!(prompter.seen.len(), 1);
assert_eq!(
prompter.seen[0].reason.as_deref(),
Some("hook requested confirmation")
);
}
} }

View File

@@ -19,7 +19,7 @@ rustyline = "15"
runtime = { path = "../runtime" } runtime = { path = "../runtime" }
serde_json = "1" serde_json = "1"
syntect = "5" syntect = "5"
tokio = { version = "1", features = ["rt-multi-thread", "time"] } tokio = { version = "1", features = ["rt-multi-thread", "signal", "time"] }
tools = { path = "../tools" } tools = { path = "../tools" }
[lints] [lints]

View File

@@ -9,13 +9,14 @@ use std::io::{self, Read, Write};
use std::net::TcpListener; use std::net::TcpListener;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::process::Command; use std::process::Command;
use std::sync::mpsc::{self, Receiver, Sender};
use std::thread::{self, JoinHandle};
use std::time::{SystemTime, UNIX_EPOCH}; use std::time::{SystemTime, UNIX_EPOCH};
use api::{ use api::{
resolve_startup_auth_source, AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock, resolve_startup_auth_source, AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock,
InputMessage, MessageRequest, MessageResponse, OutputContentBlock, PromptCache, InputMessage, MessageRequest, MessageResponse, OutputContentBlock,
PromptCacheRecord, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock,
ToolResultContentBlock,
}; };
use commands::{ use commands::{
@@ -29,8 +30,8 @@ use runtime::{
parse_oauth_callback_request_target, save_oauth_credentials, ApiClient, ApiRequest, parse_oauth_callback_request_target, save_oauth_credentials, ApiClient, ApiRequest,
AssistantEvent, CompactionConfig, ConfigLoader, ConfigSource, ContentBlock, AssistantEvent, CompactionConfig, ConfigLoader, ConfigSource, ContentBlock,
ConversationMessage, ConversationRuntime, MessageRole, OAuthAuthorizationRequest, OAuthConfig, ConversationMessage, ConversationRuntime, MessageRole, OAuthAuthorizationRequest, OAuthConfig,
OAuthTokenExchangeRequest, PermissionMode, PermissionPolicy, ProjectContext, PromptCacheEvent, OAuthTokenExchangeRequest, PermissionMode, PermissionPolicy, ProjectContext, RuntimeError,
RuntimeError, Session, TokenUsage, ToolError, ToolExecutor, UsageTracker, Session, TokenUsage, ToolError, ToolExecutor, UsageTracker,
}; };
use serde_json::json; use serde_json::json;
use tools::{execute_tool, mvp_tool_specs, ToolSpec}; use tools::{execute_tool, mvp_tool_specs, ToolSpec};
@@ -985,6 +986,61 @@ struct LiveCli {
session: SessionHandle, session: SessionHandle,
} }
struct HookAbortMonitor {
stop_tx: Option<Sender<()>>,
join_handle: Option<JoinHandle<()>>,
}
impl HookAbortMonitor {
fn spawn(abort_signal: runtime::HookAbortSignal) -> Self {
Self::spawn_with_waiter(abort_signal, move |stop_rx, abort_signal| {
let Ok(runtime) = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
else {
return;
};
runtime.block_on(async move {
let wait_for_stop = tokio::task::spawn_blocking(move || {
let _ = stop_rx.recv();
});
tokio::select! {
result = tokio::signal::ctrl_c() => {
if result.is_ok() {
abort_signal.abort();
}
}
_ = wait_for_stop => {}
}
});
})
}
fn spawn_with_waiter<F>(abort_signal: runtime::HookAbortSignal, wait_for_interrupt: F) -> Self
where
F: FnOnce(Receiver<()>, runtime::HookAbortSignal) + Send + 'static,
{
let (stop_tx, stop_rx) = mpsc::channel();
let join_handle = thread::spawn(move || wait_for_interrupt(stop_rx, abort_signal));
Self {
stop_tx: Some(stop_tx),
join_handle: Some(join_handle),
}
}
fn stop(mut self) {
if let Some(stop_tx) = self.stop_tx.take() {
let _ = stop_tx.send(());
}
if let Some(join_handle) = self.join_handle.take() {
let _ = join_handle.join();
}
}
}
impl LiveCli { impl LiveCli {
fn new( fn new(
model: String, model: String,
@@ -996,7 +1052,6 @@ impl LiveCli {
let session = create_managed_session_handle()?; let session = create_managed_session_handle()?;
let runtime = build_runtime( let runtime = build_runtime(
Session::new(), Session::new(),
session.id.clone(),
model.clone(), model.clone(),
system_prompt.clone(), system_prompt.clone(),
enable_tools, enable_tools,
@@ -1041,7 +1096,34 @@ impl LiveCli {
) )
} }
fn prepare_turn_runtime(
&self,
emit_output: bool,
) -> Result<
(
ConversationRuntime<AnthropicRuntimeClient, CliToolExecutor>,
HookAbortMonitor,
),
Box<dyn std::error::Error>,
> {
let hook_abort_signal = runtime::HookAbortSignal::new();
let runtime = build_runtime(
self.runtime.session().clone(),
self.model.clone(),
self.system_prompt.clone(),
true,
emit_output,
self.allowed_tools.clone(),
self.permission_mode,
)?
.with_hook_abort_signal(hook_abort_signal.clone());
let hook_abort_monitor = HookAbortMonitor::spawn(hook_abort_signal);
Ok((runtime, hook_abort_monitor))
}
fn run_turn(&mut self, input: &str) -> Result<(), Box<dyn std::error::Error>> { fn run_turn(&mut self, input: &str) -> Result<(), Box<dyn std::error::Error>> {
let (mut runtime, hook_abort_monitor) = self.prepare_turn_runtime(true)?;
let mut spinner = Spinner::new(); let mut spinner = Spinner::new();
let mut stdout = io::stdout(); let mut stdout = io::stdout();
spinner.tick( spinner.tick(
@@ -1050,16 +1132,17 @@ impl LiveCli {
&mut stdout, &mut stdout,
)?; )?;
let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode); let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode);
let result = self.runtime.run_turn(input, Some(&mut permission_prompter)); let result = runtime.run_turn(input, Some(&mut permission_prompter));
hook_abort_monitor.stop();
self.runtime = runtime;
match result { match result {
Ok(summary) => { Ok(_) => {
spinner.finish( spinner.finish(
"✨ Done", "✨ Done",
TerminalRenderer::new().color_theme(), TerminalRenderer::new().color_theme(),
&mut stdout, &mut stdout,
)?; )?;
println!(); println!();
print_prompt_cache_events(&summary);
self.persist_session()?; self.persist_session()?;
Ok(()) Ok(())
} }
@@ -1086,19 +1169,11 @@ impl LiveCli {
} }
fn run_prompt_json(&mut self, input: &str) -> Result<(), Box<dyn std::error::Error>> { fn run_prompt_json(&mut self, input: &str) -> Result<(), Box<dyn std::error::Error>> {
let session = self.runtime.session().clone(); let (mut runtime, hook_abort_monitor) = self.prepare_turn_runtime(false)?;
let mut runtime = build_runtime(
session,
self.session.id.clone(),
self.model.clone(),
self.system_prompt.clone(),
true,
false,
self.allowed_tools.clone(),
self.permission_mode,
)?;
let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode); let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode);
let summary = runtime.run_turn(input, Some(&mut permission_prompter))?; let result = runtime.run_turn(input, Some(&mut permission_prompter));
hook_abort_monitor.stop();
let summary = result?;
self.runtime = runtime; self.runtime = runtime;
self.persist_session()?; self.persist_session()?;
println!( println!(
@@ -1109,7 +1184,6 @@ impl LiveCli {
"iterations": summary.iterations, "iterations": summary.iterations,
"tool_uses": collect_tool_uses(&summary), "tool_uses": collect_tool_uses(&summary),
"tool_results": collect_tool_results(&summary), "tool_results": collect_tool_results(&summary),
"prompt_cache_events": collect_prompt_cache_events(&summary),
"usage": { "usage": {
"input_tokens": summary.usage.input_tokens, "input_tokens": summary.usage.input_tokens,
"output_tokens": summary.usage.output_tokens, "output_tokens": summary.usage.output_tokens,
@@ -1237,7 +1311,6 @@ impl LiveCli {
let message_count = session.messages.len(); let message_count = session.messages.len();
self.runtime = build_runtime( self.runtime = build_runtime(
session, session,
self.session.id.clone(),
model.clone(), model.clone(),
self.system_prompt.clone(), self.system_prompt.clone(),
true, true,
@@ -1281,7 +1354,6 @@ impl LiveCli {
self.permission_mode = permission_mode_from_label(normalized); self.permission_mode = permission_mode_from_label(normalized);
self.runtime = build_runtime( self.runtime = build_runtime(
session, session,
self.session.id.clone(),
self.model.clone(), self.model.clone(),
self.system_prompt.clone(), self.system_prompt.clone(),
true, true,
@@ -1307,7 +1379,6 @@ impl LiveCli {
self.session = create_managed_session_handle()?; self.session = create_managed_session_handle()?;
self.runtime = build_runtime( self.runtime = build_runtime(
Session::new(), Session::new(),
self.session.id.clone(),
self.model.clone(), self.model.clone(),
self.system_prompt.clone(), self.system_prompt.clone(),
true, true,
@@ -1343,7 +1414,6 @@ impl LiveCli {
let message_count = session.messages.len(); let message_count = session.messages.len();
self.runtime = build_runtime( self.runtime = build_runtime(
session, session,
handle.id.clone(),
self.model.clone(), self.model.clone(),
self.system_prompt.clone(), self.system_prompt.clone(),
true, true,
@@ -1416,7 +1486,6 @@ impl LiveCli {
let message_count = session.messages.len(); let message_count = session.messages.len();
self.runtime = build_runtime( self.runtime = build_runtime(
session, session,
handle.id.clone(),
self.model.clone(), self.model.clone(),
self.system_prompt.clone(), self.system_prompt.clone(),
true, true,
@@ -1447,7 +1516,6 @@ impl LiveCli {
let skipped = removed == 0; let skipped = removed == 0;
self.runtime = build_runtime( self.runtime = build_runtime(
result.compacted_session, result.compacted_session,
self.session.id.clone(),
self.model.clone(), self.model.clone(),
self.system_prompt.clone(), self.system_prompt.clone(),
true, true,
@@ -1923,10 +1991,8 @@ fn build_runtime_feature_config(
.clone()) .clone())
} }
#[allow(clippy::too_many_arguments)]
fn build_runtime( fn build_runtime(
session: Session, session: Session,
session_id: String,
model: String, model: String,
system_prompt: Vec<String>, system_prompt: Vec<String>,
enable_tools: bool, enable_tools: bool,
@@ -1935,20 +2001,52 @@ fn build_runtime(
permission_mode: PermissionMode, permission_mode: PermissionMode,
) -> Result<ConversationRuntime<AnthropicRuntimeClient, CliToolExecutor>, Box<dyn std::error::Error>> ) -> Result<ConversationRuntime<AnthropicRuntimeClient, CliToolExecutor>, Box<dyn std::error::Error>>
{ {
Ok(ConversationRuntime::new_with_features( let feature_config = build_runtime_feature_config()?;
let mut runtime = ConversationRuntime::new_with_features(
session, session,
AnthropicRuntimeClient::new( AnthropicRuntimeClient::new(model, enable_tools, emit_output, allowed_tools.clone())?,
model,
enable_tools,
emit_output,
allowed_tools.clone(),
session_id,
)?,
CliToolExecutor::new(allowed_tools, emit_output), CliToolExecutor::new(allowed_tools, emit_output),
permission_policy(permission_mode), permission_policy(permission_mode, &feature_config),
system_prompt, system_prompt,
&build_runtime_feature_config()?, feature_config,
)) );
if emit_output {
runtime = runtime.with_hook_progress_reporter(Box::new(CliHookProgressReporter));
}
Ok(runtime)
}
struct CliHookProgressReporter;
impl runtime::HookProgressReporter for CliHookProgressReporter {
fn on_event(&mut self, event: &runtime::HookProgressEvent) {
match event {
runtime::HookProgressEvent::Started {
event,
tool_name,
command,
} => eprintln!(
"[hook {event_name}] {tool_name}: {command}",
event_name = event.as_str()
),
runtime::HookProgressEvent::Completed {
event,
tool_name,
command,
} => eprintln!(
"[hook done {event_name}] {tool_name}: {command}",
event_name = event.as_str()
),
runtime::HookProgressEvent::Cancelled {
event,
tool_name,
command,
} => eprintln!(
"[hook cancelled {event_name}] {tool_name}: {command}",
event_name = event.as_str()
),
}
}
} }
struct CliPermissionPrompter { struct CliPermissionPrompter {
@@ -1971,6 +2069,9 @@ impl runtime::PermissionPrompter for CliPermissionPrompter {
println!(" Tool {}", request.tool_name); println!(" Tool {}", request.tool_name);
println!(" Current mode {}", self.current_mode.as_str()); println!(" Current mode {}", self.current_mode.as_str());
println!(" Required mode {}", request.required_mode.as_str()); println!(" Required mode {}", request.required_mode.as_str());
if let Some(reason) = &request.reason {
println!(" Reason {reason}");
}
println!(" Input {}", request.input); println!(" Input {}", request.input);
print!("Approve this tool call? [y/N]: "); print!("Approve this tool call? [y/N]: ");
let _ = io::stdout().flush(); let _ = io::stdout().flush();
@@ -2012,13 +2113,11 @@ impl AnthropicRuntimeClient {
enable_tools: bool, enable_tools: bool,
emit_output: bool, emit_output: bool,
allowed_tools: Option<AllowedToolSet>, allowed_tools: Option<AllowedToolSet>,
session_id: impl Into<String>,
) -> Result<Self, Box<dyn std::error::Error>> { ) -> Result<Self, Box<dyn std::error::Error>> {
Ok(Self { Ok(Self {
runtime: tokio::runtime::Runtime::new()?, runtime: tokio::runtime::Runtime::new()?,
client: AnthropicClient::from_auth(resolve_cli_auth_source()?) client: AnthropicClient::from_auth(resolve_cli_auth_source()?)
.with_base_url(api::read_base_url()) .with_base_url(api::read_base_url()),
.with_prompt_cache(PromptCache::new(session_id)),
model, model,
enable_tools, enable_tools,
emit_output, emit_output,
@@ -2133,8 +2232,8 @@ impl ApiClient for AnthropicRuntimeClient {
events.push(AssistantEvent::Usage(TokenUsage { events.push(AssistantEvent::Usage(TokenUsage {
input_tokens: delta.usage.input_tokens, input_tokens: delta.usage.input_tokens,
output_tokens: delta.usage.output_tokens, output_tokens: delta.usage.output_tokens,
cache_creation_input_tokens: delta.usage.cache_creation_input_tokens, cache_creation_input_tokens: 0,
cache_read_input_tokens: delta.usage.cache_read_input_tokens, cache_read_input_tokens: 0,
})); }));
} }
ApiStreamEvent::MessageStop(_) => { ApiStreamEvent::MessageStop(_) => {
@@ -2149,8 +2248,6 @@ impl ApiClient for AnthropicRuntimeClient {
} }
} }
push_prompt_cache_record(&self.client, &mut events);
if !saw_stop if !saw_stop
&& events.iter().any(|event| { && events.iter().any(|event| {
matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty()) matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty())
@@ -2175,9 +2272,7 @@ impl ApiClient for AnthropicRuntimeClient {
}) })
.await .await
.map_err(|error| RuntimeError::new(error.to_string()))?; .map_err(|error| RuntimeError::new(error.to_string()))?;
let mut events = response_to_events(response, out)?; response_to_events(response, out)
push_prompt_cache_record(&self.client, &mut events);
Ok(events)
}) })
} }
} }
@@ -2238,39 +2333,6 @@ fn collect_tool_results(summary: &runtime::TurnSummary) -> Vec<serde_json::Value
.collect() .collect()
} }
fn collect_prompt_cache_events(summary: &runtime::TurnSummary) -> Vec<serde_json::Value> {
summary
.prompt_cache_events
.iter()
.map(|event| {
json!({
"unexpected": event.unexpected,
"reason": event.reason,
"previous_cache_read_input_tokens": event.previous_cache_read_input_tokens,
"current_cache_read_input_tokens": event.current_cache_read_input_tokens,
"token_drop": event.token_drop,
})
})
.collect()
}
fn print_prompt_cache_events(summary: &runtime::TurnSummary) {
for event in &summary.prompt_cache_events {
let label = if event.unexpected {
"Prompt cache break"
} else {
"Prompt cache invalidation"
};
println!(
"{label}: {} (cache read {} -> {}, drop {})",
event.reason,
event.previous_cache_read_input_tokens,
event.current_cache_read_input_tokens,
event.token_drop,
);
}
}
fn slash_command_completion_candidates() -> Vec<String> { fn slash_command_completion_candidates() -> Vec<String> {
slash_command_specs() slash_command_specs()
.iter() .iter()
@@ -2417,19 +2479,19 @@ fn first_visible_line(text: &str) -> &str {
} }
fn format_bash_result(icon: &str, parsed: &serde_json::Value) -> String { fn format_bash_result(icon: &str, parsed: &serde_json::Value) -> String {
use std::fmt::Write as _;
let mut lines = vec![format!("{icon} \x1b[38;5;245mbash\x1b[0m")]; let mut lines = vec![format!("{icon} \x1b[38;5;245mbash\x1b[0m")];
if let Some(task_id) = parsed if let Some(task_id) = parsed
.get("backgroundTaskId") .get("backgroundTaskId")
.and_then(|value| value.as_str()) .and_then(|value| value.as_str())
{ {
use std::fmt::Write as _;
let _ = write!(lines[0], " backgrounded ({task_id})"); let _ = write!(lines[0], " backgrounded ({task_id})");
} else if let Some(status) = parsed } else if let Some(status) = parsed
.get("returnCodeInterpretation") .get("returnCodeInterpretation")
.and_then(|value| value.as_str()) .and_then(|value| value.as_str())
.filter(|status| !status.is_empty()) .filter(|status| !status.is_empty())
{ {
use std::fmt::Write as _;
let _ = write!(lines[0], " {status}"); let _ = write!(lines[0], " {status}");
} }
@@ -2680,26 +2742,6 @@ fn response_to_events(
Ok(events) Ok(events)
} }
fn push_prompt_cache_record(client: &AnthropicClient, events: &mut Vec<AssistantEvent>) {
if let Some(event) = client
.take_last_prompt_cache_record()
.and_then(prompt_cache_record_to_runtime_event)
{
events.push(AssistantEvent::PromptCache(event));
}
}
fn prompt_cache_record_to_runtime_event(record: PromptCacheRecord) -> Option<PromptCacheEvent> {
let cache_break = record.cache_break?;
Some(PromptCacheEvent {
unexpected: cache_break.unexpected,
reason: cache_break.reason,
previous_cache_read_input_tokens: cache_break.previous_cache_read_input_tokens,
current_cache_read_input_tokens: cache_break.current_cache_read_input_tokens,
token_drop: cache_break.token_drop,
})
}
struct CliToolExecutor { struct CliToolExecutor {
renderer: TerminalRenderer, renderer: TerminalRenderer,
emit_output: bool, emit_output: bool,
@@ -2752,12 +2794,14 @@ impl ToolExecutor for CliToolExecutor {
} }
} }
fn permission_policy(mode: PermissionMode) -> PermissionPolicy { fn permission_policy(
tool_permission_specs() mode: PermissionMode,
.into_iter() feature_config: &runtime::RuntimeFeatureConfig,
.fold(PermissionPolicy::new(mode), |policy, spec| { ) -> PermissionPolicy {
policy.with_tool_requirement(spec.name, spec.required_permission) tool_permission_specs().into_iter().fold(
}) PermissionPolicy::new(mode).with_permission_rules(feature_config.permission_rules()),
|policy, spec| policy.with_tool_requirement(spec.name, spec.required_permission),
)
} }
fn tool_permission_specs() -> Vec<ToolSpec> { fn tool_permission_specs() -> Vec<ToolSpec> {
@@ -2906,12 +2950,17 @@ mod tests {
normalize_permission_mode, parse_args, parse_git_status_metadata, print_help_to, normalize_permission_mode, parse_args, parse_git_status_metadata, print_help_to,
push_output_block, render_config_report, render_memory_report, render_repl_help, push_output_block, render_config_report, render_memory_report, render_repl_help,
resolve_model_alias, response_to_events, resume_supported_slash_commands, status_context, resolve_model_alias, response_to_events, resume_supported_slash_commands, status_context,
CliAction, CliOutputFormat, SlashCommand, StatusUsage, DEFAULT_MODEL, CliAction, CliOutputFormat, HookAbortMonitor, SlashCommand, StatusUsage, DEFAULT_MODEL,
}; };
use api::{MessageResponse, OutputContentBlock, Usage}; use api::{MessageResponse, OutputContentBlock, Usage};
use runtime::{AssistantEvent, ContentBlock, ConversationMessage, MessageRole, PermissionMode}; use runtime::{
AssistantEvent, ContentBlock, ConversationMessage, HookAbortSignal, MessageRole,
PermissionMode,
};
use serde_json::json; use serde_json::json;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::mpsc;
use std::time::Duration;
#[test] #[test]
fn defaults_to_repl_when_no_args() { fn defaults_to_repl_when_no_args() {
@@ -3570,4 +3619,43 @@ mod tests {
if name == "read_file" && input == "{\"path\":\"rust/Cargo.toml\"}" if name == "read_file" && input == "{\"path\":\"rust/Cargo.toml\"}"
)); ));
} }
#[test]
fn hook_abort_monitor_stops_without_aborting() {
let abort_signal = HookAbortSignal::new();
let (ready_tx, ready_rx) = mpsc::channel();
let monitor = HookAbortMonitor::spawn_with_waiter(
abort_signal.clone(),
move |stop_rx, abort_signal| {
ready_tx.send(()).expect("ready signal");
let _ = stop_rx.recv();
assert!(!abort_signal.is_aborted());
},
);
ready_rx.recv().expect("waiter should be ready");
monitor.stop();
assert!(!abort_signal.is_aborted());
}
#[test]
fn hook_abort_monitor_propagates_interrupt() {
let abort_signal = HookAbortSignal::new();
let (done_tx, done_rx) = mpsc::channel();
let monitor = HookAbortMonitor::spawn_with_waiter(
abort_signal.clone(),
move |_stop_rx, abort_signal| {
abort_signal.abort();
done_tx.send(()).expect("done signal");
},
);
done_rx
.recv_timeout(Duration::from_secs(1))
.expect("interrupt should complete");
monitor.stop();
assert!(abort_signal.is_aborted());
}
} }

View File

@@ -5,15 +5,15 @@ use std::time::{Duration, Instant};
use api::{ use api::{
read_base_url, AnthropicClient, ContentBlockDelta, InputContentBlock, InputMessage, read_base_url, AnthropicClient, ContentBlockDelta, InputContentBlock, InputMessage,
MessageRequest, MessageResponse, OutputContentBlock, PromptCache, PromptCacheRecord, MessageRequest, MessageResponse, OutputContentBlock, StreamEvent as ApiStreamEvent, ToolChoice,
StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, ToolDefinition, ToolResultContentBlock,
}; };
use reqwest::blocking::Client; use reqwest::blocking::Client;
use runtime::{ use runtime::{
edit_file, execute_bash, glob_search, grep_search, load_system_prompt, read_file, write_file, edit_file, execute_bash, glob_search, grep_search, load_system_prompt, read_file, write_file,
ApiClient, ApiRequest, AssistantEvent, BashCommandInput, ContentBlock, ConversationMessage, ApiClient, ApiRequest, AssistantEvent, BashCommandInput, ContentBlock, ConversationMessage,
ConversationRuntime, GrepSearchInput, MessageRole, PermissionMode, PermissionPolicy, ConversationRuntime, GrepSearchInput, MessageRole, PermissionMode, PermissionPolicy,
PromptCacheEvent, RuntimeError, Session, TokenUsage, ToolError, ToolExecutor, RuntimeError, Session, TokenUsage, ToolError, ToolExecutor,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{json, Value}; use serde_json::{json, Value};
@@ -1466,8 +1466,7 @@ fn build_agent_runtime(
.clone() .clone()
.unwrap_or_else(|| DEFAULT_AGENT_MODEL.to_string()); .unwrap_or_else(|| DEFAULT_AGENT_MODEL.to_string());
let allowed_tools = job.allowed_tools.clone(); let allowed_tools = job.allowed_tools.clone();
let api_client = let api_client = AnthropicRuntimeClient::new(model, allowed_tools.clone())?;
AnthropicRuntimeClient::new(model, allowed_tools.clone(), job.manifest.agent_id.clone())?;
let tool_executor = SubagentToolExecutor::new(allowed_tools); let tool_executor = SubagentToolExecutor::new(allowed_tools);
Ok(ConversationRuntime::new( Ok(ConversationRuntime::new(
Session::new(), Session::new(),
@@ -1644,15 +1643,10 @@ struct AnthropicRuntimeClient {
} }
impl AnthropicRuntimeClient { impl AnthropicRuntimeClient {
fn new( fn new(model: String, allowed_tools: BTreeSet<String>) -> Result<Self, String> {
model: String,
allowed_tools: BTreeSet<String>,
session_id: impl Into<String>,
) -> Result<Self, String> {
let client = AnthropicClient::from_env() let client = AnthropicClient::from_env()
.map_err(|error| error.to_string())? .map_err(|error| error.to_string())?
.with_base_url(read_base_url()) .with_base_url(read_base_url());
.with_prompt_cache(PromptCache::new(session_id));
Ok(Self { Ok(Self {
runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?, runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?,
client, client,
@@ -1663,7 +1657,6 @@ impl AnthropicRuntimeClient {
} }
impl ApiClient for AnthropicRuntimeClient { impl ApiClient for AnthropicRuntimeClient {
#[allow(clippy::too_many_lines)]
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> { fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
let tools = tool_specs_for_allowed_tools(Some(&self.allowed_tools)) let tools = tool_specs_for_allowed_tools(Some(&self.allowed_tools))
.into_iter() .into_iter()
@@ -1733,8 +1726,8 @@ impl ApiClient for AnthropicRuntimeClient {
events.push(AssistantEvent::Usage(TokenUsage { events.push(AssistantEvent::Usage(TokenUsage {
input_tokens: delta.usage.input_tokens, input_tokens: delta.usage.input_tokens,
output_tokens: delta.usage.output_tokens, output_tokens: delta.usage.output_tokens,
cache_creation_input_tokens: delta.usage.cache_creation_input_tokens, cache_creation_input_tokens: 0,
cache_read_input_tokens: delta.usage.cache_read_input_tokens, cache_read_input_tokens: 0,
})); }));
} }
ApiStreamEvent::MessageStop(_) => { ApiStreamEvent::MessageStop(_) => {
@@ -1744,8 +1737,6 @@ impl ApiClient for AnthropicRuntimeClient {
} }
} }
push_prompt_cache_record(&self.client, &mut events);
if !saw_stop if !saw_stop
&& events.iter().any(|event| { && events.iter().any(|event| {
matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty()) matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty())
@@ -1770,9 +1761,7 @@ impl ApiClient for AnthropicRuntimeClient {
}) })
.await .await
.map_err(|error| RuntimeError::new(error.to_string()))?; .map_err(|error| RuntimeError::new(error.to_string()))?;
let mut events = response_to_events(response); Ok(response_to_events(response))
push_prompt_cache_record(&self.client, &mut events);
Ok(events)
}) })
} }
} }
@@ -1895,26 +1884,6 @@ fn response_to_events(response: MessageResponse) -> Vec<AssistantEvent> {
events events
} }
fn push_prompt_cache_record(client: &AnthropicClient, events: &mut Vec<AssistantEvent>) {
if let Some(event) = client
.take_last_prompt_cache_record()
.and_then(prompt_cache_record_to_runtime_event)
{
events.push(AssistantEvent::PromptCache(event));
}
}
fn prompt_cache_record_to_runtime_event(record: PromptCacheRecord) -> Option<PromptCacheEvent> {
let cache_break = record.cache_break?;
Some(PromptCacheEvent {
unexpected: cache_break.unexpected,
reason: cache_break.reason,
previous_cache_read_input_tokens: cache_break.previous_cache_read_input_tokens,
current_cache_read_input_tokens: cache_break.current_cache_read_input_tokens,
token_drop: cache_break.token_drop,
})
}
fn final_assistant_text(summary: &runtime::TurnSummary) -> String { fn final_assistant_text(summary: &runtime::TurnSummary) -> String {
summary summary
.assistant_messages .assistant_messages