feat: cache-tracking progress
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
use runtime::{
|
use runtime::{
|
||||||
@@ -8,7 +9,7 @@ use runtime::{
|
|||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
use crate::error::ApiError;
|
use crate::error::ApiError;
|
||||||
use crate::prompt_cache::{PromptCache, PromptCacheStats};
|
use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats};
|
||||||
use crate::sse::SseParser;
|
use crate::sse::SseParser;
|
||||||
use crate::types::{MessageRequest, MessageResponse, StreamEvent, Usage};
|
use crate::types::{MessageRequest, MessageResponse, StreamEvent, Usage};
|
||||||
|
|
||||||
@@ -110,6 +111,7 @@ pub struct AnthropicClient {
|
|||||||
initial_backoff: Duration,
|
initial_backoff: Duration,
|
||||||
max_backoff: Duration,
|
max_backoff: Duration,
|
||||||
prompt_cache: Option<PromptCache>,
|
prompt_cache: Option<PromptCache>,
|
||||||
|
last_prompt_cache_record: Arc<Mutex<Option<PromptCacheRecord>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AnthropicClient {
|
impl AnthropicClient {
|
||||||
@@ -123,6 +125,7 @@ impl AnthropicClient {
|
|||||||
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
||||||
max_backoff: DEFAULT_MAX_BACKOFF,
|
max_backoff: DEFAULT_MAX_BACKOFF,
|
||||||
prompt_cache: None,
|
prompt_cache: None,
|
||||||
|
last_prompt_cache_record: Arc::new(Mutex::new(None)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,6 +139,7 @@ impl AnthropicClient {
|
|||||||
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
||||||
max_backoff: DEFAULT_MAX_BACKOFF,
|
max_backoff: DEFAULT_MAX_BACKOFF,
|
||||||
prompt_cache: None,
|
prompt_cache: None,
|
||||||
|
last_prompt_cache_record: Arc::new(Mutex::new(None)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -209,6 +213,14 @@ impl AnthropicClient {
|
|||||||
self.prompt_cache.as_ref().map(PromptCache::stats)
|
self.prompt_cache.as_ref().map(PromptCache::stats)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn take_last_prompt_cache_record(&self) -> Option<PromptCacheRecord> {
|
||||||
|
self.last_prompt_cache_record()
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||||
|
.take()
|
||||||
|
}
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn auth_source(&self) -> &AuthSource {
|
pub fn auth_source(&self) -> &AuthSource {
|
||||||
&self.auth
|
&self.auth
|
||||||
@@ -218,12 +230,16 @@ impl AnthropicClient {
|
|||||||
&self,
|
&self,
|
||||||
request: &MessageRequest,
|
request: &MessageRequest,
|
||||||
) -> Result<MessageResponse, ApiError> {
|
) -> Result<MessageResponse, ApiError> {
|
||||||
|
self.store_last_prompt_cache_record(None);
|
||||||
let request = MessageRequest {
|
let request = MessageRequest {
|
||||||
stream: false,
|
stream: false,
|
||||||
..request.clone()
|
..request.clone()
|
||||||
};
|
};
|
||||||
if let Some(prompt_cache) = &self.prompt_cache {
|
if let Some(prompt_cache) = &self.prompt_cache {
|
||||||
if let Some(response) = prompt_cache.lookup_completion(&request) {
|
if let Some(response) = prompt_cache.lookup_completion(&request) {
|
||||||
|
self.store_last_prompt_cache_record(Some(prompt_cache_record_from_stats(
|
||||||
|
prompt_cache.stats(),
|
||||||
|
)));
|
||||||
return Ok(response);
|
return Ok(response);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -237,7 +253,8 @@ impl AnthropicClient {
|
|||||||
response.request_id = request_id;
|
response.request_id = request_id;
|
||||||
}
|
}
|
||||||
if let Some(prompt_cache) = &self.prompt_cache {
|
if let Some(prompt_cache) = &self.prompt_cache {
|
||||||
let _ = prompt_cache.record_response(&request, &response);
|
let record = prompt_cache.record_response(&request, &response);
|
||||||
|
self.store_last_prompt_cache_record(Some(record));
|
||||||
}
|
}
|
||||||
Ok(response)
|
Ok(response)
|
||||||
}
|
}
|
||||||
@@ -246,6 +263,7 @@ impl AnthropicClient {
|
|||||||
&self,
|
&self,
|
||||||
request: &MessageRequest,
|
request: &MessageRequest,
|
||||||
) -> Result<MessageStream, ApiError> {
|
) -> Result<MessageStream, ApiError> {
|
||||||
|
self.store_last_prompt_cache_record(None);
|
||||||
let response = self
|
let response = self
|
||||||
.send_with_retry(&request.clone().with_streaming())
|
.send_with_retry(&request.clone().with_streaming())
|
||||||
.await?;
|
.await?;
|
||||||
@@ -263,10 +281,22 @@ impl AnthropicClient {
|
|||||||
request: request.clone().with_streaming(),
|
request: request.clone().with_streaming(),
|
||||||
last_usage: None,
|
last_usage: None,
|
||||||
finalized: false,
|
finalized: false,
|
||||||
|
last_record: self.last_prompt_cache_record.clone(),
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn store_last_prompt_cache_record(&self, record: Option<PromptCacheRecord>) {
|
||||||
|
*self
|
||||||
|
.last_prompt_cache_record()
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(std::sync::PoisonError::into_inner) = record;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn last_prompt_cache_record(&self) -> &Arc<Mutex<Option<PromptCacheRecord>>> {
|
||||||
|
&self.last_prompt_cache_record
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn exchange_oauth_code(
|
pub async fn exchange_oauth_code(
|
||||||
&self,
|
&self,
|
||||||
config: &OAuthConfig,
|
config: &OAuthConfig,
|
||||||
@@ -615,6 +645,7 @@ struct StreamCacheTracking {
|
|||||||
request: MessageRequest,
|
request: MessageRequest,
|
||||||
last_usage: Option<Usage>,
|
last_usage: Option<Usage>,
|
||||||
finalized: bool,
|
finalized: bool,
|
||||||
|
last_record: Arc<Mutex<Option<PromptCacheRecord>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamCacheTracking {
|
impl StreamCacheTracking {
|
||||||
@@ -638,12 +669,23 @@ impl StreamCacheTracking {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if let Some(usage) = &self.last_usage {
|
if let Some(usage) = &self.last_usage {
|
||||||
let _ = self.prompt_cache.record_usage(&self.request, usage);
|
let record = self.prompt_cache.record_usage(&self.request, usage);
|
||||||
|
*self
|
||||||
|
.last_record
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(std::sync::PoisonError::into_inner) = Some(record);
|
||||||
}
|
}
|
||||||
self.finalized = true;
|
self.finalized = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn prompt_cache_record_from_stats(stats: PromptCacheStats) -> PromptCacheRecord {
|
||||||
|
PromptCacheRecord {
|
||||||
|
cache_break: None,
|
||||||
|
stats,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
|
async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
if status.is_success() {
|
if status.is_success() {
|
||||||
|
|||||||
@@ -25,9 +25,19 @@ pub enum AssistantEvent {
|
|||||||
input: String,
|
input: String,
|
||||||
},
|
},
|
||||||
Usage(TokenUsage),
|
Usage(TokenUsage),
|
||||||
|
PromptCache(PromptCacheEvent),
|
||||||
MessageStop,
|
MessageStop,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct PromptCacheEvent {
|
||||||
|
pub unexpected: bool,
|
||||||
|
pub reason: String,
|
||||||
|
pub previous_cache_read_input_tokens: u32,
|
||||||
|
pub current_cache_read_input_tokens: u32,
|
||||||
|
pub token_drop: u32,
|
||||||
|
}
|
||||||
|
|
||||||
pub trait ApiClient {
|
pub trait ApiClient {
|
||||||
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>;
|
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>;
|
||||||
}
|
}
|
||||||
@@ -84,6 +94,7 @@ impl std::error::Error for RuntimeError {}
|
|||||||
pub struct TurnSummary {
|
pub struct TurnSummary {
|
||||||
pub assistant_messages: Vec<ConversationMessage>,
|
pub assistant_messages: Vec<ConversationMessage>,
|
||||||
pub tool_results: Vec<ConversationMessage>,
|
pub tool_results: Vec<ConversationMessage>,
|
||||||
|
pub prompt_cache_events: Vec<PromptCacheEvent>,
|
||||||
pub iterations: usize,
|
pub iterations: usize,
|
||||||
pub usage: TokenUsage,
|
pub usage: TokenUsage,
|
||||||
}
|
}
|
||||||
@@ -118,7 +129,7 @@ where
|
|||||||
tool_executor,
|
tool_executor,
|
||||||
permission_policy,
|
permission_policy,
|
||||||
system_prompt,
|
system_prompt,
|
||||||
RuntimeFeatureConfig::default(),
|
&RuntimeFeatureConfig::default(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -129,7 +140,7 @@ where
|
|||||||
tool_executor: T,
|
tool_executor: T,
|
||||||
permission_policy: PermissionPolicy,
|
permission_policy: PermissionPolicy,
|
||||||
system_prompt: Vec<String>,
|
system_prompt: Vec<String>,
|
||||||
feature_config: RuntimeFeatureConfig,
|
feature_config: &RuntimeFeatureConfig,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let usage_tracker = UsageTracker::from_session(&session);
|
let usage_tracker = UsageTracker::from_session(&session);
|
||||||
Self {
|
Self {
|
||||||
@@ -140,7 +151,7 @@ where
|
|||||||
system_prompt,
|
system_prompt,
|
||||||
max_iterations: usize::MAX,
|
max_iterations: usize::MAX,
|
||||||
usage_tracker,
|
usage_tracker,
|
||||||
hook_runner: HookRunner::from_feature_config(&feature_config),
|
hook_runner: HookRunner::from_feature_config(feature_config),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -161,6 +172,7 @@ where
|
|||||||
|
|
||||||
let mut assistant_messages = Vec::new();
|
let mut assistant_messages = Vec::new();
|
||||||
let mut tool_results = Vec::new();
|
let mut tool_results = Vec::new();
|
||||||
|
let mut prompt_cache_events = Vec::new();
|
||||||
let mut iterations = 0;
|
let mut iterations = 0;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
@@ -176,10 +188,12 @@ where
|
|||||||
messages: self.session.messages.clone(),
|
messages: self.session.messages.clone(),
|
||||||
};
|
};
|
||||||
let events = self.api_client.stream(request)?;
|
let events = self.api_client.stream(request)?;
|
||||||
let (assistant_message, usage) = build_assistant_message(events)?;
|
let (assistant_message, usage, turn_prompt_cache_events) =
|
||||||
|
build_assistant_message(events)?;
|
||||||
if let Some(usage) = usage {
|
if let Some(usage) = usage {
|
||||||
self.usage_tracker.record(usage);
|
self.usage_tracker.record(usage);
|
||||||
}
|
}
|
||||||
|
prompt_cache_events.extend(turn_prompt_cache_events);
|
||||||
let pending_tool_uses = assistant_message
|
let pending_tool_uses = assistant_message
|
||||||
.blocks
|
.blocks
|
||||||
.iter()
|
.iter()
|
||||||
@@ -257,6 +271,7 @@ where
|
|||||||
Ok(TurnSummary {
|
Ok(TurnSummary {
|
||||||
assistant_messages,
|
assistant_messages,
|
||||||
tool_results,
|
tool_results,
|
||||||
|
prompt_cache_events,
|
||||||
iterations,
|
iterations,
|
||||||
usage: self.usage_tracker.cumulative_usage(),
|
usage: self.usage_tracker.cumulative_usage(),
|
||||||
})
|
})
|
||||||
@@ -290,9 +305,17 @@ where
|
|||||||
|
|
||||||
fn build_assistant_message(
|
fn build_assistant_message(
|
||||||
events: Vec<AssistantEvent>,
|
events: Vec<AssistantEvent>,
|
||||||
) -> Result<(ConversationMessage, Option<TokenUsage>), RuntimeError> {
|
) -> Result<
|
||||||
|
(
|
||||||
|
ConversationMessage,
|
||||||
|
Option<TokenUsage>,
|
||||||
|
Vec<PromptCacheEvent>,
|
||||||
|
),
|
||||||
|
RuntimeError,
|
||||||
|
> {
|
||||||
let mut text = String::new();
|
let mut text = String::new();
|
||||||
let mut blocks = Vec::new();
|
let mut blocks = Vec::new();
|
||||||
|
let mut prompt_cache_events = Vec::new();
|
||||||
let mut finished = false;
|
let mut finished = false;
|
||||||
let mut usage = None;
|
let mut usage = None;
|
||||||
|
|
||||||
@@ -304,6 +327,7 @@ fn build_assistant_message(
|
|||||||
blocks.push(ContentBlock::ToolUse { id, name, input });
|
blocks.push(ContentBlock::ToolUse { id, name, input });
|
||||||
}
|
}
|
||||||
AssistantEvent::Usage(value) => usage = Some(value),
|
AssistantEvent::Usage(value) => usage = Some(value),
|
||||||
|
AssistantEvent::PromptCache(event) => prompt_cache_events.push(event),
|
||||||
AssistantEvent::MessageStop => {
|
AssistantEvent::MessageStop => {
|
||||||
finished = true;
|
finished = true;
|
||||||
}
|
}
|
||||||
@@ -324,6 +348,7 @@ fn build_assistant_message(
|
|||||||
Ok((
|
Ok((
|
||||||
ConversationMessage::assistant_with_usage(blocks, usage),
|
ConversationMessage::assistant_with_usage(blocks, usage),
|
||||||
usage,
|
usage,
|
||||||
|
prompt_cache_events,
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -396,7 +421,7 @@ impl ToolExecutor for StaticToolExecutor {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::{
|
use super::{
|
||||||
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError,
|
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, PromptCacheEvent, RuntimeError,
|
||||||
StaticToolExecutor,
|
StaticToolExecutor,
|
||||||
};
|
};
|
||||||
use crate::compact::CompactionConfig;
|
use crate::compact::CompactionConfig;
|
||||||
@@ -453,6 +478,15 @@ mod tests {
|
|||||||
cache_creation_input_tokens: 1,
|
cache_creation_input_tokens: 1,
|
||||||
cache_read_input_tokens: 3,
|
cache_read_input_tokens: 3,
|
||||||
}),
|
}),
|
||||||
|
AssistantEvent::PromptCache(PromptCacheEvent {
|
||||||
|
unexpected: true,
|
||||||
|
reason:
|
||||||
|
"cache read tokens dropped while prompt fingerprint remained stable"
|
||||||
|
.to_string(),
|
||||||
|
previous_cache_read_input_tokens: 6_000,
|
||||||
|
current_cache_read_input_tokens: 1_000,
|
||||||
|
token_drop: 5_000,
|
||||||
|
}),
|
||||||
AssistantEvent::MessageStop,
|
AssistantEvent::MessageStop,
|
||||||
])
|
])
|
||||||
}
|
}
|
||||||
@@ -506,8 +540,10 @@ mod tests {
|
|||||||
assert_eq!(summary.iterations, 2);
|
assert_eq!(summary.iterations, 2);
|
||||||
assert_eq!(summary.assistant_messages.len(), 2);
|
assert_eq!(summary.assistant_messages.len(), 2);
|
||||||
assert_eq!(summary.tool_results.len(), 1);
|
assert_eq!(summary.tool_results.len(), 1);
|
||||||
|
assert_eq!(summary.prompt_cache_events.len(), 1);
|
||||||
assert_eq!(runtime.session().messages.len(), 4);
|
assert_eq!(runtime.session().messages.len(), 4);
|
||||||
assert_eq!(summary.usage.output_tokens, 10);
|
assert_eq!(summary.usage.output_tokens, 10);
|
||||||
|
assert!(summary.prompt_cache_events[0].unexpected);
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
runtime.session().messages[1].blocks[1],
|
runtime.session().messages[1].blocks[1],
|
||||||
ContentBlock::ToolUse { .. }
|
ContentBlock::ToolUse { .. }
|
||||||
@@ -609,7 +645,7 @@ mod tests {
|
|||||||
}),
|
}),
|
||||||
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
||||||
vec!["system".to_string()],
|
vec!["system".to_string()],
|
||||||
RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
|
&RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
|
||||||
vec![shell_snippet("printf 'blocked by hook'; exit 2")],
|
vec![shell_snippet("printf 'blocked by hook'; exit 2")],
|
||||||
Vec::new(),
|
Vec::new(),
|
||||||
)),
|
)),
|
||||||
@@ -675,7 +711,7 @@ mod tests {
|
|||||||
StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
|
StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
|
||||||
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
||||||
vec!["system".to_string()],
|
vec!["system".to_string()],
|
||||||
RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
|
&RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
|
||||||
vec![shell_snippet("printf 'pre hook ran'")],
|
vec![shell_snippet("printf 'pre hook ran'")],
|
||||||
vec![shell_snippet("printf 'post hook ran'")],
|
vec![shell_snippet("printf 'post hook ran'")],
|
||||||
)),
|
)),
|
||||||
@@ -697,7 +733,7 @@ mod tests {
|
|||||||
"post hook should preserve non-error result: {output:?}"
|
"post hook should preserve non-error result: {output:?}"
|
||||||
);
|
);
|
||||||
assert!(
|
assert!(
|
||||||
output.contains("4"),
|
output.contains('4'),
|
||||||
"tool output missing value: {output:?}"
|
"tool output missing value: {output:?}"
|
||||||
);
|
);
|
||||||
assert!(
|
assert!(
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ impl HookRunner {
|
|||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult {
|
pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult {
|
||||||
self.run_commands(
|
Self::run_commands(
|
||||||
HookEvent::PreToolUse,
|
HookEvent::PreToolUse,
|
||||||
self.config.pre_tool_use(),
|
self.config.pre_tool_use(),
|
||||||
tool_name,
|
tool_name,
|
||||||
@@ -82,7 +82,7 @@ impl HookRunner {
|
|||||||
tool_output: &str,
|
tool_output: &str,
|
||||||
is_error: bool,
|
is_error: bool,
|
||||||
) -> HookRunResult {
|
) -> HookRunResult {
|
||||||
self.run_commands(
|
Self::run_commands(
|
||||||
HookEvent::PostToolUse,
|
HookEvent::PostToolUse,
|
||||||
self.config.post_tool_use(),
|
self.config.post_tool_use(),
|
||||||
tool_name,
|
tool_name,
|
||||||
@@ -93,7 +93,6 @@ impl HookRunner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn run_commands(
|
fn run_commands(
|
||||||
&self,
|
|
||||||
event: HookEvent,
|
event: HookEvent,
|
||||||
commands: &[String],
|
commands: &[String],
|
||||||
tool_name: &str,
|
tool_name: &str,
|
||||||
@@ -118,7 +117,7 @@ impl HookRunner {
|
|||||||
let mut messages = Vec::new();
|
let mut messages = Vec::new();
|
||||||
|
|
||||||
for command in commands {
|
for command in commands {
|
||||||
match self.run_command(
|
match Self::run_command(
|
||||||
command,
|
command,
|
||||||
event,
|
event,
|
||||||
tool_name,
|
tool_name,
|
||||||
@@ -150,7 +149,6 @@ impl HookRunner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn run_command(
|
fn run_command(
|
||||||
&self,
|
|
||||||
command: &str,
|
command: &str,
|
||||||
event: HookEvent,
|
event: HookEvent,
|
||||||
tool_name: &str,
|
tool_name: &str,
|
||||||
|
|||||||
@@ -31,8 +31,8 @@ pub use config::{
|
|||||||
ScopedMcpServerConfig, CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
|
ScopedMcpServerConfig, CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
|
||||||
};
|
};
|
||||||
pub use conversation::{
|
pub use conversation::{
|
||||||
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, StaticToolExecutor,
|
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, PromptCacheEvent, RuntimeError,
|
||||||
ToolError, ToolExecutor, TurnSummary,
|
StaticToolExecutor, ToolError, ToolExecutor, TurnSummary,
|
||||||
};
|
};
|
||||||
pub use file_ops::{
|
pub use file_ops::{
|
||||||
edit_file, glob_search, grep_search, read_file, write_file, EditFileOutput, GlobSearchOutput,
|
edit_file, glob_search, grep_search, read_file, write_file, EditFileOutput, GlobSearchOutput,
|
||||||
|
|||||||
@@ -13,8 +13,9 @@ use std::time::{SystemTime, UNIX_EPOCH};
|
|||||||
|
|
||||||
use api::{
|
use api::{
|
||||||
resolve_startup_auth_source, AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock,
|
resolve_startup_auth_source, AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock,
|
||||||
InputMessage, MessageRequest, MessageResponse, OutputContentBlock,
|
InputMessage, MessageRequest, MessageResponse, OutputContentBlock, PromptCache,
|
||||||
StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock,
|
PromptCacheRecord, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition,
|
||||||
|
ToolResultContentBlock,
|
||||||
};
|
};
|
||||||
|
|
||||||
use commands::{
|
use commands::{
|
||||||
@@ -28,8 +29,8 @@ use runtime::{
|
|||||||
parse_oauth_callback_request_target, save_oauth_credentials, ApiClient, ApiRequest,
|
parse_oauth_callback_request_target, save_oauth_credentials, ApiClient, ApiRequest,
|
||||||
AssistantEvent, CompactionConfig, ConfigLoader, ConfigSource, ContentBlock,
|
AssistantEvent, CompactionConfig, ConfigLoader, ConfigSource, ContentBlock,
|
||||||
ConversationMessage, ConversationRuntime, MessageRole, OAuthAuthorizationRequest, OAuthConfig,
|
ConversationMessage, ConversationRuntime, MessageRole, OAuthAuthorizationRequest, OAuthConfig,
|
||||||
OAuthTokenExchangeRequest, PermissionMode, PermissionPolicy, ProjectContext, RuntimeError,
|
OAuthTokenExchangeRequest, PermissionMode, PermissionPolicy, ProjectContext, PromptCacheEvent,
|
||||||
Session, TokenUsage, ToolError, ToolExecutor, UsageTracker,
|
RuntimeError, Session, TokenUsage, ToolError, ToolExecutor, UsageTracker,
|
||||||
};
|
};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use tools::{execute_tool, mvp_tool_specs, ToolSpec};
|
use tools::{execute_tool, mvp_tool_specs, ToolSpec};
|
||||||
@@ -995,6 +996,7 @@ impl LiveCli {
|
|||||||
let session = create_managed_session_handle()?;
|
let session = create_managed_session_handle()?;
|
||||||
let runtime = build_runtime(
|
let runtime = build_runtime(
|
||||||
Session::new(),
|
Session::new(),
|
||||||
|
session.id.clone(),
|
||||||
model.clone(),
|
model.clone(),
|
||||||
system_prompt.clone(),
|
system_prompt.clone(),
|
||||||
enable_tools,
|
enable_tools,
|
||||||
@@ -1050,13 +1052,14 @@ impl LiveCli {
|
|||||||
let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode);
|
let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode);
|
||||||
let result = self.runtime.run_turn(input, Some(&mut permission_prompter));
|
let result = self.runtime.run_turn(input, Some(&mut permission_prompter));
|
||||||
match result {
|
match result {
|
||||||
Ok(_) => {
|
Ok(summary) => {
|
||||||
spinner.finish(
|
spinner.finish(
|
||||||
"✨ Done",
|
"✨ Done",
|
||||||
TerminalRenderer::new().color_theme(),
|
TerminalRenderer::new().color_theme(),
|
||||||
&mut stdout,
|
&mut stdout,
|
||||||
)?;
|
)?;
|
||||||
println!();
|
println!();
|
||||||
|
print_prompt_cache_events(&summary);
|
||||||
self.persist_session()?;
|
self.persist_session()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -1086,6 +1089,7 @@ impl LiveCli {
|
|||||||
let session = self.runtime.session().clone();
|
let session = self.runtime.session().clone();
|
||||||
let mut runtime = build_runtime(
|
let mut runtime = build_runtime(
|
||||||
session,
|
session,
|
||||||
|
self.session.id.clone(),
|
||||||
self.model.clone(),
|
self.model.clone(),
|
||||||
self.system_prompt.clone(),
|
self.system_prompt.clone(),
|
||||||
true,
|
true,
|
||||||
@@ -1105,6 +1109,7 @@ impl LiveCli {
|
|||||||
"iterations": summary.iterations,
|
"iterations": summary.iterations,
|
||||||
"tool_uses": collect_tool_uses(&summary),
|
"tool_uses": collect_tool_uses(&summary),
|
||||||
"tool_results": collect_tool_results(&summary),
|
"tool_results": collect_tool_results(&summary),
|
||||||
|
"prompt_cache_events": collect_prompt_cache_events(&summary),
|
||||||
"usage": {
|
"usage": {
|
||||||
"input_tokens": summary.usage.input_tokens,
|
"input_tokens": summary.usage.input_tokens,
|
||||||
"output_tokens": summary.usage.output_tokens,
|
"output_tokens": summary.usage.output_tokens,
|
||||||
@@ -1232,6 +1237,7 @@ impl LiveCli {
|
|||||||
let message_count = session.messages.len();
|
let message_count = session.messages.len();
|
||||||
self.runtime = build_runtime(
|
self.runtime = build_runtime(
|
||||||
session,
|
session,
|
||||||
|
self.session.id.clone(),
|
||||||
model.clone(),
|
model.clone(),
|
||||||
self.system_prompt.clone(),
|
self.system_prompt.clone(),
|
||||||
true,
|
true,
|
||||||
@@ -1275,6 +1281,7 @@ impl LiveCli {
|
|||||||
self.permission_mode = permission_mode_from_label(normalized);
|
self.permission_mode = permission_mode_from_label(normalized);
|
||||||
self.runtime = build_runtime(
|
self.runtime = build_runtime(
|
||||||
session,
|
session,
|
||||||
|
self.session.id.clone(),
|
||||||
self.model.clone(),
|
self.model.clone(),
|
||||||
self.system_prompt.clone(),
|
self.system_prompt.clone(),
|
||||||
true,
|
true,
|
||||||
@@ -1300,6 +1307,7 @@ impl LiveCli {
|
|||||||
self.session = create_managed_session_handle()?;
|
self.session = create_managed_session_handle()?;
|
||||||
self.runtime = build_runtime(
|
self.runtime = build_runtime(
|
||||||
Session::new(),
|
Session::new(),
|
||||||
|
self.session.id.clone(),
|
||||||
self.model.clone(),
|
self.model.clone(),
|
||||||
self.system_prompt.clone(),
|
self.system_prompt.clone(),
|
||||||
true,
|
true,
|
||||||
@@ -1335,6 +1343,7 @@ impl LiveCli {
|
|||||||
let message_count = session.messages.len();
|
let message_count = session.messages.len();
|
||||||
self.runtime = build_runtime(
|
self.runtime = build_runtime(
|
||||||
session,
|
session,
|
||||||
|
handle.id.clone(),
|
||||||
self.model.clone(),
|
self.model.clone(),
|
||||||
self.system_prompt.clone(),
|
self.system_prompt.clone(),
|
||||||
true,
|
true,
|
||||||
@@ -1407,6 +1416,7 @@ impl LiveCli {
|
|||||||
let message_count = session.messages.len();
|
let message_count = session.messages.len();
|
||||||
self.runtime = build_runtime(
|
self.runtime = build_runtime(
|
||||||
session,
|
session,
|
||||||
|
handle.id.clone(),
|
||||||
self.model.clone(),
|
self.model.clone(),
|
||||||
self.system_prompt.clone(),
|
self.system_prompt.clone(),
|
||||||
true,
|
true,
|
||||||
@@ -1437,6 +1447,7 @@ impl LiveCli {
|
|||||||
let skipped = removed == 0;
|
let skipped = removed == 0;
|
||||||
self.runtime = build_runtime(
|
self.runtime = build_runtime(
|
||||||
result.compacted_session,
|
result.compacted_session,
|
||||||
|
self.session.id.clone(),
|
||||||
self.model.clone(),
|
self.model.clone(),
|
||||||
self.system_prompt.clone(),
|
self.system_prompt.clone(),
|
||||||
true,
|
true,
|
||||||
@@ -1912,8 +1923,10 @@ fn build_runtime_feature_config(
|
|||||||
.clone())
|
.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn build_runtime(
|
fn build_runtime(
|
||||||
session: Session,
|
session: Session,
|
||||||
|
session_id: String,
|
||||||
model: String,
|
model: String,
|
||||||
system_prompt: Vec<String>,
|
system_prompt: Vec<String>,
|
||||||
enable_tools: bool,
|
enable_tools: bool,
|
||||||
@@ -1924,11 +1937,17 @@ fn build_runtime(
|
|||||||
{
|
{
|
||||||
Ok(ConversationRuntime::new_with_features(
|
Ok(ConversationRuntime::new_with_features(
|
||||||
session,
|
session,
|
||||||
AnthropicRuntimeClient::new(model, enable_tools, emit_output, allowed_tools.clone())?,
|
AnthropicRuntimeClient::new(
|
||||||
|
model,
|
||||||
|
enable_tools,
|
||||||
|
emit_output,
|
||||||
|
allowed_tools.clone(),
|
||||||
|
session_id,
|
||||||
|
)?,
|
||||||
CliToolExecutor::new(allowed_tools, emit_output),
|
CliToolExecutor::new(allowed_tools, emit_output),
|
||||||
permission_policy(permission_mode),
|
permission_policy(permission_mode),
|
||||||
system_prompt,
|
system_prompt,
|
||||||
build_runtime_feature_config()?,
|
&build_runtime_feature_config()?,
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1993,11 +2012,13 @@ impl AnthropicRuntimeClient {
|
|||||||
enable_tools: bool,
|
enable_tools: bool,
|
||||||
emit_output: bool,
|
emit_output: bool,
|
||||||
allowed_tools: Option<AllowedToolSet>,
|
allowed_tools: Option<AllowedToolSet>,
|
||||||
|
session_id: impl Into<String>,
|
||||||
) -> Result<Self, Box<dyn std::error::Error>> {
|
) -> Result<Self, Box<dyn std::error::Error>> {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
runtime: tokio::runtime::Runtime::new()?,
|
runtime: tokio::runtime::Runtime::new()?,
|
||||||
client: AnthropicClient::from_auth(resolve_cli_auth_source()?)
|
client: AnthropicClient::from_auth(resolve_cli_auth_source()?)
|
||||||
.with_base_url(api::read_base_url()),
|
.with_base_url(api::read_base_url())
|
||||||
|
.with_prompt_cache(PromptCache::new(session_id)),
|
||||||
model,
|
model,
|
||||||
enable_tools,
|
enable_tools,
|
||||||
emit_output,
|
emit_output,
|
||||||
@@ -2112,8 +2133,8 @@ impl ApiClient for AnthropicRuntimeClient {
|
|||||||
events.push(AssistantEvent::Usage(TokenUsage {
|
events.push(AssistantEvent::Usage(TokenUsage {
|
||||||
input_tokens: delta.usage.input_tokens,
|
input_tokens: delta.usage.input_tokens,
|
||||||
output_tokens: delta.usage.output_tokens,
|
output_tokens: delta.usage.output_tokens,
|
||||||
cache_creation_input_tokens: 0,
|
cache_creation_input_tokens: delta.usage.cache_creation_input_tokens,
|
||||||
cache_read_input_tokens: 0,
|
cache_read_input_tokens: delta.usage.cache_read_input_tokens,
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
ApiStreamEvent::MessageStop(_) => {
|
ApiStreamEvent::MessageStop(_) => {
|
||||||
@@ -2128,6 +2149,8 @@ impl ApiClient for AnthropicRuntimeClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
push_prompt_cache_record(&self.client, &mut events);
|
||||||
|
|
||||||
if !saw_stop
|
if !saw_stop
|
||||||
&& events.iter().any(|event| {
|
&& events.iter().any(|event| {
|
||||||
matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty())
|
matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty())
|
||||||
@@ -2152,7 +2175,9 @@ impl ApiClient for AnthropicRuntimeClient {
|
|||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.map_err(|error| RuntimeError::new(error.to_string()))?;
|
.map_err(|error| RuntimeError::new(error.to_string()))?;
|
||||||
response_to_events(response, out)
|
let mut events = response_to_events(response, out)?;
|
||||||
|
push_prompt_cache_record(&self.client, &mut events);
|
||||||
|
Ok(events)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2213,6 +2238,39 @@ fn collect_tool_results(summary: &runtime::TurnSummary) -> Vec<serde_json::Value
|
|||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn collect_prompt_cache_events(summary: &runtime::TurnSummary) -> Vec<serde_json::Value> {
|
||||||
|
summary
|
||||||
|
.prompt_cache_events
|
||||||
|
.iter()
|
||||||
|
.map(|event| {
|
||||||
|
json!({
|
||||||
|
"unexpected": event.unexpected,
|
||||||
|
"reason": event.reason,
|
||||||
|
"previous_cache_read_input_tokens": event.previous_cache_read_input_tokens,
|
||||||
|
"current_cache_read_input_tokens": event.current_cache_read_input_tokens,
|
||||||
|
"token_drop": event.token_drop,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_prompt_cache_events(summary: &runtime::TurnSummary) {
|
||||||
|
for event in &summary.prompt_cache_events {
|
||||||
|
let label = if event.unexpected {
|
||||||
|
"Prompt cache break"
|
||||||
|
} else {
|
||||||
|
"Prompt cache invalidation"
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"{label}: {} (cache read {} -> {}, drop {})",
|
||||||
|
event.reason,
|
||||||
|
event.previous_cache_read_input_tokens,
|
||||||
|
event.current_cache_read_input_tokens,
|
||||||
|
event.token_drop,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn slash_command_completion_candidates() -> Vec<String> {
|
fn slash_command_completion_candidates() -> Vec<String> {
|
||||||
slash_command_specs()
|
slash_command_specs()
|
||||||
.iter()
|
.iter()
|
||||||
@@ -2359,18 +2417,20 @@ fn first_visible_line(text: &str) -> &str {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn format_bash_result(icon: &str, parsed: &serde_json::Value) -> String {
|
fn format_bash_result(icon: &str, parsed: &serde_json::Value) -> String {
|
||||||
|
use std::fmt::Write as _;
|
||||||
|
|
||||||
let mut lines = vec![format!("{icon} \x1b[38;5;245mbash\x1b[0m")];
|
let mut lines = vec![format!("{icon} \x1b[38;5;245mbash\x1b[0m")];
|
||||||
if let Some(task_id) = parsed
|
if let Some(task_id) = parsed
|
||||||
.get("backgroundTaskId")
|
.get("backgroundTaskId")
|
||||||
.and_then(|value| value.as_str())
|
.and_then(|value| value.as_str())
|
||||||
{
|
{
|
||||||
lines[0].push_str(&format!(" backgrounded ({task_id})"));
|
let _ = write!(lines[0], " backgrounded ({task_id})");
|
||||||
} else if let Some(status) = parsed
|
} else if let Some(status) = parsed
|
||||||
.get("returnCodeInterpretation")
|
.get("returnCodeInterpretation")
|
||||||
.and_then(|value| value.as_str())
|
.and_then(|value| value.as_str())
|
||||||
.filter(|status| !status.is_empty())
|
.filter(|status| !status.is_empty())
|
||||||
{
|
{
|
||||||
lines[0].push_str(&format!(" {status}"));
|
let _ = write!(lines[0], " {status}");
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(stdout) = parsed.get("stdout").and_then(|value| value.as_str()) {
|
if let Some(stdout) = parsed.get("stdout").and_then(|value| value.as_str()) {
|
||||||
@@ -2392,15 +2452,15 @@ fn format_read_result(icon: &str, parsed: &serde_json::Value) -> String {
|
|||||||
let path = extract_tool_path(file);
|
let path = extract_tool_path(file);
|
||||||
let start_line = file
|
let start_line = file
|
||||||
.get("startLine")
|
.get("startLine")
|
||||||
.and_then(|value| value.as_u64())
|
.and_then(serde_json::Value::as_u64)
|
||||||
.unwrap_or(1);
|
.unwrap_or(1);
|
||||||
let num_lines = file
|
let num_lines = file
|
||||||
.get("numLines")
|
.get("numLines")
|
||||||
.and_then(|value| value.as_u64())
|
.and_then(serde_json::Value::as_u64)
|
||||||
.unwrap_or(0);
|
.unwrap_or(0);
|
||||||
let total_lines = file
|
let total_lines = file
|
||||||
.get("totalLines")
|
.get("totalLines")
|
||||||
.and_then(|value| value.as_u64())
|
.and_then(serde_json::Value::as_u64)
|
||||||
.unwrap_or(num_lines);
|
.unwrap_or(num_lines);
|
||||||
let content = file
|
let content = file
|
||||||
.get("content")
|
.get("content")
|
||||||
@@ -2426,8 +2486,7 @@ fn format_write_result(icon: &str, parsed: &serde_json::Value) -> String {
|
|||||||
let line_count = parsed
|
let line_count = parsed
|
||||||
.get("content")
|
.get("content")
|
||||||
.and_then(|value| value.as_str())
|
.and_then(|value| value.as_str())
|
||||||
.map(|content| content.lines().count())
|
.map_or(0, |content| content.lines().count());
|
||||||
.unwrap_or(0);
|
|
||||||
format!(
|
format!(
|
||||||
"{icon} \x1b[1;32m✏️ {} {path}\x1b[0m \x1b[2m({line_count} lines)\x1b[0m",
|
"{icon} \x1b[1;32m✏️ {} {path}\x1b[0m \x1b[2m({line_count} lines)\x1b[0m",
|
||||||
if kind == "create" { "Wrote" } else { "Updated" },
|
if kind == "create" { "Wrote" } else { "Updated" },
|
||||||
@@ -2458,7 +2517,7 @@ fn format_edit_result(icon: &str, parsed: &serde_json::Value) -> String {
|
|||||||
let path = extract_tool_path(parsed);
|
let path = extract_tool_path(parsed);
|
||||||
let suffix = if parsed
|
let suffix = if parsed
|
||||||
.get("replaceAll")
|
.get("replaceAll")
|
||||||
.and_then(|value| value.as_bool())
|
.and_then(serde_json::Value::as_bool)
|
||||||
.unwrap_or(false)
|
.unwrap_or(false)
|
||||||
{
|
{
|
||||||
" (replace all)"
|
" (replace all)"
|
||||||
@@ -2486,7 +2545,7 @@ fn format_edit_result(icon: &str, parsed: &serde_json::Value) -> String {
|
|||||||
fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String {
|
fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String {
|
||||||
let num_files = parsed
|
let num_files = parsed
|
||||||
.get("numFiles")
|
.get("numFiles")
|
||||||
.and_then(|value| value.as_u64())
|
.and_then(serde_json::Value::as_u64)
|
||||||
.unwrap_or(0);
|
.unwrap_or(0);
|
||||||
let filenames = parsed
|
let filenames = parsed
|
||||||
.get("filenames")
|
.get("filenames")
|
||||||
@@ -2510,11 +2569,11 @@ fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String {
|
|||||||
fn format_grep_result(icon: &str, parsed: &serde_json::Value) -> String {
|
fn format_grep_result(icon: &str, parsed: &serde_json::Value) -> String {
|
||||||
let num_matches = parsed
|
let num_matches = parsed
|
||||||
.get("numMatches")
|
.get("numMatches")
|
||||||
.and_then(|value| value.as_u64())
|
.and_then(serde_json::Value::as_u64)
|
||||||
.unwrap_or(0);
|
.unwrap_or(0);
|
||||||
let num_files = parsed
|
let num_files = parsed
|
||||||
.get("numFiles")
|
.get("numFiles")
|
||||||
.and_then(|value| value.as_u64())
|
.and_then(serde_json::Value::as_u64)
|
||||||
.unwrap_or(0);
|
.unwrap_or(0);
|
||||||
let content = parsed
|
let content = parsed
|
||||||
.get("content")
|
.get("content")
|
||||||
@@ -2621,6 +2680,26 @@ fn response_to_events(
|
|||||||
Ok(events)
|
Ok(events)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn push_prompt_cache_record(client: &AnthropicClient, events: &mut Vec<AssistantEvent>) {
|
||||||
|
if let Some(event) = client
|
||||||
|
.take_last_prompt_cache_record()
|
||||||
|
.and_then(prompt_cache_record_to_runtime_event)
|
||||||
|
{
|
||||||
|
events.push(AssistantEvent::PromptCache(event));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prompt_cache_record_to_runtime_event(record: PromptCacheRecord) -> Option<PromptCacheEvent> {
|
||||||
|
let cache_break = record.cache_break?;
|
||||||
|
Some(PromptCacheEvent {
|
||||||
|
unexpected: cache_break.unexpected,
|
||||||
|
reason: cache_break.reason,
|
||||||
|
previous_cache_read_input_tokens: cache_break.previous_cache_read_input_tokens,
|
||||||
|
current_cache_read_input_tokens: cache_break.current_cache_read_input_tokens,
|
||||||
|
token_drop: cache_break.token_drop,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
struct CliToolExecutor {
|
struct CliToolExecutor {
|
||||||
renderer: TerminalRenderer,
|
renderer: TerminalRenderer,
|
||||||
emit_output: bool,
|
emit_output: bool,
|
||||||
|
|||||||
@@ -286,7 +286,7 @@ impl TerminalRenderer {
|
|||||||
) {
|
) {
|
||||||
match event {
|
match event {
|
||||||
Event::Start(Tag::Heading { level, .. }) => {
|
Event::Start(Tag::Heading { level, .. }) => {
|
||||||
self.start_heading(state, level as u8, output)
|
Self::start_heading(state, level as u8, output);
|
||||||
}
|
}
|
||||||
Event::End(TagEnd::Paragraph) => output.push_str("\n\n"),
|
Event::End(TagEnd::Paragraph) => output.push_str("\n\n"),
|
||||||
Event::Start(Tag::BlockQuote(..)) => self.start_quote(state, output),
|
Event::Start(Tag::BlockQuote(..)) => self.start_quote(state, output),
|
||||||
@@ -426,7 +426,7 @@ impl TerminalRenderer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start_heading(&self, state: &mut RenderState, level: u8, output: &mut String) {
|
fn start_heading(state: &mut RenderState, level: u8, output: &mut String) {
|
||||||
state.heading_level = Some(level);
|
state.heading_level = Some(level);
|
||||||
if !output.is_empty() {
|
if !output.is_empty() {
|
||||||
output.push('\n');
|
output.push('\n');
|
||||||
|
|||||||
@@ -5,15 +5,15 @@ use std::time::{Duration, Instant};
|
|||||||
|
|
||||||
use api::{
|
use api::{
|
||||||
read_base_url, AnthropicClient, ContentBlockDelta, InputContentBlock, InputMessage,
|
read_base_url, AnthropicClient, ContentBlockDelta, InputContentBlock, InputMessage,
|
||||||
MessageRequest, MessageResponse, OutputContentBlock, StreamEvent as ApiStreamEvent, ToolChoice,
|
MessageRequest, MessageResponse, OutputContentBlock, PromptCache, PromptCacheRecord,
|
||||||
ToolDefinition, ToolResultContentBlock,
|
StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock,
|
||||||
};
|
};
|
||||||
use reqwest::blocking::Client;
|
use reqwest::blocking::Client;
|
||||||
use runtime::{
|
use runtime::{
|
||||||
edit_file, execute_bash, glob_search, grep_search, load_system_prompt, read_file, write_file,
|
edit_file, execute_bash, glob_search, grep_search, load_system_prompt, read_file, write_file,
|
||||||
ApiClient, ApiRequest, AssistantEvent, BashCommandInput, ContentBlock, ConversationMessage,
|
ApiClient, ApiRequest, AssistantEvent, BashCommandInput, ContentBlock, ConversationMessage,
|
||||||
ConversationRuntime, GrepSearchInput, MessageRole, PermissionMode, PermissionPolicy,
|
ConversationRuntime, GrepSearchInput, MessageRole, PermissionMode, PermissionPolicy,
|
||||||
RuntimeError, Session, TokenUsage, ToolError, ToolExecutor,
|
PromptCacheEvent, RuntimeError, Session, TokenUsage, ToolError, ToolExecutor,
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
@@ -1466,7 +1466,8 @@ fn build_agent_runtime(
|
|||||||
.clone()
|
.clone()
|
||||||
.unwrap_or_else(|| DEFAULT_AGENT_MODEL.to_string());
|
.unwrap_or_else(|| DEFAULT_AGENT_MODEL.to_string());
|
||||||
let allowed_tools = job.allowed_tools.clone();
|
let allowed_tools = job.allowed_tools.clone();
|
||||||
let api_client = AnthropicRuntimeClient::new(model, allowed_tools.clone())?;
|
let api_client =
|
||||||
|
AnthropicRuntimeClient::new(model, allowed_tools.clone(), job.manifest.agent_id.clone())?;
|
||||||
let tool_executor = SubagentToolExecutor::new(allowed_tools);
|
let tool_executor = SubagentToolExecutor::new(allowed_tools);
|
||||||
Ok(ConversationRuntime::new(
|
Ok(ConversationRuntime::new(
|
||||||
Session::new(),
|
Session::new(),
|
||||||
@@ -1643,10 +1644,15 @@ struct AnthropicRuntimeClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl AnthropicRuntimeClient {
|
impl AnthropicRuntimeClient {
|
||||||
fn new(model: String, allowed_tools: BTreeSet<String>) -> Result<Self, String> {
|
fn new(
|
||||||
|
model: String,
|
||||||
|
allowed_tools: BTreeSet<String>,
|
||||||
|
session_id: impl Into<String>,
|
||||||
|
) -> Result<Self, String> {
|
||||||
let client = AnthropicClient::from_env()
|
let client = AnthropicClient::from_env()
|
||||||
.map_err(|error| error.to_string())?
|
.map_err(|error| error.to_string())?
|
||||||
.with_base_url(read_base_url());
|
.with_base_url(read_base_url())
|
||||||
|
.with_prompt_cache(PromptCache::new(session_id));
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?,
|
runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?,
|
||||||
client,
|
client,
|
||||||
@@ -1657,6 +1663,7 @@ impl AnthropicRuntimeClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ApiClient for AnthropicRuntimeClient {
|
impl ApiClient for AnthropicRuntimeClient {
|
||||||
|
#[allow(clippy::too_many_lines)]
|
||||||
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
||||||
let tools = tool_specs_for_allowed_tools(Some(&self.allowed_tools))
|
let tools = tool_specs_for_allowed_tools(Some(&self.allowed_tools))
|
||||||
.into_iter()
|
.into_iter()
|
||||||
@@ -1726,8 +1733,8 @@ impl ApiClient for AnthropicRuntimeClient {
|
|||||||
events.push(AssistantEvent::Usage(TokenUsage {
|
events.push(AssistantEvent::Usage(TokenUsage {
|
||||||
input_tokens: delta.usage.input_tokens,
|
input_tokens: delta.usage.input_tokens,
|
||||||
output_tokens: delta.usage.output_tokens,
|
output_tokens: delta.usage.output_tokens,
|
||||||
cache_creation_input_tokens: 0,
|
cache_creation_input_tokens: delta.usage.cache_creation_input_tokens,
|
||||||
cache_read_input_tokens: 0,
|
cache_read_input_tokens: delta.usage.cache_read_input_tokens,
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
ApiStreamEvent::MessageStop(_) => {
|
ApiStreamEvent::MessageStop(_) => {
|
||||||
@@ -1737,6 +1744,8 @@ impl ApiClient for AnthropicRuntimeClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
push_prompt_cache_record(&self.client, &mut events);
|
||||||
|
|
||||||
if !saw_stop
|
if !saw_stop
|
||||||
&& events.iter().any(|event| {
|
&& events.iter().any(|event| {
|
||||||
matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty())
|
matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty())
|
||||||
@@ -1761,7 +1770,9 @@ impl ApiClient for AnthropicRuntimeClient {
|
|||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.map_err(|error| RuntimeError::new(error.to_string()))?;
|
.map_err(|error| RuntimeError::new(error.to_string()))?;
|
||||||
Ok(response_to_events(response))
|
let mut events = response_to_events(response);
|
||||||
|
push_prompt_cache_record(&self.client, &mut events);
|
||||||
|
Ok(events)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1884,6 +1895,26 @@ fn response_to_events(response: MessageResponse) -> Vec<AssistantEvent> {
|
|||||||
events
|
events
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn push_prompt_cache_record(client: &AnthropicClient, events: &mut Vec<AssistantEvent>) {
|
||||||
|
if let Some(event) = client
|
||||||
|
.take_last_prompt_cache_record()
|
||||||
|
.and_then(prompt_cache_record_to_runtime_event)
|
||||||
|
{
|
||||||
|
events.push(AssistantEvent::PromptCache(event));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prompt_cache_record_to_runtime_event(record: PromptCacheRecord) -> Option<PromptCacheEvent> {
|
||||||
|
let cache_break = record.cache_break?;
|
||||||
|
Some(PromptCacheEvent {
|
||||||
|
unexpected: cache_break.unexpected,
|
||||||
|
reason: cache_break.reason,
|
||||||
|
previous_cache_read_input_tokens: cache_break.previous_cache_read_input_tokens,
|
||||||
|
current_cache_read_input_tokens: cache_break.current_cache_read_input_tokens,
|
||||||
|
token_drop: cache_break.token_drop,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn final_assistant_text(summary: &runtime::TurnSummary) -> String {
|
fn final_assistant_text(summary: &runtime::TurnSummary) -> String {
|
||||||
summary
|
summary
|
||||||
.assistant_messages
|
.assistant_messages
|
||||||
|
|||||||
Reference in New Issue
Block a user