feat: cache-tracking progress

This commit is contained in:
Yeachan-Heo
2026-04-01 06:15:13 +00:00
parent 26344c578b
commit c9d214c8d1
7 changed files with 238 additions and 52 deletions

View File

@@ -1,4 +1,5 @@
use std::collections::VecDeque; use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::time::{Duration, SystemTime, UNIX_EPOCH};
use runtime::{ use runtime::{
@@ -8,7 +9,7 @@ use runtime::{
use serde::Deserialize; use serde::Deserialize;
use crate::error::ApiError; use crate::error::ApiError;
use crate::prompt_cache::{PromptCache, PromptCacheStats}; use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats};
use crate::sse::SseParser; use crate::sse::SseParser;
use crate::types::{MessageRequest, MessageResponse, StreamEvent, Usage}; use crate::types::{MessageRequest, MessageResponse, StreamEvent, Usage};
@@ -110,6 +111,7 @@ pub struct AnthropicClient {
initial_backoff: Duration, initial_backoff: Duration,
max_backoff: Duration, max_backoff: Duration,
prompt_cache: Option<PromptCache>, prompt_cache: Option<PromptCache>,
last_prompt_cache_record: Arc<Mutex<Option<PromptCacheRecord>>>,
} }
impl AnthropicClient { impl AnthropicClient {
@@ -123,6 +125,7 @@ impl AnthropicClient {
initial_backoff: DEFAULT_INITIAL_BACKOFF, initial_backoff: DEFAULT_INITIAL_BACKOFF,
max_backoff: DEFAULT_MAX_BACKOFF, max_backoff: DEFAULT_MAX_BACKOFF,
prompt_cache: None, prompt_cache: None,
last_prompt_cache_record: Arc::new(Mutex::new(None)),
} }
} }
@@ -136,6 +139,7 @@ impl AnthropicClient {
initial_backoff: DEFAULT_INITIAL_BACKOFF, initial_backoff: DEFAULT_INITIAL_BACKOFF,
max_backoff: DEFAULT_MAX_BACKOFF, max_backoff: DEFAULT_MAX_BACKOFF,
prompt_cache: None, prompt_cache: None,
last_prompt_cache_record: Arc::new(Mutex::new(None)),
} }
} }
@@ -209,6 +213,14 @@ impl AnthropicClient {
self.prompt_cache.as_ref().map(PromptCache::stats) self.prompt_cache.as_ref().map(PromptCache::stats)
} }
#[must_use]
pub fn take_last_prompt_cache_record(&self) -> Option<PromptCacheRecord> {
self.last_prompt_cache_record()
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.take()
}
#[must_use] #[must_use]
pub fn auth_source(&self) -> &AuthSource { pub fn auth_source(&self) -> &AuthSource {
&self.auth &self.auth
@@ -218,12 +230,16 @@ impl AnthropicClient {
&self, &self,
request: &MessageRequest, request: &MessageRequest,
) -> Result<MessageResponse, ApiError> { ) -> Result<MessageResponse, ApiError> {
self.store_last_prompt_cache_record(None);
let request = MessageRequest { let request = MessageRequest {
stream: false, stream: false,
..request.clone() ..request.clone()
}; };
if let Some(prompt_cache) = &self.prompt_cache { if let Some(prompt_cache) = &self.prompt_cache {
if let Some(response) = prompt_cache.lookup_completion(&request) { if let Some(response) = prompt_cache.lookup_completion(&request) {
self.store_last_prompt_cache_record(Some(prompt_cache_record_from_stats(
prompt_cache.stats(),
)));
return Ok(response); return Ok(response);
} }
} }
@@ -237,7 +253,8 @@ impl AnthropicClient {
response.request_id = request_id; response.request_id = request_id;
} }
if let Some(prompt_cache) = &self.prompt_cache { if let Some(prompt_cache) = &self.prompt_cache {
let _ = prompt_cache.record_response(&request, &response); let record = prompt_cache.record_response(&request, &response);
self.store_last_prompt_cache_record(Some(record));
} }
Ok(response) Ok(response)
} }
@@ -246,6 +263,7 @@ impl AnthropicClient {
&self, &self,
request: &MessageRequest, request: &MessageRequest,
) -> Result<MessageStream, ApiError> { ) -> Result<MessageStream, ApiError> {
self.store_last_prompt_cache_record(None);
let response = self let response = self
.send_with_retry(&request.clone().with_streaming()) .send_with_retry(&request.clone().with_streaming())
.await?; .await?;
@@ -263,10 +281,22 @@ impl AnthropicClient {
request: request.clone().with_streaming(), request: request.clone().with_streaming(),
last_usage: None, last_usage: None,
finalized: false, finalized: false,
last_record: self.last_prompt_cache_record.clone(),
}), }),
}) })
} }
fn store_last_prompt_cache_record(&self, record: Option<PromptCacheRecord>) {
*self
.last_prompt_cache_record()
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) = record;
}
fn last_prompt_cache_record(&self) -> &Arc<Mutex<Option<PromptCacheRecord>>> {
&self.last_prompt_cache_record
}
pub async fn exchange_oauth_code( pub async fn exchange_oauth_code(
&self, &self,
config: &OAuthConfig, config: &OAuthConfig,
@@ -615,6 +645,7 @@ struct StreamCacheTracking {
request: MessageRequest, request: MessageRequest,
last_usage: Option<Usage>, last_usage: Option<Usage>,
finalized: bool, finalized: bool,
last_record: Arc<Mutex<Option<PromptCacheRecord>>>,
} }
impl StreamCacheTracking { impl StreamCacheTracking {
@@ -638,12 +669,23 @@ impl StreamCacheTracking {
return; return;
} }
if let Some(usage) = &self.last_usage { if let Some(usage) = &self.last_usage {
let _ = self.prompt_cache.record_usage(&self.request, usage); let record = self.prompt_cache.record_usage(&self.request, usage);
*self
.last_record
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) = Some(record);
} }
self.finalized = true; self.finalized = true;
} }
} }
fn prompt_cache_record_from_stats(stats: PromptCacheStats) -> PromptCacheRecord {
PromptCacheRecord {
cache_break: None,
stats,
}
}
async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> { async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
let status = response.status(); let status = response.status();
if status.is_success() { if status.is_success() {

View File

@@ -25,9 +25,19 @@ pub enum AssistantEvent {
input: String, input: String,
}, },
Usage(TokenUsage), Usage(TokenUsage),
PromptCache(PromptCacheEvent),
MessageStop, MessageStop,
} }
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PromptCacheEvent {
pub unexpected: bool,
pub reason: String,
pub previous_cache_read_input_tokens: u32,
pub current_cache_read_input_tokens: u32,
pub token_drop: u32,
}
pub trait ApiClient { pub trait ApiClient {
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>; fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>;
} }
@@ -84,6 +94,7 @@ impl std::error::Error for RuntimeError {}
pub struct TurnSummary { pub struct TurnSummary {
pub assistant_messages: Vec<ConversationMessage>, pub assistant_messages: Vec<ConversationMessage>,
pub tool_results: Vec<ConversationMessage>, pub tool_results: Vec<ConversationMessage>,
pub prompt_cache_events: Vec<PromptCacheEvent>,
pub iterations: usize, pub iterations: usize,
pub usage: TokenUsage, pub usage: TokenUsage,
} }
@@ -118,7 +129,7 @@ where
tool_executor, tool_executor,
permission_policy, permission_policy,
system_prompt, system_prompt,
RuntimeFeatureConfig::default(), &RuntimeFeatureConfig::default(),
) )
} }
@@ -129,7 +140,7 @@ where
tool_executor: T, tool_executor: T,
permission_policy: PermissionPolicy, permission_policy: PermissionPolicy,
system_prompt: Vec<String>, system_prompt: Vec<String>,
feature_config: RuntimeFeatureConfig, feature_config: &RuntimeFeatureConfig,
) -> Self { ) -> Self {
let usage_tracker = UsageTracker::from_session(&session); let usage_tracker = UsageTracker::from_session(&session);
Self { Self {
@@ -140,7 +151,7 @@ where
system_prompt, system_prompt,
max_iterations: usize::MAX, max_iterations: usize::MAX,
usage_tracker, usage_tracker,
hook_runner: HookRunner::from_feature_config(&feature_config), hook_runner: HookRunner::from_feature_config(feature_config),
} }
} }
@@ -161,6 +172,7 @@ where
let mut assistant_messages = Vec::new(); let mut assistant_messages = Vec::new();
let mut tool_results = Vec::new(); let mut tool_results = Vec::new();
let mut prompt_cache_events = Vec::new();
let mut iterations = 0; let mut iterations = 0;
loop { loop {
@@ -176,10 +188,12 @@ where
messages: self.session.messages.clone(), messages: self.session.messages.clone(),
}; };
let events = self.api_client.stream(request)?; let events = self.api_client.stream(request)?;
let (assistant_message, usage) = build_assistant_message(events)?; let (assistant_message, usage, turn_prompt_cache_events) =
build_assistant_message(events)?;
if let Some(usage) = usage { if let Some(usage) = usage {
self.usage_tracker.record(usage); self.usage_tracker.record(usage);
} }
prompt_cache_events.extend(turn_prompt_cache_events);
let pending_tool_uses = assistant_message let pending_tool_uses = assistant_message
.blocks .blocks
.iter() .iter()
@@ -257,6 +271,7 @@ where
Ok(TurnSummary { Ok(TurnSummary {
assistant_messages, assistant_messages,
tool_results, tool_results,
prompt_cache_events,
iterations, iterations,
usage: self.usage_tracker.cumulative_usage(), usage: self.usage_tracker.cumulative_usage(),
}) })
@@ -290,9 +305,17 @@ where
fn build_assistant_message( fn build_assistant_message(
events: Vec<AssistantEvent>, events: Vec<AssistantEvent>,
) -> Result<(ConversationMessage, Option<TokenUsage>), RuntimeError> { ) -> Result<
(
ConversationMessage,
Option<TokenUsage>,
Vec<PromptCacheEvent>,
),
RuntimeError,
> {
let mut text = String::new(); let mut text = String::new();
let mut blocks = Vec::new(); let mut blocks = Vec::new();
let mut prompt_cache_events = Vec::new();
let mut finished = false; let mut finished = false;
let mut usage = None; let mut usage = None;
@@ -304,6 +327,7 @@ fn build_assistant_message(
blocks.push(ContentBlock::ToolUse { id, name, input }); blocks.push(ContentBlock::ToolUse { id, name, input });
} }
AssistantEvent::Usage(value) => usage = Some(value), AssistantEvent::Usage(value) => usage = Some(value),
AssistantEvent::PromptCache(event) => prompt_cache_events.push(event),
AssistantEvent::MessageStop => { AssistantEvent::MessageStop => {
finished = true; finished = true;
} }
@@ -324,6 +348,7 @@ fn build_assistant_message(
Ok(( Ok((
ConversationMessage::assistant_with_usage(blocks, usage), ConversationMessage::assistant_with_usage(blocks, usage),
usage, usage,
prompt_cache_events,
)) ))
} }
@@ -396,7 +421,7 @@ impl ToolExecutor for StaticToolExecutor {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{ use super::{
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, PromptCacheEvent, RuntimeError,
StaticToolExecutor, StaticToolExecutor,
}; };
use crate::compact::CompactionConfig; use crate::compact::CompactionConfig;
@@ -453,6 +478,15 @@ mod tests {
cache_creation_input_tokens: 1, cache_creation_input_tokens: 1,
cache_read_input_tokens: 3, cache_read_input_tokens: 3,
}), }),
AssistantEvent::PromptCache(PromptCacheEvent {
unexpected: true,
reason:
"cache read tokens dropped while prompt fingerprint remained stable"
.to_string(),
previous_cache_read_input_tokens: 6_000,
current_cache_read_input_tokens: 1_000,
token_drop: 5_000,
}),
AssistantEvent::MessageStop, AssistantEvent::MessageStop,
]) ])
} }
@@ -506,8 +540,10 @@ mod tests {
assert_eq!(summary.iterations, 2); assert_eq!(summary.iterations, 2);
assert_eq!(summary.assistant_messages.len(), 2); assert_eq!(summary.assistant_messages.len(), 2);
assert_eq!(summary.tool_results.len(), 1); assert_eq!(summary.tool_results.len(), 1);
assert_eq!(summary.prompt_cache_events.len(), 1);
assert_eq!(runtime.session().messages.len(), 4); assert_eq!(runtime.session().messages.len(), 4);
assert_eq!(summary.usage.output_tokens, 10); assert_eq!(summary.usage.output_tokens, 10);
assert!(summary.prompt_cache_events[0].unexpected);
assert!(matches!( assert!(matches!(
runtime.session().messages[1].blocks[1], runtime.session().messages[1].blocks[1],
ContentBlock::ToolUse { .. } ContentBlock::ToolUse { .. }
@@ -609,7 +645,7 @@ mod tests {
}), }),
PermissionPolicy::new(PermissionMode::DangerFullAccess), PermissionPolicy::new(PermissionMode::DangerFullAccess),
vec!["system".to_string()], vec!["system".to_string()],
RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
vec![shell_snippet("printf 'blocked by hook'; exit 2")], vec![shell_snippet("printf 'blocked by hook'; exit 2")],
Vec::new(), Vec::new(),
)), )),
@@ -675,7 +711,7 @@ mod tests {
StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())), StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
PermissionPolicy::new(PermissionMode::DangerFullAccess), PermissionPolicy::new(PermissionMode::DangerFullAccess),
vec!["system".to_string()], vec!["system".to_string()],
RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
vec![shell_snippet("printf 'pre hook ran'")], vec![shell_snippet("printf 'pre hook ran'")],
vec![shell_snippet("printf 'post hook ran'")], vec![shell_snippet("printf 'post hook ran'")],
)), )),
@@ -697,7 +733,7 @@ mod tests {
"post hook should preserve non-error result: {output:?}" "post hook should preserve non-error result: {output:?}"
); );
assert!( assert!(
output.contains("4"), output.contains('4'),
"tool output missing value: {output:?}" "tool output missing value: {output:?}"
); );
assert!( assert!(

View File

@@ -64,7 +64,7 @@ impl HookRunner {
#[must_use] #[must_use]
pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult { pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult {
self.run_commands( Self::run_commands(
HookEvent::PreToolUse, HookEvent::PreToolUse,
self.config.pre_tool_use(), self.config.pre_tool_use(),
tool_name, tool_name,
@@ -82,7 +82,7 @@ impl HookRunner {
tool_output: &str, tool_output: &str,
is_error: bool, is_error: bool,
) -> HookRunResult { ) -> HookRunResult {
self.run_commands( Self::run_commands(
HookEvent::PostToolUse, HookEvent::PostToolUse,
self.config.post_tool_use(), self.config.post_tool_use(),
tool_name, tool_name,
@@ -93,7 +93,6 @@ impl HookRunner {
} }
fn run_commands( fn run_commands(
&self,
event: HookEvent, event: HookEvent,
commands: &[String], commands: &[String],
tool_name: &str, tool_name: &str,
@@ -118,7 +117,7 @@ impl HookRunner {
let mut messages = Vec::new(); let mut messages = Vec::new();
for command in commands { for command in commands {
match self.run_command( match Self::run_command(
command, command,
event, event,
tool_name, tool_name,
@@ -150,7 +149,6 @@ impl HookRunner {
} }
fn run_command( fn run_command(
&self,
command: &str, command: &str,
event: HookEvent, event: HookEvent,
tool_name: &str, tool_name: &str,

View File

@@ -31,8 +31,8 @@ pub use config::{
ScopedMcpServerConfig, CLAUDE_CODE_SETTINGS_SCHEMA_NAME, ScopedMcpServerConfig, CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
}; };
pub use conversation::{ pub use conversation::{
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, StaticToolExecutor, ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, PromptCacheEvent, RuntimeError,
ToolError, ToolExecutor, TurnSummary, StaticToolExecutor, ToolError, ToolExecutor, TurnSummary,
}; };
pub use file_ops::{ pub use file_ops::{
edit_file, glob_search, grep_search, read_file, write_file, EditFileOutput, GlobSearchOutput, edit_file, glob_search, grep_search, read_file, write_file, EditFileOutput, GlobSearchOutput,

View File

@@ -13,8 +13,9 @@ use std::time::{SystemTime, UNIX_EPOCH};
use api::{ use api::{
resolve_startup_auth_source, AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock, resolve_startup_auth_source, AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock,
InputMessage, MessageRequest, MessageResponse, OutputContentBlock, InputMessage, MessageRequest, MessageResponse, OutputContentBlock, PromptCache,
StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, PromptCacheRecord, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition,
ToolResultContentBlock,
}; };
use commands::{ use commands::{
@@ -28,8 +29,8 @@ use runtime::{
parse_oauth_callback_request_target, save_oauth_credentials, ApiClient, ApiRequest, parse_oauth_callback_request_target, save_oauth_credentials, ApiClient, ApiRequest,
AssistantEvent, CompactionConfig, ConfigLoader, ConfigSource, ContentBlock, AssistantEvent, CompactionConfig, ConfigLoader, ConfigSource, ContentBlock,
ConversationMessage, ConversationRuntime, MessageRole, OAuthAuthorizationRequest, OAuthConfig, ConversationMessage, ConversationRuntime, MessageRole, OAuthAuthorizationRequest, OAuthConfig,
OAuthTokenExchangeRequest, PermissionMode, PermissionPolicy, ProjectContext, RuntimeError, OAuthTokenExchangeRequest, PermissionMode, PermissionPolicy, ProjectContext, PromptCacheEvent,
Session, TokenUsage, ToolError, ToolExecutor, UsageTracker, RuntimeError, Session, TokenUsage, ToolError, ToolExecutor, UsageTracker,
}; };
use serde_json::json; use serde_json::json;
use tools::{execute_tool, mvp_tool_specs, ToolSpec}; use tools::{execute_tool, mvp_tool_specs, ToolSpec};
@@ -995,6 +996,7 @@ impl LiveCli {
let session = create_managed_session_handle()?; let session = create_managed_session_handle()?;
let runtime = build_runtime( let runtime = build_runtime(
Session::new(), Session::new(),
session.id.clone(),
model.clone(), model.clone(),
system_prompt.clone(), system_prompt.clone(),
enable_tools, enable_tools,
@@ -1050,13 +1052,14 @@ impl LiveCli {
let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode); let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode);
let result = self.runtime.run_turn(input, Some(&mut permission_prompter)); let result = self.runtime.run_turn(input, Some(&mut permission_prompter));
match result { match result {
Ok(_) => { Ok(summary) => {
spinner.finish( spinner.finish(
"✨ Done", "✨ Done",
TerminalRenderer::new().color_theme(), TerminalRenderer::new().color_theme(),
&mut stdout, &mut stdout,
)?; )?;
println!(); println!();
print_prompt_cache_events(&summary);
self.persist_session()?; self.persist_session()?;
Ok(()) Ok(())
} }
@@ -1086,6 +1089,7 @@ impl LiveCli {
let session = self.runtime.session().clone(); let session = self.runtime.session().clone();
let mut runtime = build_runtime( let mut runtime = build_runtime(
session, session,
self.session.id.clone(),
self.model.clone(), self.model.clone(),
self.system_prompt.clone(), self.system_prompt.clone(),
true, true,
@@ -1105,6 +1109,7 @@ impl LiveCli {
"iterations": summary.iterations, "iterations": summary.iterations,
"tool_uses": collect_tool_uses(&summary), "tool_uses": collect_tool_uses(&summary),
"tool_results": collect_tool_results(&summary), "tool_results": collect_tool_results(&summary),
"prompt_cache_events": collect_prompt_cache_events(&summary),
"usage": { "usage": {
"input_tokens": summary.usage.input_tokens, "input_tokens": summary.usage.input_tokens,
"output_tokens": summary.usage.output_tokens, "output_tokens": summary.usage.output_tokens,
@@ -1232,6 +1237,7 @@ impl LiveCli {
let message_count = session.messages.len(); let message_count = session.messages.len();
self.runtime = build_runtime( self.runtime = build_runtime(
session, session,
self.session.id.clone(),
model.clone(), model.clone(),
self.system_prompt.clone(), self.system_prompt.clone(),
true, true,
@@ -1275,6 +1281,7 @@ impl LiveCli {
self.permission_mode = permission_mode_from_label(normalized); self.permission_mode = permission_mode_from_label(normalized);
self.runtime = build_runtime( self.runtime = build_runtime(
session, session,
self.session.id.clone(),
self.model.clone(), self.model.clone(),
self.system_prompt.clone(), self.system_prompt.clone(),
true, true,
@@ -1300,6 +1307,7 @@ impl LiveCli {
self.session = create_managed_session_handle()?; self.session = create_managed_session_handle()?;
self.runtime = build_runtime( self.runtime = build_runtime(
Session::new(), Session::new(),
self.session.id.clone(),
self.model.clone(), self.model.clone(),
self.system_prompt.clone(), self.system_prompt.clone(),
true, true,
@@ -1335,6 +1343,7 @@ impl LiveCli {
let message_count = session.messages.len(); let message_count = session.messages.len();
self.runtime = build_runtime( self.runtime = build_runtime(
session, session,
handle.id.clone(),
self.model.clone(), self.model.clone(),
self.system_prompt.clone(), self.system_prompt.clone(),
true, true,
@@ -1407,6 +1416,7 @@ impl LiveCli {
let message_count = session.messages.len(); let message_count = session.messages.len();
self.runtime = build_runtime( self.runtime = build_runtime(
session, session,
handle.id.clone(),
self.model.clone(), self.model.clone(),
self.system_prompt.clone(), self.system_prompt.clone(),
true, true,
@@ -1437,6 +1447,7 @@ impl LiveCli {
let skipped = removed == 0; let skipped = removed == 0;
self.runtime = build_runtime( self.runtime = build_runtime(
result.compacted_session, result.compacted_session,
self.session.id.clone(),
self.model.clone(), self.model.clone(),
self.system_prompt.clone(), self.system_prompt.clone(),
true, true,
@@ -1912,8 +1923,10 @@ fn build_runtime_feature_config(
.clone()) .clone())
} }
#[allow(clippy::too_many_arguments)]
fn build_runtime( fn build_runtime(
session: Session, session: Session,
session_id: String,
model: String, model: String,
system_prompt: Vec<String>, system_prompt: Vec<String>,
enable_tools: bool, enable_tools: bool,
@@ -1924,11 +1937,17 @@ fn build_runtime(
{ {
Ok(ConversationRuntime::new_with_features( Ok(ConversationRuntime::new_with_features(
session, session,
AnthropicRuntimeClient::new(model, enable_tools, emit_output, allowed_tools.clone())?, AnthropicRuntimeClient::new(
model,
enable_tools,
emit_output,
allowed_tools.clone(),
session_id,
)?,
CliToolExecutor::new(allowed_tools, emit_output), CliToolExecutor::new(allowed_tools, emit_output),
permission_policy(permission_mode), permission_policy(permission_mode),
system_prompt, system_prompt,
build_runtime_feature_config()?, &build_runtime_feature_config()?,
)) ))
} }
@@ -1993,11 +2012,13 @@ impl AnthropicRuntimeClient {
enable_tools: bool, enable_tools: bool,
emit_output: bool, emit_output: bool,
allowed_tools: Option<AllowedToolSet>, allowed_tools: Option<AllowedToolSet>,
session_id: impl Into<String>,
) -> Result<Self, Box<dyn std::error::Error>> { ) -> Result<Self, Box<dyn std::error::Error>> {
Ok(Self { Ok(Self {
runtime: tokio::runtime::Runtime::new()?, runtime: tokio::runtime::Runtime::new()?,
client: AnthropicClient::from_auth(resolve_cli_auth_source()?) client: AnthropicClient::from_auth(resolve_cli_auth_source()?)
.with_base_url(api::read_base_url()), .with_base_url(api::read_base_url())
.with_prompt_cache(PromptCache::new(session_id)),
model, model,
enable_tools, enable_tools,
emit_output, emit_output,
@@ -2112,8 +2133,8 @@ impl ApiClient for AnthropicRuntimeClient {
events.push(AssistantEvent::Usage(TokenUsage { events.push(AssistantEvent::Usage(TokenUsage {
input_tokens: delta.usage.input_tokens, input_tokens: delta.usage.input_tokens,
output_tokens: delta.usage.output_tokens, output_tokens: delta.usage.output_tokens,
cache_creation_input_tokens: 0, cache_creation_input_tokens: delta.usage.cache_creation_input_tokens,
cache_read_input_tokens: 0, cache_read_input_tokens: delta.usage.cache_read_input_tokens,
})); }));
} }
ApiStreamEvent::MessageStop(_) => { ApiStreamEvent::MessageStop(_) => {
@@ -2128,6 +2149,8 @@ impl ApiClient for AnthropicRuntimeClient {
} }
} }
push_prompt_cache_record(&self.client, &mut events);
if !saw_stop if !saw_stop
&& events.iter().any(|event| { && events.iter().any(|event| {
matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty()) matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty())
@@ -2152,7 +2175,9 @@ impl ApiClient for AnthropicRuntimeClient {
}) })
.await .await
.map_err(|error| RuntimeError::new(error.to_string()))?; .map_err(|error| RuntimeError::new(error.to_string()))?;
response_to_events(response, out) let mut events = response_to_events(response, out)?;
push_prompt_cache_record(&self.client, &mut events);
Ok(events)
}) })
} }
} }
@@ -2213,6 +2238,39 @@ fn collect_tool_results(summary: &runtime::TurnSummary) -> Vec<serde_json::Value
.collect() .collect()
} }
fn collect_prompt_cache_events(summary: &runtime::TurnSummary) -> Vec<serde_json::Value> {
summary
.prompt_cache_events
.iter()
.map(|event| {
json!({
"unexpected": event.unexpected,
"reason": event.reason,
"previous_cache_read_input_tokens": event.previous_cache_read_input_tokens,
"current_cache_read_input_tokens": event.current_cache_read_input_tokens,
"token_drop": event.token_drop,
})
})
.collect()
}
fn print_prompt_cache_events(summary: &runtime::TurnSummary) {
for event in &summary.prompt_cache_events {
let label = if event.unexpected {
"Prompt cache break"
} else {
"Prompt cache invalidation"
};
println!(
"{label}: {} (cache read {} -> {}, drop {})",
event.reason,
event.previous_cache_read_input_tokens,
event.current_cache_read_input_tokens,
event.token_drop,
);
}
}
fn slash_command_completion_candidates() -> Vec<String> { fn slash_command_completion_candidates() -> Vec<String> {
slash_command_specs() slash_command_specs()
.iter() .iter()
@@ -2359,18 +2417,20 @@ fn first_visible_line(text: &str) -> &str {
} }
fn format_bash_result(icon: &str, parsed: &serde_json::Value) -> String { fn format_bash_result(icon: &str, parsed: &serde_json::Value) -> String {
use std::fmt::Write as _;
let mut lines = vec![format!("{icon} \x1b[38;5;245mbash\x1b[0m")]; let mut lines = vec![format!("{icon} \x1b[38;5;245mbash\x1b[0m")];
if let Some(task_id) = parsed if let Some(task_id) = parsed
.get("backgroundTaskId") .get("backgroundTaskId")
.and_then(|value| value.as_str()) .and_then(|value| value.as_str())
{ {
lines[0].push_str(&format!(" backgrounded ({task_id})")); let _ = write!(lines[0], " backgrounded ({task_id})");
} else if let Some(status) = parsed } else if let Some(status) = parsed
.get("returnCodeInterpretation") .get("returnCodeInterpretation")
.and_then(|value| value.as_str()) .and_then(|value| value.as_str())
.filter(|status| !status.is_empty()) .filter(|status| !status.is_empty())
{ {
lines[0].push_str(&format!(" {status}")); let _ = write!(lines[0], " {status}");
} }
if let Some(stdout) = parsed.get("stdout").and_then(|value| value.as_str()) { if let Some(stdout) = parsed.get("stdout").and_then(|value| value.as_str()) {
@@ -2392,15 +2452,15 @@ fn format_read_result(icon: &str, parsed: &serde_json::Value) -> String {
let path = extract_tool_path(file); let path = extract_tool_path(file);
let start_line = file let start_line = file
.get("startLine") .get("startLine")
.and_then(|value| value.as_u64()) .and_then(serde_json::Value::as_u64)
.unwrap_or(1); .unwrap_or(1);
let num_lines = file let num_lines = file
.get("numLines") .get("numLines")
.and_then(|value| value.as_u64()) .and_then(serde_json::Value::as_u64)
.unwrap_or(0); .unwrap_or(0);
let total_lines = file let total_lines = file
.get("totalLines") .get("totalLines")
.and_then(|value| value.as_u64()) .and_then(serde_json::Value::as_u64)
.unwrap_or(num_lines); .unwrap_or(num_lines);
let content = file let content = file
.get("content") .get("content")
@@ -2426,8 +2486,7 @@ fn format_write_result(icon: &str, parsed: &serde_json::Value) -> String {
let line_count = parsed let line_count = parsed
.get("content") .get("content")
.and_then(|value| value.as_str()) .and_then(|value| value.as_str())
.map(|content| content.lines().count()) .map_or(0, |content| content.lines().count());
.unwrap_or(0);
format!( format!(
"{icon} \x1b[1;32m✏ {} {path}\x1b[0m \x1b[2m({line_count} lines)\x1b[0m", "{icon} \x1b[1;32m✏ {} {path}\x1b[0m \x1b[2m({line_count} lines)\x1b[0m",
if kind == "create" { "Wrote" } else { "Updated" }, if kind == "create" { "Wrote" } else { "Updated" },
@@ -2458,7 +2517,7 @@ fn format_edit_result(icon: &str, parsed: &serde_json::Value) -> String {
let path = extract_tool_path(parsed); let path = extract_tool_path(parsed);
let suffix = if parsed let suffix = if parsed
.get("replaceAll") .get("replaceAll")
.and_then(|value| value.as_bool()) .and_then(serde_json::Value::as_bool)
.unwrap_or(false) .unwrap_or(false)
{ {
" (replace all)" " (replace all)"
@@ -2486,7 +2545,7 @@ fn format_edit_result(icon: &str, parsed: &serde_json::Value) -> String {
fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String { fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String {
let num_files = parsed let num_files = parsed
.get("numFiles") .get("numFiles")
.and_then(|value| value.as_u64()) .and_then(serde_json::Value::as_u64)
.unwrap_or(0); .unwrap_or(0);
let filenames = parsed let filenames = parsed
.get("filenames") .get("filenames")
@@ -2510,11 +2569,11 @@ fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String {
fn format_grep_result(icon: &str, parsed: &serde_json::Value) -> String { fn format_grep_result(icon: &str, parsed: &serde_json::Value) -> String {
let num_matches = parsed let num_matches = parsed
.get("numMatches") .get("numMatches")
.and_then(|value| value.as_u64()) .and_then(serde_json::Value::as_u64)
.unwrap_or(0); .unwrap_or(0);
let num_files = parsed let num_files = parsed
.get("numFiles") .get("numFiles")
.and_then(|value| value.as_u64()) .and_then(serde_json::Value::as_u64)
.unwrap_or(0); .unwrap_or(0);
let content = parsed let content = parsed
.get("content") .get("content")
@@ -2621,6 +2680,26 @@ fn response_to_events(
Ok(events) Ok(events)
} }
fn push_prompt_cache_record(client: &AnthropicClient, events: &mut Vec<AssistantEvent>) {
if let Some(event) = client
.take_last_prompt_cache_record()
.and_then(prompt_cache_record_to_runtime_event)
{
events.push(AssistantEvent::PromptCache(event));
}
}
fn prompt_cache_record_to_runtime_event(record: PromptCacheRecord) -> Option<PromptCacheEvent> {
let cache_break = record.cache_break?;
Some(PromptCacheEvent {
unexpected: cache_break.unexpected,
reason: cache_break.reason,
previous_cache_read_input_tokens: cache_break.previous_cache_read_input_tokens,
current_cache_read_input_tokens: cache_break.current_cache_read_input_tokens,
token_drop: cache_break.token_drop,
})
}
struct CliToolExecutor { struct CliToolExecutor {
renderer: TerminalRenderer, renderer: TerminalRenderer,
emit_output: bool, emit_output: bool,

View File

@@ -286,7 +286,7 @@ impl TerminalRenderer {
) { ) {
match event { match event {
Event::Start(Tag::Heading { level, .. }) => { Event::Start(Tag::Heading { level, .. }) => {
self.start_heading(state, level as u8, output) Self::start_heading(state, level as u8, output);
} }
Event::End(TagEnd::Paragraph) => output.push_str("\n\n"), Event::End(TagEnd::Paragraph) => output.push_str("\n\n"),
Event::Start(Tag::BlockQuote(..)) => self.start_quote(state, output), Event::Start(Tag::BlockQuote(..)) => self.start_quote(state, output),
@@ -426,7 +426,7 @@ impl TerminalRenderer {
} }
} }
fn start_heading(&self, state: &mut RenderState, level: u8, output: &mut String) { fn start_heading(state: &mut RenderState, level: u8, output: &mut String) {
state.heading_level = Some(level); state.heading_level = Some(level);
if !output.is_empty() { if !output.is_empty() {
output.push('\n'); output.push('\n');

View File

@@ -5,15 +5,15 @@ use std::time::{Duration, Instant};
use api::{ use api::{
read_base_url, AnthropicClient, ContentBlockDelta, InputContentBlock, InputMessage, read_base_url, AnthropicClient, ContentBlockDelta, InputContentBlock, InputMessage,
MessageRequest, MessageResponse, OutputContentBlock, StreamEvent as ApiStreamEvent, ToolChoice, MessageRequest, MessageResponse, OutputContentBlock, PromptCache, PromptCacheRecord,
ToolDefinition, ToolResultContentBlock, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock,
}; };
use reqwest::blocking::Client; use reqwest::blocking::Client;
use runtime::{ use runtime::{
edit_file, execute_bash, glob_search, grep_search, load_system_prompt, read_file, write_file, edit_file, execute_bash, glob_search, grep_search, load_system_prompt, read_file, write_file,
ApiClient, ApiRequest, AssistantEvent, BashCommandInput, ContentBlock, ConversationMessage, ApiClient, ApiRequest, AssistantEvent, BashCommandInput, ContentBlock, ConversationMessage,
ConversationRuntime, GrepSearchInput, MessageRole, PermissionMode, PermissionPolicy, ConversationRuntime, GrepSearchInput, MessageRole, PermissionMode, PermissionPolicy,
RuntimeError, Session, TokenUsage, ToolError, ToolExecutor, PromptCacheEvent, RuntimeError, Session, TokenUsage, ToolError, ToolExecutor,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{json, Value}; use serde_json::{json, Value};
@@ -1466,7 +1466,8 @@ fn build_agent_runtime(
.clone() .clone()
.unwrap_or_else(|| DEFAULT_AGENT_MODEL.to_string()); .unwrap_or_else(|| DEFAULT_AGENT_MODEL.to_string());
let allowed_tools = job.allowed_tools.clone(); let allowed_tools = job.allowed_tools.clone();
let api_client = AnthropicRuntimeClient::new(model, allowed_tools.clone())?; let api_client =
AnthropicRuntimeClient::new(model, allowed_tools.clone(), job.manifest.agent_id.clone())?;
let tool_executor = SubagentToolExecutor::new(allowed_tools); let tool_executor = SubagentToolExecutor::new(allowed_tools);
Ok(ConversationRuntime::new( Ok(ConversationRuntime::new(
Session::new(), Session::new(),
@@ -1643,10 +1644,15 @@ struct AnthropicRuntimeClient {
} }
impl AnthropicRuntimeClient { impl AnthropicRuntimeClient {
fn new(model: String, allowed_tools: BTreeSet<String>) -> Result<Self, String> { fn new(
model: String,
allowed_tools: BTreeSet<String>,
session_id: impl Into<String>,
) -> Result<Self, String> {
let client = AnthropicClient::from_env() let client = AnthropicClient::from_env()
.map_err(|error| error.to_string())? .map_err(|error| error.to_string())?
.with_base_url(read_base_url()); .with_base_url(read_base_url())
.with_prompt_cache(PromptCache::new(session_id));
Ok(Self { Ok(Self {
runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?, runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?,
client, client,
@@ -1657,6 +1663,7 @@ impl AnthropicRuntimeClient {
} }
impl ApiClient for AnthropicRuntimeClient { impl ApiClient for AnthropicRuntimeClient {
#[allow(clippy::too_many_lines)]
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> { fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
let tools = tool_specs_for_allowed_tools(Some(&self.allowed_tools)) let tools = tool_specs_for_allowed_tools(Some(&self.allowed_tools))
.into_iter() .into_iter()
@@ -1726,8 +1733,8 @@ impl ApiClient for AnthropicRuntimeClient {
events.push(AssistantEvent::Usage(TokenUsage { events.push(AssistantEvent::Usage(TokenUsage {
input_tokens: delta.usage.input_tokens, input_tokens: delta.usage.input_tokens,
output_tokens: delta.usage.output_tokens, output_tokens: delta.usage.output_tokens,
cache_creation_input_tokens: 0, cache_creation_input_tokens: delta.usage.cache_creation_input_tokens,
cache_read_input_tokens: 0, cache_read_input_tokens: delta.usage.cache_read_input_tokens,
})); }));
} }
ApiStreamEvent::MessageStop(_) => { ApiStreamEvent::MessageStop(_) => {
@@ -1737,6 +1744,8 @@ impl ApiClient for AnthropicRuntimeClient {
} }
} }
push_prompt_cache_record(&self.client, &mut events);
if !saw_stop if !saw_stop
&& events.iter().any(|event| { && events.iter().any(|event| {
matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty()) matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty())
@@ -1761,7 +1770,9 @@ impl ApiClient for AnthropicRuntimeClient {
}) })
.await .await
.map_err(|error| RuntimeError::new(error.to_string()))?; .map_err(|error| RuntimeError::new(error.to_string()))?;
Ok(response_to_events(response)) let mut events = response_to_events(response);
push_prompt_cache_record(&self.client, &mut events);
Ok(events)
}) })
} }
} }
@@ -1884,6 +1895,26 @@ fn response_to_events(response: MessageResponse) -> Vec<AssistantEvent> {
events events
} }
fn push_prompt_cache_record(client: &AnthropicClient, events: &mut Vec<AssistantEvent>) {
if let Some(event) = client
.take_last_prompt_cache_record()
.and_then(prompt_cache_record_to_runtime_event)
{
events.push(AssistantEvent::PromptCache(event));
}
}
fn prompt_cache_record_to_runtime_event(record: PromptCacheRecord) -> Option<PromptCacheEvent> {
let cache_break = record.cache_break?;
Some(PromptCacheEvent {
unexpected: cache_break.unexpected,
reason: cache_break.reason,
previous_cache_read_input_tokens: cache_break.previous_cache_read_input_tokens,
current_cache_read_input_tokens: cache_break.current_cache_read_input_tokens,
token_drop: cache_break.token_drop,
})
}
fn final_assistant_text(summary: &runtime::TurnSummary) -> String { fn final_assistant_text(summary: &runtime::TurnSummary) -> String {
summary summary
.assistant_messages .assistant_messages