994 lines
33 KiB
Rust
994 lines
33 KiB
Rust
use std::collections::BTreeMap;
|
|
use std::fmt::{Display, Formatter};
|
|
|
|
use serde_json::{Map, Value};
|
|
use telemetry::SessionTracer;
|
|
|
|
use crate::compact::{
|
|
compact_session, estimate_session_tokens, CompactionConfig, CompactionResult,
|
|
};
|
|
use crate::config::RuntimeFeatureConfig;
|
|
use crate::hooks::{HookRunResult, HookRunner};
|
|
use crate::permissions::{PermissionOutcome, PermissionPolicy, PermissionPrompter};
|
|
use crate::session::{ContentBlock, ConversationMessage, Session};
|
|
use crate::usage::{TokenUsage, UsageTracker};
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct ApiRequest {
|
|
pub system_prompt: Vec<String>,
|
|
pub messages: Vec<ConversationMessage>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub enum AssistantEvent {
|
|
TextDelta(String),
|
|
ToolUse {
|
|
id: String,
|
|
name: String,
|
|
input: String,
|
|
},
|
|
Usage(TokenUsage),
|
|
MessageStop,
|
|
}
|
|
|
|
pub trait ApiClient {
|
|
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>;
|
|
}
|
|
|
|
pub trait ToolExecutor {
|
|
fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError>;
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct ToolError {
|
|
message: String,
|
|
}
|
|
|
|
impl ToolError {
|
|
#[must_use]
|
|
pub fn new(message: impl Into<String>) -> Self {
|
|
Self {
|
|
message: message.into(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Display for ToolError {
|
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "{}", self.message)
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for ToolError {}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct RuntimeError {
|
|
message: String,
|
|
}
|
|
|
|
impl RuntimeError {
|
|
#[must_use]
|
|
pub fn new(message: impl Into<String>) -> Self {
|
|
Self {
|
|
message: message.into(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Display for RuntimeError {
|
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "{}", self.message)
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for RuntimeError {}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct TurnSummary {
|
|
pub assistant_messages: Vec<ConversationMessage>,
|
|
pub tool_results: Vec<ConversationMessage>,
|
|
pub iterations: usize,
|
|
pub usage: TokenUsage,
|
|
}
|
|
|
|
pub struct ConversationRuntime<C, T> {
|
|
session: Session,
|
|
api_client: C,
|
|
tool_executor: T,
|
|
permission_policy: PermissionPolicy,
|
|
system_prompt: Vec<String>,
|
|
max_iterations: usize,
|
|
usage_tracker: UsageTracker,
|
|
hook_runner: HookRunner,
|
|
session_tracer: Option<SessionTracer>,
|
|
}
|
|
|
|
impl<C, T> ConversationRuntime<C, T>
|
|
where
|
|
C: ApiClient,
|
|
T: ToolExecutor,
|
|
{
|
|
#[must_use]
|
|
pub fn new(
|
|
session: Session,
|
|
api_client: C,
|
|
tool_executor: T,
|
|
permission_policy: PermissionPolicy,
|
|
system_prompt: Vec<String>,
|
|
) -> Self {
|
|
Self::new_with_features(
|
|
session,
|
|
api_client,
|
|
tool_executor,
|
|
permission_policy,
|
|
system_prompt,
|
|
&RuntimeFeatureConfig::default(),
|
|
)
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn new_with_features(
|
|
session: Session,
|
|
api_client: C,
|
|
tool_executor: T,
|
|
permission_policy: PermissionPolicy,
|
|
system_prompt: Vec<String>,
|
|
feature_config: &RuntimeFeatureConfig,
|
|
) -> Self {
|
|
let usage_tracker = UsageTracker::from_session(&session);
|
|
Self {
|
|
session,
|
|
api_client,
|
|
tool_executor,
|
|
permission_policy,
|
|
system_prompt,
|
|
max_iterations: usize::MAX,
|
|
usage_tracker,
|
|
hook_runner: HookRunner::from_feature_config(feature_config),
|
|
session_tracer: None,
|
|
}
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
|
|
self.max_iterations = max_iterations;
|
|
self
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn with_session_tracer(mut self, session_tracer: SessionTracer) -> Self {
|
|
self.session_tracer = Some(session_tracer);
|
|
self
|
|
}
|
|
|
|
#[allow(clippy::too_many_lines)]
|
|
pub fn run_turn(
|
|
&mut self,
|
|
user_input: impl Into<String>,
|
|
mut prompter: Option<&mut dyn PermissionPrompter>,
|
|
) -> Result<TurnSummary, RuntimeError> {
|
|
let user_input = user_input.into();
|
|
self.record_turn_started(&user_input);
|
|
self.session
|
|
.messages
|
|
.push(ConversationMessage::user_text(user_input));
|
|
|
|
let mut assistant_messages = Vec::new();
|
|
let mut tool_results = Vec::new();
|
|
let mut iterations = 0;
|
|
|
|
loop {
|
|
iterations += 1;
|
|
if iterations > self.max_iterations {
|
|
let error = RuntimeError::new(
|
|
"conversation loop exceeded the maximum number of iterations",
|
|
);
|
|
self.record_turn_failed(iterations, &error);
|
|
return Err(error);
|
|
}
|
|
|
|
let request = ApiRequest {
|
|
system_prompt: self.system_prompt.clone(),
|
|
messages: self.session.messages.clone(),
|
|
};
|
|
let events = match self.api_client.stream(request) {
|
|
Ok(events) => events,
|
|
Err(error) => {
|
|
self.record_turn_failed(iterations, &error);
|
|
return Err(error);
|
|
}
|
|
};
|
|
let (assistant_message, usage) = match build_assistant_message(events) {
|
|
Ok(result) => result,
|
|
Err(error) => {
|
|
self.record_turn_failed(iterations, &error);
|
|
return Err(error);
|
|
}
|
|
};
|
|
if let Some(usage) = usage {
|
|
self.usage_tracker.record(usage);
|
|
}
|
|
let pending_tool_uses = assistant_message
|
|
.blocks
|
|
.iter()
|
|
.filter_map(|block| match block {
|
|
ContentBlock::ToolUse { id, name, input } => {
|
|
Some((id.clone(), name.clone(), input.clone()))
|
|
}
|
|
_ => None,
|
|
})
|
|
.collect::<Vec<_>>();
|
|
self.record_assistant_iteration(
|
|
iterations,
|
|
&assistant_message,
|
|
pending_tool_uses.len(),
|
|
);
|
|
|
|
self.session.messages.push(assistant_message.clone());
|
|
assistant_messages.push(assistant_message);
|
|
|
|
if pending_tool_uses.is_empty() {
|
|
break;
|
|
}
|
|
|
|
for (tool_use_id, tool_name, input) in pending_tool_uses {
|
|
self.record_tool_started(iterations, &tool_name);
|
|
let permission_outcome = if let Some(prompt) = prompter.as_mut() {
|
|
self.permission_policy
|
|
.authorize(&tool_name, &input, Some(*prompt))
|
|
} else {
|
|
self.permission_policy.authorize(&tool_name, &input, None)
|
|
};
|
|
|
|
let result_message = match permission_outcome {
|
|
PermissionOutcome::Allow => {
|
|
let pre_hook_result = self.hook_runner.run_pre_tool_use(&tool_name, &input);
|
|
if pre_hook_result.is_denied() {
|
|
let deny_message = format!("PreToolUse hook denied tool `{tool_name}`");
|
|
ConversationMessage::tool_result(
|
|
tool_use_id,
|
|
tool_name,
|
|
format_hook_message(&pre_hook_result, &deny_message),
|
|
true,
|
|
)
|
|
} else {
|
|
let (mut output, mut is_error) =
|
|
match self.tool_executor.execute(&tool_name, &input) {
|
|
Ok(output) => (output, false),
|
|
Err(error) => (error.to_string(), true),
|
|
};
|
|
output = merge_hook_feedback(pre_hook_result.messages(), output, false);
|
|
|
|
let post_hook_result = self
|
|
.hook_runner
|
|
.run_post_tool_use(&tool_name, &input, &output, is_error);
|
|
if post_hook_result.is_denied() {
|
|
is_error = true;
|
|
}
|
|
output = merge_hook_feedback(
|
|
post_hook_result.messages(),
|
|
output,
|
|
post_hook_result.is_denied(),
|
|
);
|
|
|
|
ConversationMessage::tool_result(
|
|
tool_use_id,
|
|
tool_name,
|
|
output,
|
|
is_error,
|
|
)
|
|
}
|
|
}
|
|
PermissionOutcome::Deny { reason } => {
|
|
ConversationMessage::tool_result(tool_use_id, tool_name, reason, true)
|
|
}
|
|
};
|
|
self.record_tool_finished(iterations, &result_message);
|
|
self.session.messages.push(result_message.clone());
|
|
tool_results.push(result_message);
|
|
}
|
|
}
|
|
|
|
let summary = TurnSummary {
|
|
assistant_messages,
|
|
tool_results,
|
|
iterations,
|
|
usage: self.usage_tracker.cumulative_usage(),
|
|
};
|
|
self.record_turn_completed(&summary);
|
|
Ok(summary)
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn compact(&self, config: CompactionConfig) -> CompactionResult {
|
|
compact_session(&self.session, config)
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn estimated_tokens(&self) -> usize {
|
|
estimate_session_tokens(&self.session)
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn usage(&self) -> &UsageTracker {
|
|
&self.usage_tracker
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn session(&self) -> &Session {
|
|
&self.session
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn into_session(self) -> Session {
|
|
self.session
|
|
}
|
|
|
|
fn record_turn_started(&self, user_input: &str) {
|
|
if let Some(tracer) = &self.session_tracer {
|
|
let mut attributes = Map::new();
|
|
attributes.insert(
|
|
"message_count_before".to_string(),
|
|
Value::from(u64::try_from(self.session.messages.len()).unwrap_or(u64::MAX)),
|
|
);
|
|
attributes.insert(
|
|
"input_chars".to_string(),
|
|
Value::from(u64::try_from(user_input.chars().count()).unwrap_or(u64::MAX)),
|
|
);
|
|
tracer.record("turn_started", attributes);
|
|
}
|
|
}
|
|
|
|
fn record_assistant_iteration(
|
|
&self,
|
|
iteration: usize,
|
|
assistant_message: &ConversationMessage,
|
|
pending_tool_count: usize,
|
|
) {
|
|
if let Some(tracer) = &self.session_tracer {
|
|
let mut attributes = Map::new();
|
|
attributes.insert(
|
|
"iteration".to_string(),
|
|
Value::from(u64::try_from(iteration).unwrap_or(u64::MAX)),
|
|
);
|
|
attributes.insert(
|
|
"block_count".to_string(),
|
|
Value::from(u64::try_from(assistant_message.blocks.len()).unwrap_or(u64::MAX)),
|
|
);
|
|
attributes.insert(
|
|
"pending_tool_count".to_string(),
|
|
Value::from(u64::try_from(pending_tool_count).unwrap_or(u64::MAX)),
|
|
);
|
|
tracer.record("assistant_iteration_completed", attributes);
|
|
}
|
|
}
|
|
|
|
fn record_tool_started(&self, iteration: usize, tool_name: &str) {
|
|
if let Some(tracer) = &self.session_tracer {
|
|
let mut attributes = Map::new();
|
|
attributes.insert(
|
|
"iteration".to_string(),
|
|
Value::from(u64::try_from(iteration).unwrap_or(u64::MAX)),
|
|
);
|
|
attributes.insert(
|
|
"tool_name".to_string(),
|
|
Value::String(tool_name.to_string()),
|
|
);
|
|
tracer.record("tool_execution_started", attributes);
|
|
}
|
|
}
|
|
|
|
fn record_tool_finished(&self, iteration: usize, result_message: &ConversationMessage) {
|
|
let Some(tracer) = &self.session_tracer else {
|
|
return;
|
|
};
|
|
let Some(ContentBlock::ToolResult {
|
|
tool_name,
|
|
is_error,
|
|
output,
|
|
..
|
|
}) = result_message.blocks.first()
|
|
else {
|
|
return;
|
|
};
|
|
let mut attributes = Map::new();
|
|
attributes.insert(
|
|
"iteration".to_string(),
|
|
Value::from(u64::try_from(iteration).unwrap_or(u64::MAX)),
|
|
);
|
|
attributes.insert("tool_name".to_string(), Value::String(tool_name.clone()));
|
|
attributes.insert("is_error".to_string(), Value::Bool(*is_error));
|
|
attributes.insert(
|
|
"output_chars".to_string(),
|
|
Value::from(u64::try_from(output.chars().count()).unwrap_or(u64::MAX)),
|
|
);
|
|
tracer.record("tool_execution_finished", attributes);
|
|
}
|
|
|
|
fn record_turn_completed(&self, summary: &TurnSummary) {
|
|
if let Some(tracer) = &self.session_tracer {
|
|
let mut attributes = Map::new();
|
|
attributes.insert(
|
|
"assistant_message_count".to_string(),
|
|
Value::from(u64::try_from(summary.assistant_messages.len()).unwrap_or(u64::MAX)),
|
|
);
|
|
attributes.insert(
|
|
"tool_result_count".to_string(),
|
|
Value::from(u64::try_from(summary.tool_results.len()).unwrap_or(u64::MAX)),
|
|
);
|
|
attributes.insert(
|
|
"iterations".to_string(),
|
|
Value::from(u64::try_from(summary.iterations).unwrap_or(u64::MAX)),
|
|
);
|
|
attributes.insert(
|
|
"total_input_tokens".to_string(),
|
|
Value::from(summary.usage.input_tokens),
|
|
);
|
|
attributes.insert(
|
|
"total_output_tokens".to_string(),
|
|
Value::from(summary.usage.output_tokens),
|
|
);
|
|
tracer.record("turn_completed", attributes);
|
|
}
|
|
}
|
|
|
|
fn record_turn_failed(&self, iteration: usize, error: &RuntimeError) {
|
|
if let Some(tracer) = &self.session_tracer {
|
|
let mut attributes = Map::new();
|
|
attributes.insert(
|
|
"iteration".to_string(),
|
|
Value::from(u64::try_from(iteration).unwrap_or(u64::MAX)),
|
|
);
|
|
attributes.insert("error".to_string(), Value::String(error.to_string()));
|
|
tracer.record("turn_failed", attributes);
|
|
}
|
|
}
|
|
}
|
|
|
|
fn build_assistant_message(
|
|
events: Vec<AssistantEvent>,
|
|
) -> Result<(ConversationMessage, Option<TokenUsage>), RuntimeError> {
|
|
let mut text = String::new();
|
|
let mut blocks = Vec::new();
|
|
let mut finished = false;
|
|
let mut usage = None;
|
|
|
|
for event in events {
|
|
match event {
|
|
AssistantEvent::TextDelta(delta) => text.push_str(&delta),
|
|
AssistantEvent::ToolUse { id, name, input } => {
|
|
flush_text_block(&mut text, &mut blocks);
|
|
blocks.push(ContentBlock::ToolUse { id, name, input });
|
|
}
|
|
AssistantEvent::Usage(value) => usage = Some(value),
|
|
AssistantEvent::MessageStop => {
|
|
finished = true;
|
|
}
|
|
}
|
|
}
|
|
|
|
flush_text_block(&mut text, &mut blocks);
|
|
|
|
if !finished {
|
|
return Err(RuntimeError::new(
|
|
"assistant stream ended without a message stop event",
|
|
));
|
|
}
|
|
if blocks.is_empty() {
|
|
return Err(RuntimeError::new("assistant stream produced no content"));
|
|
}
|
|
|
|
Ok((
|
|
ConversationMessage::assistant_with_usage(blocks, usage),
|
|
usage,
|
|
))
|
|
}
|
|
|
|
fn flush_text_block(text: &mut String, blocks: &mut Vec<ContentBlock>) {
|
|
if !text.is_empty() {
|
|
blocks.push(ContentBlock::Text {
|
|
text: std::mem::take(text),
|
|
});
|
|
}
|
|
}
|
|
|
|
fn format_hook_message(result: &HookRunResult, fallback: &str) -> String {
|
|
if result.messages().is_empty() {
|
|
fallback.to_string()
|
|
} else {
|
|
result.messages().join("\n")
|
|
}
|
|
}
|
|
|
|
fn merge_hook_feedback(messages: &[String], output: String, denied: bool) -> String {
|
|
if messages.is_empty() {
|
|
return output;
|
|
}
|
|
|
|
let mut sections = Vec::new();
|
|
if !output.trim().is_empty() {
|
|
sections.push(output);
|
|
}
|
|
let label = if denied {
|
|
"Hook feedback (denied)"
|
|
} else {
|
|
"Hook feedback"
|
|
};
|
|
sections.push(format!("{label}:\n{}", messages.join("\n")));
|
|
sections.join("\n\n")
|
|
}
|
|
|
|
type ToolHandler = Box<dyn FnMut(&str) -> Result<String, ToolError>>;
|
|
|
|
#[derive(Default)]
|
|
pub struct StaticToolExecutor {
|
|
handlers: BTreeMap<String, ToolHandler>,
|
|
}
|
|
|
|
impl StaticToolExecutor {
|
|
#[must_use]
|
|
pub fn new() -> Self {
|
|
Self::default()
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn register(
|
|
mut self,
|
|
tool_name: impl Into<String>,
|
|
handler: impl FnMut(&str) -> Result<String, ToolError> + 'static,
|
|
) -> Self {
|
|
self.handlers.insert(tool_name.into(), Box::new(handler));
|
|
self
|
|
}
|
|
}
|
|
|
|
impl ToolExecutor for StaticToolExecutor {
|
|
fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError> {
|
|
self.handlers
|
|
.get_mut(tool_name)
|
|
.ok_or_else(|| ToolError::new(format!("unknown tool: {tool_name}")))?(input)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::{
|
|
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError,
|
|
StaticToolExecutor,
|
|
};
|
|
use crate::compact::CompactionConfig;
|
|
use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
|
|
use crate::permissions::{
|
|
PermissionMode, PermissionPolicy, PermissionPromptDecision, PermissionPrompter,
|
|
PermissionRequest,
|
|
};
|
|
use crate::prompt::{ProjectContext, SystemPromptBuilder};
|
|
use crate::session::{ContentBlock, MessageRole, Session};
|
|
use crate::usage::TokenUsage;
|
|
use std::path::PathBuf;
|
|
use std::sync::Arc;
|
|
use telemetry::{MemoryTelemetrySink, SessionTracer, TelemetryEvent};
|
|
|
|
struct ScriptedApiClient {
|
|
call_count: usize,
|
|
}
|
|
|
|
impl ApiClient for ScriptedApiClient {
|
|
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
self.call_count += 1;
|
|
match self.call_count {
|
|
1 => {
|
|
assert!(request
|
|
.messages
|
|
.iter()
|
|
.any(|message| message.role == MessageRole::User));
|
|
Ok(vec![
|
|
AssistantEvent::TextDelta("Let me calculate that.".to_string()),
|
|
AssistantEvent::ToolUse {
|
|
id: "tool-1".to_string(),
|
|
name: "add".to_string(),
|
|
input: "2,2".to_string(),
|
|
},
|
|
AssistantEvent::Usage(TokenUsage {
|
|
input_tokens: 20,
|
|
output_tokens: 6,
|
|
cache_creation_input_tokens: 1,
|
|
cache_read_input_tokens: 2,
|
|
}),
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
2 => {
|
|
let last_message = request
|
|
.messages
|
|
.last()
|
|
.expect("tool result should be present");
|
|
assert_eq!(last_message.role, MessageRole::Tool);
|
|
Ok(vec![
|
|
AssistantEvent::TextDelta("The answer is 4.".to_string()),
|
|
AssistantEvent::Usage(TokenUsage {
|
|
input_tokens: 24,
|
|
output_tokens: 4,
|
|
cache_creation_input_tokens: 1,
|
|
cache_read_input_tokens: 3,
|
|
}),
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
_ => Err(RuntimeError::new("unexpected extra API call")),
|
|
}
|
|
}
|
|
}
|
|
|
|
struct PromptAllowOnce;
|
|
|
|
impl PermissionPrompter for PromptAllowOnce {
|
|
fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision {
|
|
assert_eq!(request.tool_name, "add");
|
|
PermissionPromptDecision::Allow
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn runs_user_to_tool_to_result_loop_end_to_end_and_tracks_usage() {
|
|
let api_client = ScriptedApiClient { call_count: 0 };
|
|
let tool_executor = StaticToolExecutor::new().register("add", |input| {
|
|
let total = input
|
|
.split(',')
|
|
.map(|part| part.parse::<i32>().expect("input must be valid integer"))
|
|
.sum::<i32>();
|
|
Ok(total.to_string())
|
|
});
|
|
let permission_policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite);
|
|
let system_prompt = SystemPromptBuilder::new()
|
|
.with_project_context(ProjectContext {
|
|
cwd: PathBuf::from("/tmp/project"),
|
|
current_date: "2026-03-31".to_string(),
|
|
git_status: None,
|
|
git_diff: None,
|
|
instruction_files: Vec::new(),
|
|
})
|
|
.with_os("linux", "6.8")
|
|
.build();
|
|
let mut runtime = ConversationRuntime::new(
|
|
Session::new(),
|
|
api_client,
|
|
tool_executor,
|
|
permission_policy,
|
|
system_prompt,
|
|
);
|
|
|
|
let summary = runtime
|
|
.run_turn("what is 2 + 2?", Some(&mut PromptAllowOnce))
|
|
.expect("conversation loop should succeed");
|
|
|
|
assert_eq!(summary.iterations, 2);
|
|
assert_eq!(summary.assistant_messages.len(), 2);
|
|
assert_eq!(summary.tool_results.len(), 1);
|
|
assert_eq!(runtime.session().messages.len(), 4);
|
|
assert_eq!(summary.usage.output_tokens, 10);
|
|
assert!(matches!(
|
|
runtime.session().messages[1].blocks[1],
|
|
ContentBlock::ToolUse { .. }
|
|
));
|
|
assert!(matches!(
|
|
runtime.session().messages[2].blocks[0],
|
|
ContentBlock::ToolResult {
|
|
is_error: false,
|
|
..
|
|
}
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn records_runtime_session_trace_events() {
|
|
let sink = Arc::new(MemoryTelemetrySink::default());
|
|
let tracer = SessionTracer::new("session-runtime", sink.clone());
|
|
let mut runtime = ConversationRuntime::new(
|
|
Session::new(),
|
|
ScriptedApiClient { call_count: 0 },
|
|
StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
|
|
PermissionPolicy::new(PermissionMode::WorkspaceWrite),
|
|
vec!["system".to_string()],
|
|
)
|
|
.with_session_tracer(tracer);
|
|
|
|
runtime
|
|
.run_turn("what is 2 + 2?", Some(&mut PromptAllowOnce))
|
|
.expect("conversation loop should succeed");
|
|
|
|
let events = sink.events();
|
|
let trace_names = events
|
|
.iter()
|
|
.filter_map(|event| match event {
|
|
TelemetryEvent::SessionTrace(trace) => Some(trace.name.as_str()),
|
|
_ => None,
|
|
})
|
|
.collect::<Vec<_>>();
|
|
|
|
assert!(trace_names.contains(&"turn_started"));
|
|
assert!(trace_names.contains(&"assistant_iteration_completed"));
|
|
assert!(trace_names.contains(&"tool_execution_started"));
|
|
assert!(trace_names.contains(&"tool_execution_finished"));
|
|
assert!(trace_names.contains(&"turn_completed"));
|
|
}
|
|
|
|
#[test]
|
|
fn records_denied_tool_results_when_prompt_rejects() {
|
|
struct RejectPrompter;
|
|
impl PermissionPrompter for RejectPrompter {
|
|
fn decide(&mut self, _request: &PermissionRequest) -> PermissionPromptDecision {
|
|
PermissionPromptDecision::Deny {
|
|
reason: "not now".to_string(),
|
|
}
|
|
}
|
|
}
|
|
|
|
struct SingleCallApiClient;
|
|
impl ApiClient for SingleCallApiClient {
|
|
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
if request
|
|
.messages
|
|
.iter()
|
|
.any(|message| message.role == MessageRole::Tool)
|
|
{
|
|
return Ok(vec![
|
|
AssistantEvent::TextDelta("I could not use the tool.".to_string()),
|
|
AssistantEvent::MessageStop,
|
|
]);
|
|
}
|
|
Ok(vec![
|
|
AssistantEvent::ToolUse {
|
|
id: "tool-1".to_string(),
|
|
name: "blocked".to_string(),
|
|
input: "secret".to_string(),
|
|
},
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
}
|
|
|
|
let mut runtime = ConversationRuntime::new(
|
|
Session::new(),
|
|
SingleCallApiClient,
|
|
StaticToolExecutor::new(),
|
|
PermissionPolicy::new(PermissionMode::WorkspaceWrite),
|
|
vec!["system".to_string()],
|
|
);
|
|
|
|
let summary = runtime
|
|
.run_turn("use the tool", Some(&mut RejectPrompter))
|
|
.expect("conversation should continue after denied tool");
|
|
|
|
assert_eq!(summary.tool_results.len(), 1);
|
|
assert!(matches!(
|
|
&summary.tool_results[0].blocks[0],
|
|
ContentBlock::ToolResult { is_error: true, output, .. } if output == "not now"
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn denies_tool_use_when_pre_tool_hook_blocks() {
|
|
struct SingleCallApiClient;
|
|
impl ApiClient for SingleCallApiClient {
|
|
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
if request
|
|
.messages
|
|
.iter()
|
|
.any(|message| message.role == MessageRole::Tool)
|
|
{
|
|
return Ok(vec![
|
|
AssistantEvent::TextDelta("blocked".to_string()),
|
|
AssistantEvent::MessageStop,
|
|
]);
|
|
}
|
|
Ok(vec![
|
|
AssistantEvent::ToolUse {
|
|
id: "tool-1".to_string(),
|
|
name: "blocked".to_string(),
|
|
input: r#"{"path":"secret.txt"}"#.to_string(),
|
|
},
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
}
|
|
|
|
let mut runtime = ConversationRuntime::new_with_features(
|
|
Session::new(),
|
|
SingleCallApiClient,
|
|
StaticToolExecutor::new().register("blocked", |_input| {
|
|
panic!("tool should not execute when hook denies")
|
|
}),
|
|
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
|
vec!["system".to_string()],
|
|
&RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
|
|
vec![shell_snippet("printf 'blocked by hook'; exit 2")],
|
|
Vec::new(),
|
|
)),
|
|
);
|
|
|
|
let summary = runtime
|
|
.run_turn("use the tool", None)
|
|
.expect("conversation should continue after hook denial");
|
|
|
|
assert_eq!(summary.tool_results.len(), 1);
|
|
let ContentBlock::ToolResult {
|
|
is_error, output, ..
|
|
} = &summary.tool_results[0].blocks[0]
|
|
else {
|
|
panic!("expected tool result block");
|
|
};
|
|
assert!(
|
|
*is_error,
|
|
"hook denial should produce an error result: {output}"
|
|
);
|
|
assert!(
|
|
output.contains("denied tool") || output.contains("blocked by hook"),
|
|
"unexpected hook denial output: {output:?}"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn appends_post_tool_hook_feedback_to_tool_result() {
|
|
struct TwoCallApiClient {
|
|
calls: usize,
|
|
}
|
|
|
|
impl ApiClient for TwoCallApiClient {
|
|
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
self.calls += 1;
|
|
match self.calls {
|
|
1 => Ok(vec![
|
|
AssistantEvent::ToolUse {
|
|
id: "tool-1".to_string(),
|
|
name: "add".to_string(),
|
|
input: r#"{"lhs":2,"rhs":2}"#.to_string(),
|
|
},
|
|
AssistantEvent::MessageStop,
|
|
]),
|
|
2 => {
|
|
assert!(request
|
|
.messages
|
|
.iter()
|
|
.any(|message| message.role == MessageRole::Tool));
|
|
Ok(vec![
|
|
AssistantEvent::TextDelta("done".to_string()),
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
_ => Err(RuntimeError::new("unexpected extra API call")),
|
|
}
|
|
}
|
|
}
|
|
|
|
let mut runtime = ConversationRuntime::new_with_features(
|
|
Session::new(),
|
|
TwoCallApiClient { calls: 0 },
|
|
StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
|
|
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
|
vec!["system".to_string()],
|
|
&RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
|
|
vec![shell_snippet("printf 'pre hook ran'")],
|
|
vec![shell_snippet("printf 'post hook ran'")],
|
|
)),
|
|
);
|
|
|
|
let summary = runtime
|
|
.run_turn("use add", None)
|
|
.expect("tool loop succeeds");
|
|
|
|
assert_eq!(summary.tool_results.len(), 1);
|
|
let ContentBlock::ToolResult {
|
|
is_error, output, ..
|
|
} = &summary.tool_results[0].blocks[0]
|
|
else {
|
|
panic!("expected tool result block");
|
|
};
|
|
assert!(
|
|
!*is_error,
|
|
"post hook should preserve non-error result: {output:?}"
|
|
);
|
|
assert!(
|
|
output.contains('4'),
|
|
"tool output missing value: {output:?}"
|
|
);
|
|
assert!(
|
|
output.contains("pre hook ran"),
|
|
"tool output missing pre hook feedback: {output:?}"
|
|
);
|
|
assert!(
|
|
output.contains("post hook ran"),
|
|
"tool output missing post hook feedback: {output:?}"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn reconstructs_usage_tracker_from_restored_session() {
|
|
struct SimpleApi;
|
|
impl ApiClient for SimpleApi {
|
|
fn stream(
|
|
&mut self,
|
|
_request: ApiRequest,
|
|
) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
Ok(vec![
|
|
AssistantEvent::TextDelta("done".to_string()),
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
}
|
|
|
|
let mut session = Session::new();
|
|
session
|
|
.messages
|
|
.push(crate::session::ConversationMessage::assistant_with_usage(
|
|
vec![ContentBlock::Text {
|
|
text: "earlier".to_string(),
|
|
}],
|
|
Some(TokenUsage {
|
|
input_tokens: 11,
|
|
output_tokens: 7,
|
|
cache_creation_input_tokens: 2,
|
|
cache_read_input_tokens: 1,
|
|
}),
|
|
));
|
|
|
|
let runtime = ConversationRuntime::new(
|
|
session,
|
|
SimpleApi,
|
|
StaticToolExecutor::new(),
|
|
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
|
vec!["system".to_string()],
|
|
);
|
|
|
|
assert_eq!(runtime.usage().turns(), 1);
|
|
assert_eq!(runtime.usage().cumulative_usage().total_tokens(), 21);
|
|
}
|
|
|
|
#[test]
|
|
fn compacts_session_after_turns() {
|
|
struct SimpleApi;
|
|
impl ApiClient for SimpleApi {
|
|
fn stream(
|
|
&mut self,
|
|
_request: ApiRequest,
|
|
) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
Ok(vec![
|
|
AssistantEvent::TextDelta("done".to_string()),
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
}
|
|
|
|
let mut runtime = ConversationRuntime::new(
|
|
Session::new(),
|
|
SimpleApi,
|
|
StaticToolExecutor::new(),
|
|
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
|
vec!["system".to_string()],
|
|
);
|
|
runtime.run_turn("a", None).expect("turn a");
|
|
runtime.run_turn("b", None).expect("turn b");
|
|
runtime.run_turn("c", None).expect("turn c");
|
|
|
|
let result = runtime.compact(CompactionConfig {
|
|
preserve_recent_messages: 2,
|
|
max_estimated_tokens: 1,
|
|
});
|
|
assert!(result.summary.contains("Conversation summary"));
|
|
assert_eq!(
|
|
result.compacted_session.messages[0].role,
|
|
MessageRole::System
|
|
);
|
|
}
|
|
|
|
#[cfg(windows)]
|
|
fn shell_snippet(script: &str) -> String {
|
|
script.replace('\'', "\"")
|
|
}
|
|
|
|
#[cfg(not(windows))]
|
|
fn shell_snippet(script: &str) -> String {
|
|
script.to_string()
|
|
}
|
|
}
|