From 178934a9a0d0750b1926fa10289e1e45c3013d08 Mon Sep 17 00:00:00 2001 From: Yeachan-Heo Date: Wed, 1 Apr 2026 04:20:15 +0000 Subject: [PATCH] feat: grok provider tests + cargo fmt --- rust/crates/rusty-claude-cli/src/main.rs | 110 +++++++++++++++++++---- rust/crates/tools/src/lib.rs | 101 ++++++++++++++++----- 2 files changed, 172 insertions(+), 39 deletions(-) diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index 00ef7cd..dee54d9 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -2046,7 +2046,7 @@ impl ApiClient for ProviderRuntimeClient { let renderer = TerminalRenderer::new(); let mut markdown_stream = MarkdownStreamState::default(); let mut events = Vec::new(); - let mut pending_tool: Option<(String, String, String)> = None; + let mut pending_tools: BTreeMap = BTreeMap::new(); let mut saw_stop = false; while let Some(event) = stream @@ -2057,15 +2057,23 @@ impl ApiClient for ProviderRuntimeClient { match event { ApiStreamEvent::MessageStart(start) => { for block in start.message.content { - push_output_block(block, out, &mut events, &mut pending_tool, true)?; + push_output_block( + block, + 0, + out, + &mut events, + &mut pending_tools, + true, + )?; } } ApiStreamEvent::ContentBlockStart(start) => { push_output_block( start.content_block, + start.index, out, &mut events, - &mut pending_tool, + &mut pending_tools, true, )?; } @@ -2081,18 +2089,18 @@ impl ApiClient for ProviderRuntimeClient { } } ContentBlockDelta::InputJsonDelta { partial_json } => { - if let Some((_, _, input)) = &mut pending_tool { + if let Some((_, _, input)) = pending_tools.get_mut(&delta.index) { input.push_str(&partial_json); } } }, - ApiStreamEvent::ContentBlockStop(_) => { + ApiStreamEvent::ContentBlockStop(stop) => { if let Some(rendered) = markdown_stream.flush(&renderer) { write!(out, "{rendered}") .and_then(|()| out.flush()) .map_err(|error| RuntimeError::new(error.to_string()))?; } - if let Some((id, name, input)) = pending_tool.take() { + if let Some((id, name, input)) = pending_tools.remove(&stop.index) { // Display tool call now that input is fully accumulated writeln!(out, "\n{}", format_tool_call_start(&name, &input)) .and_then(|()| out.flush()) @@ -2556,9 +2564,10 @@ fn truncate_for_summary(value: &str, limit: usize) -> String { fn push_output_block( block: OutputContentBlock, + block_index: u32, out: &mut (impl Write + ?Sized), events: &mut Vec, - pending_tool: &mut Option<(String, String, String)>, + pending_tools: &mut BTreeMap, streaming_tool_input: bool, ) -> Result<(), RuntimeError> { match block { @@ -2583,7 +2592,7 @@ fn push_output_block( } else { input.to_string() }; - *pending_tool = Some((id, name, initial_input)); + pending_tools.insert(block_index, (id, name, initial_input)); } } Ok(()) @@ -2594,11 +2603,13 @@ fn response_to_events( out: &mut (impl Write + ?Sized), ) -> Result, RuntimeError> { let mut events = Vec::new(); - let mut pending_tool = None; + let mut pending_tools = BTreeMap::new(); - for block in response.content { - push_output_block(block, out, &mut events, &mut pending_tool, false)?; - if let Some((id, name, input)) = pending_tool.take() { + for (index, block) in response.content.into_iter().enumerate() { + let index = + u32::try_from(index).map_err(|_| RuntimeError::new("response block index overflow"))?; + push_output_block(block, index, out, &mut events, &mut pending_tools, false)?; + if let Some((id, name, input)) = pending_tools.remove(&index) { events.push(AssistantEvent::ToolUse { id, name, input }); } } @@ -2824,6 +2835,7 @@ mod tests { use api::{MessageResponse, OutputContentBlock, Usage}; use runtime::{AssistantEvent, ContentBlock, ConversationMessage, MessageRole, PermissionMode}; use serde_json::json; + use std::collections::BTreeMap; use std::path::PathBuf; #[test] @@ -3373,15 +3385,16 @@ mod tests { fn push_output_block_renders_markdown_text() { let mut out = Vec::new(); let mut events = Vec::new(); - let mut pending_tool = None; + let mut pending_tools = BTreeMap::new(); push_output_block( OutputContentBlock::Text { text: "# Heading".to_string(), }, + 0, &mut out, &mut events, - &mut pending_tool, + &mut pending_tools, false, ) .expect("text block should render"); @@ -3395,7 +3408,7 @@ mod tests { fn push_output_block_skips_empty_object_prefix_for_tool_streams() { let mut out = Vec::new(); let mut events = Vec::new(); - let mut pending_tool = None; + let mut pending_tools = BTreeMap::new(); push_output_block( OutputContentBlock::ToolUse { @@ -3403,20 +3416,83 @@ mod tests { name: "read_file".to_string(), input: json!({}), }, + 1, &mut out, &mut events, - &mut pending_tool, + &mut pending_tools, true, ) .expect("tool block should accumulate"); assert!(events.is_empty()); assert_eq!( - pending_tool, + pending_tools.remove(&1), Some(("tool-1".to_string(), "read_file".to_string(), String::new(),)) ); } + #[test] + fn pending_tools_preserve_multiple_streaming_tool_calls_by_index() { + let mut out = Vec::new(); + let mut events = Vec::new(); + let mut pending_tools = BTreeMap::new(); + + push_output_block( + OutputContentBlock::ToolUse { + id: "tool-1".to_string(), + name: "read_file".to_string(), + input: json!({}), + }, + 1, + &mut out, + &mut events, + &mut pending_tools, + true, + ) + .expect("first tool should accumulate"); + push_output_block( + OutputContentBlock::ToolUse { + id: "tool-2".to_string(), + name: "grep_search".to_string(), + input: json!({}), + }, + 2, + &mut out, + &mut events, + &mut pending_tools, + true, + ) + .expect("second tool should accumulate"); + + pending_tools + .get_mut(&1) + .expect("first tool pending") + .2 + .push_str("{\"path\":\"src/main.rs\"}"); + pending_tools + .get_mut(&2) + .expect("second tool pending") + .2 + .push_str("{\"pattern\":\"TODO\"}"); + + assert_eq!( + pending_tools.remove(&1), + Some(( + "tool-1".to_string(), + "read_file".to_string(), + "{\"path\":\"src/main.rs\"}".to_string(), + )) + ); + assert_eq!( + pending_tools.remove(&2), + Some(( + "tool-2".to_string(), + "grep_search".to_string(), + "{\"pattern\":\"TODO\"}".to_string(), + )) + ); + } + #[test] fn response_to_events_preserves_empty_object_json_input_outside_streaming() { let mut out = Vec::new(); diff --git a/rust/crates/tools/src/lib.rs b/rust/crates/tools/src/lib.rs index 6448ca0..63be324 100644 --- a/rust/crates/tools/src/lib.rs +++ b/rust/crates/tools/src/lib.rs @@ -4,10 +4,9 @@ use std::process::Command; use std::time::{Duration, Instant}; use api::{ - detect_provider_kind, max_tokens_for_model, resolve_model_alias, ContentBlockDelta, - InputContentBlock, InputMessage, MessageRequest, MessageResponse, OutputContentBlock, - ProviderClient, ProviderKind, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, - ToolResultContentBlock, + max_tokens_for_model, resolve_model_alias, ContentBlockDelta, InputContentBlock, InputMessage, + MessageRequest, MessageResponse, OutputContentBlock, ProviderClient, + StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, }; use reqwest::blocking::Client; use runtime::{ @@ -1646,11 +1645,7 @@ struct ProviderRuntimeClient { impl ProviderRuntimeClient { fn new(model: String, allowed_tools: BTreeSet) -> Result { let model = resolve_model_alias(&model).to_string(); - let client = match detect_provider_kind(&model) { - ProviderKind::Anthropic | ProviderKind::Xai | ProviderKind::OpenAi => { - ProviderClient::from_model(&model).map_err(|error| error.to_string())? - } - }; + let client = ProviderClient::from_model(&model).map_err(|error| error.to_string())?; Ok(Self { runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?, client, @@ -1687,7 +1682,7 @@ impl ApiClient for ProviderRuntimeClient { .await .map_err(|error| RuntimeError::new(error.to_string()))?; let mut events = Vec::new(); - let mut pending_tool: Option<(String, String, String)> = None; + let mut pending_tools: BTreeMap = BTreeMap::new(); let mut saw_stop = false; while let Some(event) = stream @@ -1698,14 +1693,15 @@ impl ApiClient for ProviderRuntimeClient { match event { ApiStreamEvent::MessageStart(start) => { for block in start.message.content { - push_output_block(block, &mut events, &mut pending_tool, true); + push_output_block(block, 0, &mut events, &mut pending_tools, true); } } ApiStreamEvent::ContentBlockStart(start) => { push_output_block( start.content_block, + start.index, &mut events, - &mut pending_tool, + &mut pending_tools, true, ); } @@ -1716,13 +1712,13 @@ impl ApiClient for ProviderRuntimeClient { } } ContentBlockDelta::InputJsonDelta { partial_json } => { - if let Some((_, _, input)) = &mut pending_tool { + if let Some((_, _, input)) = pending_tools.get_mut(&delta.index) { input.push_str(&partial_json); } } }, - ApiStreamEvent::ContentBlockStop(_) => { - if let Some((id, name, input)) = pending_tool.take() { + ApiStreamEvent::ContentBlockStop(stop) => { + if let Some((id, name, input)) = pending_tools.remove(&stop.index) { events.push(AssistantEvent::ToolUse { id, name, input }); } } @@ -1843,8 +1839,9 @@ fn convert_messages(messages: &[ConversationMessage]) -> Vec { fn push_output_block( block: OutputContentBlock, + block_index: u32, events: &mut Vec, - pending_tool: &mut Option<(String, String, String)>, + pending_tools: &mut BTreeMap, streaming_tool_input: bool, ) { match block { @@ -1862,18 +1859,19 @@ fn push_output_block( } else { input.to_string() }; - *pending_tool = Some((id, name, initial_input)); + pending_tools.insert(block_index, (id, name, initial_input)); } } } fn response_to_events(response: MessageResponse) -> Vec { let mut events = Vec::new(); - let mut pending_tool = None; + let mut pending_tools = BTreeMap::new(); - for block in response.content { - push_output_block(block, &mut events, &mut pending_tool, false); - if let Some((id, name, input)) = pending_tool.take() { + for (index, block) in response.content.into_iter().enumerate() { + let index = u32::try_from(index).expect("response block index overflow"); + push_output_block(block, index, &mut events, &mut pending_tools, false); + if let Some((id, name, input)) = pending_tools.remove(&index) { events.push(AssistantEvent::ToolUse { id, name, input }); } } @@ -2897,6 +2895,7 @@ fn parse_skill_description(contents: &str) -> Option { #[cfg(test)] mod tests { + use std::collections::BTreeMap; use std::collections::BTreeSet; use std::fs; use std::io::{Read, Write}; @@ -2909,8 +2908,9 @@ mod tests { use super::{ agent_permission_policy, allowed_tools_for_subagent, execute_agent_with_spawn, execute_tool, final_assistant_text, mvp_tool_specs, persist_agent_terminal_state, - AgentInput, AgentJob, SubagentToolExecutor, + push_output_block, AgentInput, AgentJob, SubagentToolExecutor, }; + use api::OutputContentBlock; use runtime::{ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, Session}; use serde_json::json; @@ -3125,6 +3125,63 @@ mod tests { assert!(error.contains("relative URL without a base") || error.contains("empty host")); } + #[test] + fn pending_tools_preserve_multiple_streaming_tool_calls_by_index() { + let mut events = Vec::new(); + let mut pending_tools = BTreeMap::new(); + + push_output_block( + OutputContentBlock::ToolUse { + id: "tool-1".to_string(), + name: "read_file".to_string(), + input: json!({}), + }, + 1, + &mut events, + &mut pending_tools, + true, + ); + push_output_block( + OutputContentBlock::ToolUse { + id: "tool-2".to_string(), + name: "grep_search".to_string(), + input: json!({}), + }, + 2, + &mut events, + &mut pending_tools, + true, + ); + + pending_tools + .get_mut(&1) + .expect("first tool pending") + .2 + .push_str("{\"path\":\"src/main.rs\"}"); + pending_tools + .get_mut(&2) + .expect("second tool pending") + .2 + .push_str("{\"pattern\":\"TODO\"}"); + + assert_eq!( + pending_tools.remove(&1), + Some(( + "tool-1".to_string(), + "read_file".to_string(), + "{\"path\":\"src/main.rs\"}".to_string(), + )) + ); + assert_eq!( + pending_tools.remove(&2), + Some(( + "tool-2".to_string(), + "grep_search".to_string(), + "{\"pattern\":\"TODO\"}".to_string(), + )) + ); + } + #[test] fn todo_write_persists_and_returns_previous_state() { let _guard = env_lock()