- api: tool_use parsing, message_delta, request_id tracking, retry logic - tools: extended tool suite (WebSearch, WebFetch, Agent, etc.) - cli: live streamed conversations, session restore, compact commands - runtime: config loading, system prompt builder, token usage, compaction
584 lines
18 KiB
Rust
584 lines
18 KiB
Rust
use std::collections::BTreeMap;
|
|
use std::fmt::{Display, Formatter};
|
|
|
|
use crate::compact::{
|
|
compact_session, estimate_session_tokens, CompactionConfig, CompactionResult,
|
|
};
|
|
use crate::permissions::{PermissionOutcome, PermissionPolicy, PermissionPrompter};
|
|
use crate::session::{ContentBlock, ConversationMessage, Session};
|
|
use crate::usage::{TokenUsage, UsageTracker};
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct ApiRequest {
|
|
pub system_prompt: Vec<String>,
|
|
pub messages: Vec<ConversationMessage>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub enum AssistantEvent {
|
|
TextDelta(String),
|
|
ToolUse {
|
|
id: String,
|
|
name: String,
|
|
input: String,
|
|
},
|
|
Usage(TokenUsage),
|
|
MessageStop,
|
|
}
|
|
|
|
pub trait ApiClient {
|
|
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>;
|
|
}
|
|
|
|
pub trait ToolExecutor {
|
|
fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError>;
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct ToolError {
|
|
message: String,
|
|
}
|
|
|
|
impl ToolError {
|
|
#[must_use]
|
|
pub fn new(message: impl Into<String>) -> Self {
|
|
Self {
|
|
message: message.into(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Display for ToolError {
|
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "{}", self.message)
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for ToolError {}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct RuntimeError {
|
|
message: String,
|
|
}
|
|
|
|
impl RuntimeError {
|
|
#[must_use]
|
|
pub fn new(message: impl Into<String>) -> Self {
|
|
Self {
|
|
message: message.into(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Display for RuntimeError {
|
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "{}", self.message)
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for RuntimeError {}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct TurnSummary {
|
|
pub assistant_messages: Vec<ConversationMessage>,
|
|
pub tool_results: Vec<ConversationMessage>,
|
|
pub iterations: usize,
|
|
pub usage: TokenUsage,
|
|
}
|
|
|
|
pub struct ConversationRuntime<C, T> {
|
|
session: Session,
|
|
api_client: C,
|
|
tool_executor: T,
|
|
permission_policy: PermissionPolicy,
|
|
system_prompt: Vec<String>,
|
|
max_iterations: usize,
|
|
usage_tracker: UsageTracker,
|
|
}
|
|
|
|
impl<C, T> ConversationRuntime<C, T>
|
|
where
|
|
C: ApiClient,
|
|
T: ToolExecutor,
|
|
{
|
|
#[must_use]
|
|
pub fn new(
|
|
session: Session,
|
|
api_client: C,
|
|
tool_executor: T,
|
|
permission_policy: PermissionPolicy,
|
|
system_prompt: Vec<String>,
|
|
) -> Self {
|
|
let usage_tracker = UsageTracker::from_session(&session);
|
|
Self {
|
|
session,
|
|
api_client,
|
|
tool_executor,
|
|
permission_policy,
|
|
system_prompt,
|
|
max_iterations: 16,
|
|
usage_tracker,
|
|
}
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
|
|
self.max_iterations = max_iterations;
|
|
self
|
|
}
|
|
|
|
pub fn run_turn(
|
|
&mut self,
|
|
user_input: impl Into<String>,
|
|
mut prompter: Option<&mut dyn PermissionPrompter>,
|
|
) -> Result<TurnSummary, RuntimeError> {
|
|
self.session
|
|
.messages
|
|
.push(ConversationMessage::user_text(user_input.into()));
|
|
|
|
let mut assistant_messages = Vec::new();
|
|
let mut tool_results = Vec::new();
|
|
let mut iterations = 0;
|
|
|
|
loop {
|
|
iterations += 1;
|
|
if iterations > self.max_iterations {
|
|
return Err(RuntimeError::new(
|
|
"conversation loop exceeded the maximum number of iterations",
|
|
));
|
|
}
|
|
|
|
let request = ApiRequest {
|
|
system_prompt: self.system_prompt.clone(),
|
|
messages: self.session.messages.clone(),
|
|
};
|
|
let events = self.api_client.stream(request)?;
|
|
let (assistant_message, usage) = build_assistant_message(events)?;
|
|
if let Some(usage) = usage {
|
|
self.usage_tracker.record(usage);
|
|
}
|
|
let pending_tool_uses = assistant_message
|
|
.blocks
|
|
.iter()
|
|
.filter_map(|block| match block {
|
|
ContentBlock::ToolUse { id, name, input } => {
|
|
Some((id.clone(), name.clone(), input.clone()))
|
|
}
|
|
_ => None,
|
|
})
|
|
.collect::<Vec<_>>();
|
|
|
|
self.session.messages.push(assistant_message.clone());
|
|
assistant_messages.push(assistant_message);
|
|
|
|
if pending_tool_uses.is_empty() {
|
|
break;
|
|
}
|
|
|
|
for (tool_use_id, tool_name, input) in pending_tool_uses {
|
|
let permission_outcome = if let Some(prompt) = prompter.as_mut() {
|
|
self.permission_policy
|
|
.authorize(&tool_name, &input, Some(*prompt))
|
|
} else {
|
|
self.permission_policy.authorize(&tool_name, &input, None)
|
|
};
|
|
|
|
let result_message = match permission_outcome {
|
|
PermissionOutcome::Allow => {
|
|
match self.tool_executor.execute(&tool_name, &input) {
|
|
Ok(output) => ConversationMessage::tool_result(
|
|
tool_use_id,
|
|
tool_name,
|
|
output,
|
|
false,
|
|
),
|
|
Err(error) => ConversationMessage::tool_result(
|
|
tool_use_id,
|
|
tool_name,
|
|
error.to_string(),
|
|
true,
|
|
),
|
|
}
|
|
}
|
|
PermissionOutcome::Deny { reason } => {
|
|
ConversationMessage::tool_result(tool_use_id, tool_name, reason, true)
|
|
}
|
|
};
|
|
self.session.messages.push(result_message.clone());
|
|
tool_results.push(result_message);
|
|
}
|
|
}
|
|
|
|
Ok(TurnSummary {
|
|
assistant_messages,
|
|
tool_results,
|
|
iterations,
|
|
usage: self.usage_tracker.cumulative_usage(),
|
|
})
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn compact(&self, config: CompactionConfig) -> CompactionResult {
|
|
compact_session(&self.session, config)
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn estimated_tokens(&self) -> usize {
|
|
estimate_session_tokens(&self.session)
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn usage(&self) -> &UsageTracker {
|
|
&self.usage_tracker
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn session(&self) -> &Session {
|
|
&self.session
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn into_session(self) -> Session {
|
|
self.session
|
|
}
|
|
}
|
|
|
|
fn build_assistant_message(
|
|
events: Vec<AssistantEvent>,
|
|
) -> Result<(ConversationMessage, Option<TokenUsage>), RuntimeError> {
|
|
let mut text = String::new();
|
|
let mut blocks = Vec::new();
|
|
let mut finished = false;
|
|
let mut usage = None;
|
|
|
|
for event in events {
|
|
match event {
|
|
AssistantEvent::TextDelta(delta) => text.push_str(&delta),
|
|
AssistantEvent::ToolUse { id, name, input } => {
|
|
flush_text_block(&mut text, &mut blocks);
|
|
blocks.push(ContentBlock::ToolUse { id, name, input });
|
|
}
|
|
AssistantEvent::Usage(value) => usage = Some(value),
|
|
AssistantEvent::MessageStop => {
|
|
finished = true;
|
|
}
|
|
}
|
|
}
|
|
|
|
flush_text_block(&mut text, &mut blocks);
|
|
|
|
if !finished {
|
|
return Err(RuntimeError::new(
|
|
"assistant stream ended without a message stop event",
|
|
));
|
|
}
|
|
if blocks.is_empty() {
|
|
return Err(RuntimeError::new("assistant stream produced no content"));
|
|
}
|
|
|
|
Ok((
|
|
ConversationMessage::assistant_with_usage(blocks, usage),
|
|
usage,
|
|
))
|
|
}
|
|
|
|
fn flush_text_block(text: &mut String, blocks: &mut Vec<ContentBlock>) {
|
|
if !text.is_empty() {
|
|
blocks.push(ContentBlock::Text {
|
|
text: std::mem::take(text),
|
|
});
|
|
}
|
|
}
|
|
|
|
type ToolHandler = Box<dyn FnMut(&str) -> Result<String, ToolError>>;
|
|
|
|
#[derive(Default)]
|
|
pub struct StaticToolExecutor {
|
|
handlers: BTreeMap<String, ToolHandler>,
|
|
}
|
|
|
|
impl StaticToolExecutor {
|
|
#[must_use]
|
|
pub fn new() -> Self {
|
|
Self::default()
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn register(
|
|
mut self,
|
|
tool_name: impl Into<String>,
|
|
handler: impl FnMut(&str) -> Result<String, ToolError> + 'static,
|
|
) -> Self {
|
|
self.handlers.insert(tool_name.into(), Box::new(handler));
|
|
self
|
|
}
|
|
}
|
|
|
|
impl ToolExecutor for StaticToolExecutor {
|
|
fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError> {
|
|
self.handlers
|
|
.get_mut(tool_name)
|
|
.ok_or_else(|| ToolError::new(format!("unknown tool: {tool_name}")))?(input)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::{
|
|
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError,
|
|
StaticToolExecutor,
|
|
};
|
|
use crate::compact::CompactionConfig;
|
|
use crate::permissions::{
|
|
PermissionMode, PermissionPolicy, PermissionPromptDecision, PermissionPrompter,
|
|
PermissionRequest,
|
|
};
|
|
use crate::prompt::{ProjectContext, SystemPromptBuilder};
|
|
use crate::session::{ContentBlock, MessageRole, Session};
|
|
use crate::usage::TokenUsage;
|
|
use std::path::PathBuf;
|
|
|
|
struct ScriptedApiClient {
|
|
call_count: usize,
|
|
}
|
|
|
|
impl ApiClient for ScriptedApiClient {
|
|
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
self.call_count += 1;
|
|
match self.call_count {
|
|
1 => {
|
|
assert!(request
|
|
.messages
|
|
.iter()
|
|
.any(|message| message.role == MessageRole::User));
|
|
Ok(vec![
|
|
AssistantEvent::TextDelta("Let me calculate that.".to_string()),
|
|
AssistantEvent::ToolUse {
|
|
id: "tool-1".to_string(),
|
|
name: "add".to_string(),
|
|
input: "2,2".to_string(),
|
|
},
|
|
AssistantEvent::Usage(TokenUsage {
|
|
input_tokens: 20,
|
|
output_tokens: 6,
|
|
cache_creation_input_tokens: 1,
|
|
cache_read_input_tokens: 2,
|
|
}),
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
2 => {
|
|
let last_message = request
|
|
.messages
|
|
.last()
|
|
.expect("tool result should be present");
|
|
assert_eq!(last_message.role, MessageRole::Tool);
|
|
Ok(vec![
|
|
AssistantEvent::TextDelta("The answer is 4.".to_string()),
|
|
AssistantEvent::Usage(TokenUsage {
|
|
input_tokens: 24,
|
|
output_tokens: 4,
|
|
cache_creation_input_tokens: 1,
|
|
cache_read_input_tokens: 3,
|
|
}),
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
_ => Err(RuntimeError::new("unexpected extra API call")),
|
|
}
|
|
}
|
|
}
|
|
|
|
struct PromptAllowOnce;
|
|
|
|
impl PermissionPrompter for PromptAllowOnce {
|
|
fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision {
|
|
assert_eq!(request.tool_name, "add");
|
|
PermissionPromptDecision::Allow
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn runs_user_to_tool_to_result_loop_end_to_end_and_tracks_usage() {
|
|
let api_client = ScriptedApiClient { call_count: 0 };
|
|
let tool_executor = StaticToolExecutor::new().register("add", |input| {
|
|
let total = input
|
|
.split(',')
|
|
.map(|part| part.parse::<i32>().expect("input must be valid integer"))
|
|
.sum::<i32>();
|
|
Ok(total.to_string())
|
|
});
|
|
let permission_policy = PermissionPolicy::new(PermissionMode::Prompt);
|
|
let system_prompt = SystemPromptBuilder::new()
|
|
.with_project_context(ProjectContext {
|
|
cwd: PathBuf::from("/tmp/project"),
|
|
current_date: "2026-03-31".to_string(),
|
|
git_status: None,
|
|
instruction_files: Vec::new(),
|
|
})
|
|
.with_os("linux", "6.8")
|
|
.build();
|
|
let mut runtime = ConversationRuntime::new(
|
|
Session::new(),
|
|
api_client,
|
|
tool_executor,
|
|
permission_policy,
|
|
system_prompt,
|
|
);
|
|
|
|
let summary = runtime
|
|
.run_turn("what is 2 + 2?", Some(&mut PromptAllowOnce))
|
|
.expect("conversation loop should succeed");
|
|
|
|
assert_eq!(summary.iterations, 2);
|
|
assert_eq!(summary.assistant_messages.len(), 2);
|
|
assert_eq!(summary.tool_results.len(), 1);
|
|
assert_eq!(runtime.session().messages.len(), 4);
|
|
assert_eq!(summary.usage.output_tokens, 10);
|
|
assert!(matches!(
|
|
runtime.session().messages[1].blocks[1],
|
|
ContentBlock::ToolUse { .. }
|
|
));
|
|
assert!(matches!(
|
|
runtime.session().messages[2].blocks[0],
|
|
ContentBlock::ToolResult {
|
|
is_error: false,
|
|
..
|
|
}
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn records_denied_tool_results_when_prompt_rejects() {
|
|
struct RejectPrompter;
|
|
impl PermissionPrompter for RejectPrompter {
|
|
fn decide(&mut self, _request: &PermissionRequest) -> PermissionPromptDecision {
|
|
PermissionPromptDecision::Deny {
|
|
reason: "not now".to_string(),
|
|
}
|
|
}
|
|
}
|
|
|
|
struct SingleCallApiClient;
|
|
impl ApiClient for SingleCallApiClient {
|
|
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
if request
|
|
.messages
|
|
.iter()
|
|
.any(|message| message.role == MessageRole::Tool)
|
|
{
|
|
return Ok(vec![
|
|
AssistantEvent::TextDelta("I could not use the tool.".to_string()),
|
|
AssistantEvent::MessageStop,
|
|
]);
|
|
}
|
|
Ok(vec![
|
|
AssistantEvent::ToolUse {
|
|
id: "tool-1".to_string(),
|
|
name: "blocked".to_string(),
|
|
input: "secret".to_string(),
|
|
},
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
}
|
|
|
|
let mut runtime = ConversationRuntime::new(
|
|
Session::new(),
|
|
SingleCallApiClient,
|
|
StaticToolExecutor::new(),
|
|
PermissionPolicy::new(PermissionMode::Prompt),
|
|
vec!["system".to_string()],
|
|
);
|
|
|
|
let summary = runtime
|
|
.run_turn("use the tool", Some(&mut RejectPrompter))
|
|
.expect("conversation should continue after denied tool");
|
|
|
|
assert_eq!(summary.tool_results.len(), 1);
|
|
assert!(matches!(
|
|
&summary.tool_results[0].blocks[0],
|
|
ContentBlock::ToolResult { is_error: true, output, .. } if output == "not now"
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn reconstructs_usage_tracker_from_restored_session() {
|
|
struct SimpleApi;
|
|
impl ApiClient for SimpleApi {
|
|
fn stream(
|
|
&mut self,
|
|
_request: ApiRequest,
|
|
) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
Ok(vec![
|
|
AssistantEvent::TextDelta("done".to_string()),
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
}
|
|
|
|
let mut session = Session::new();
|
|
session
|
|
.messages
|
|
.push(crate::session::ConversationMessage::assistant_with_usage(
|
|
vec![ContentBlock::Text {
|
|
text: "earlier".to_string(),
|
|
}],
|
|
Some(TokenUsage {
|
|
input_tokens: 11,
|
|
output_tokens: 7,
|
|
cache_creation_input_tokens: 2,
|
|
cache_read_input_tokens: 1,
|
|
}),
|
|
));
|
|
|
|
let runtime = ConversationRuntime::new(
|
|
session,
|
|
SimpleApi,
|
|
StaticToolExecutor::new(),
|
|
PermissionPolicy::new(PermissionMode::Allow),
|
|
vec!["system".to_string()],
|
|
);
|
|
|
|
assert_eq!(runtime.usage().turns(), 1);
|
|
assert_eq!(runtime.usage().cumulative_usage().total_tokens(), 21);
|
|
}
|
|
|
|
#[test]
|
|
fn compacts_session_after_turns() {
|
|
struct SimpleApi;
|
|
impl ApiClient for SimpleApi {
|
|
fn stream(
|
|
&mut self,
|
|
_request: ApiRequest,
|
|
) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
Ok(vec![
|
|
AssistantEvent::TextDelta("done".to_string()),
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
}
|
|
|
|
let mut runtime = ConversationRuntime::new(
|
|
Session::new(),
|
|
SimpleApi,
|
|
StaticToolExecutor::new(),
|
|
PermissionPolicy::new(PermissionMode::Allow),
|
|
vec!["system".to_string()],
|
|
);
|
|
runtime.run_turn("a", None).expect("turn a");
|
|
runtime.run_turn("b", None).expect("turn b");
|
|
runtime.run_turn("c", None).expect("turn c");
|
|
|
|
let result = runtime.compact(CompactionConfig {
|
|
preserve_recent_messages: 2,
|
|
max_estimated_tokens: 1,
|
|
});
|
|
assert!(result.summary.contains("Conversation summary"));
|
|
assert_eq!(
|
|
result.compacted_session.messages[0].role,
|
|
MessageRole::System
|
|
);
|
|
}
|
|
}
|