diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index 5d60f92..477a473 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -5,15 +5,16 @@ use std::collections::{BTreeMap, BTreeSet}; use std::env; use std::fs; use std::io::{self, Read, Write}; -use std::net::TcpListener; +use std::net::{TcpListener, TcpStream, ToSocketAddrs}; use std::path::{Path, PathBuf}; use std::process::Command; -use std::time::{SystemTime, UNIX_EPOCH}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; use api::{ - resolve_startup_auth_source, AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock, - InputMessage, MessageRequest, MessageResponse, OutputContentBlock, - StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, + oauth_token_is_expired, resolve_startup_auth_source, AnthropicClient, ApiError, AuthSource, + ContentBlockDelta, InputContentBlock, InputMessage, MessageRequest, MessageResponse, + OutputContentBlock, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, + ToolResultContentBlock, }; use commands::{ @@ -22,10 +23,11 @@ use commands::{ use compat_harness::{extract_manifest, UpstreamPaths}; use render::{Spinner, TerminalRenderer}; use runtime::{ - clear_oauth_credentials, format_usd, generate_pkce_pair, generate_state, load_system_prompt, - parse_oauth_callback_request_target, pricing_for_model, save_oauth_credentials, ApiClient, + clear_oauth_credentials, generate_pkce_pair, generate_state, load_oauth_credentials, + load_system_prompt, parse_oauth_callback_request_target, save_oauth_credentials, ApiClient, ApiRequest, AssistantEvent, CompactionConfig, ConfigLoader, ConfigSource, ContentBlock, - ConversationMessage, ConversationRuntime, MessageRole, OAuthAuthorizationRequest, + ConversationMessage, ConversationRuntime, McpClientBootstrap, McpClientTransport, + McpServerConfig, McpStdioProcess, MessageRole, OAuthAuthorizationRequest, OAuthTokenExchangeRequest, PermissionMode, PermissionPolicy, ProjectContext, RuntimeError, Session, TokenUsage, ToolError, ToolExecutor, UsageTracker, }; @@ -36,7 +38,6 @@ const DEFAULT_MODEL: &str = "claude-sonnet-4-20250514"; const DEFAULT_MAX_TOKENS: u32 = 32; const DEFAULT_DATE: &str = "2026-03-31"; const DEFAULT_OAUTH_CALLBACK_PORT: u16 = 4545; -const COST_WARNING_FRACTION: f64 = 0.8; const VERSION: &str = env!("CARGO_PKG_VERSION"); const BUILD_TARGET: Option<&str> = option_env!("TARGET"); const GIT_SHA: Option<&str> = option_env!("GIT_SHA"); @@ -71,23 +72,22 @@ fn run() -> Result<(), Box> { output_format, allowed_tools, permission_mode, - max_cost_usd, - } => LiveCli::new(model, false, allowed_tools, permission_mode, max_cost_usd)? + } => LiveCli::new(model, false, allowed_tools, permission_mode)? .run_turn_with_output(&prompt, output_format)?, CliAction::Login => run_login()?, CliAction::Logout => run_logout()?, + CliAction::Doctor => run_doctor()?, CliAction::Repl { model, allowed_tools, permission_mode, - max_cost_usd, - } => run_repl(model, allowed_tools, permission_mode, max_cost_usd)?, + } => run_repl(model, allowed_tools, permission_mode)?, CliAction::Help => print_help(), } Ok(()) } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq)] enum CliAction { DumpManifests, BootstrapPlan, @@ -106,15 +106,14 @@ enum CliAction { output_format: CliOutputFormat, allowed_tools: Option, permission_mode: PermissionMode, - max_cost_usd: Option, }, Login, Logout, + Doctor, Repl { model: String, allowed_tools: Option, permission_mode: PermissionMode, - max_cost_usd: Option, }, // prompt-mode formatting is only supported for non-interactive runs Help, @@ -144,7 +143,6 @@ fn parse_args(args: &[String]) -> Result { let mut output_format = CliOutputFormat::Text; let mut permission_mode = default_permission_mode(); let mut wants_version = false; - let mut max_cost_usd: Option = None; let mut allowed_tool_values = Vec::new(); let mut rest = Vec::new(); let mut index = 0; @@ -180,13 +178,6 @@ fn parse_args(args: &[String]) -> Result { permission_mode = parse_permission_mode_arg(value)?; index += 2; } - "--max-cost" => { - let value = args - .get(index + 1) - .ok_or_else(|| "missing value for --max-cost".to_string())?; - max_cost_usd = Some(parse_max_cost_arg(value)?); - index += 2; - } flag if flag.starts_with("--output-format=") => { output_format = CliOutputFormat::parse(&flag[16..])?; index += 1; @@ -195,10 +186,6 @@ fn parse_args(args: &[String]) -> Result { permission_mode = parse_permission_mode_arg(&flag[18..])?; index += 1; } - flag if flag.starts_with("--max-cost=") => { - max_cost_usd = Some(parse_max_cost_arg(&flag[11..])?); - index += 1; - } "--allowedTools" | "--allowed-tools" => { let value = args .get(index + 1) @@ -232,7 +219,6 @@ fn parse_args(args: &[String]) -> Result { model, allowed_tools, permission_mode, - max_cost_usd, }); } if matches!(rest.first().map(String::as_str), Some("--help" | "-h")) { @@ -248,6 +234,7 @@ fn parse_args(args: &[String]) -> Result { "system-prompt" => parse_system_prompt_args(&rest[1..]), "login" => Ok(CliAction::Login), "logout" => Ok(CliAction::Logout), + "doctor" => Ok(CliAction::Doctor), "prompt" => { let prompt = rest[1..].join(" "); if prompt.trim().is_empty() { @@ -259,7 +246,6 @@ fn parse_args(args: &[String]) -> Result { output_format, allowed_tools, permission_mode, - max_cost_usd, }) } other if !other.starts_with('/') => Ok(CliAction::Prompt { @@ -268,7 +254,6 @@ fn parse_args(args: &[String]) -> Result { output_format, allowed_tools, permission_mode, - max_cost_usd, }), other => Err(format!("unknown subcommand: {other}")), } @@ -332,18 +317,6 @@ fn parse_permission_mode_arg(value: &str) -> Result { .map(permission_mode_from_label) } -fn parse_max_cost_arg(value: &str) -> Result { - let parsed = value - .parse::() - .map_err(|_| format!("invalid value for --max-cost: {value}"))?; - if !parsed.is_finite() || parsed <= 0.0 { - return Err(format!( - "--max-cost must be a positive finite USD amount: {value}" - )); - } - Ok(parsed) -} - fn permission_mode_from_label(mode: &str) -> PermissionMode { match mode { "read-only" => PermissionMode::ReadOnly, @@ -552,6 +525,627 @@ fn wait_for_oauth_callback( Ok(callback) } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum DiagnosticLevel { + Ok, + Warn, + Fail, +} + +impl DiagnosticLevel { + const fn label(self) -> &'static str { + match self { + Self::Ok => "OK", + Self::Warn => "WARN", + Self::Fail => "FAIL", + } + } + + const fn is_failure(self) -> bool { + matches!(self, Self::Fail) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct DiagnosticCheck { + name: &'static str, + level: DiagnosticLevel, + summary: String, + details: Vec, +} + +impl DiagnosticCheck { + fn new(name: &'static str, level: DiagnosticLevel, summary: impl Into) -> Self { + Self { + name, + level, + summary: summary.into(), + details: Vec::new(), + } + } + + fn with_details(mut self, details: Vec) -> Self { + self.details = details; + self + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum OAuthDiagnosticStatus { + Missing, + Valid, + ExpiredRefreshable, + ExpiredNoRefresh, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct ConfigFileCheck { + path: PathBuf, + exists: bool, + valid: bool, + note: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct DoctorReport { + checks: Vec, +} + +impl DoctorReport { + fn has_failures(&self) -> bool { + self.checks.iter().any(|check| check.level.is_failure()) + } + + fn render(&self) -> String { + let mut lines = vec!["Doctor diagnostics".to_string()]; + let ok_count = self + .checks + .iter() + .filter(|check| check.level == DiagnosticLevel::Ok) + .count(); + let warn_count = self + .checks + .iter() + .filter(|check| check.level == DiagnosticLevel::Warn) + .count(); + let fail_count = self + .checks + .iter() + .filter(|check| check.level == DiagnosticLevel::Fail) + .count(); + lines.push(format!( + "Summary\n OK {ok_count}\n Warnings {warn_count}\n Failures {fail_count}" + )); + lines.extend(self.checks.iter().map(render_diagnostic_check)); + lines.join("\n\n") + } +} + +fn render_diagnostic_check(check: &DiagnosticCheck) -> String { + let mut section = vec![format!( + "{}\n Status {}\n Summary {}", + check.name, + check.level.label(), + check.summary + )]; + if !check.details.is_empty() { + section.push(" Details".to_string()); + section.extend(check.details.iter().map(|detail| format!(" - {detail}"))); + } + section.join("\n") +} + +fn run_doctor() -> Result<(), Box> { + let cwd = env::current_dir()?; + let config_loader = ConfigLoader::default_for(&cwd); + let config = config_loader.load(); + let report = DoctorReport { + checks: vec![ + check_api_key_validity(config.as_ref().ok()), + check_oauth_token_status(config.as_ref().ok()), + check_config_files(&config_loader, config.as_ref()), + check_git_availability(&cwd), + check_mcp_server_health(config.as_ref().ok()), + check_network_connectivity(), + check_system_info(&cwd, config.as_ref().ok()), + ], + }; + println!("{}", report.render()); + if report.has_failures() { + return Err("doctor found failing checks".into()); + } + Ok(()) +} + +fn check_api_key_validity(config: Option<&runtime::RuntimeConfig>) -> DiagnosticCheck { + let api_key = match env::var("ANTHROPIC_API_KEY") { + Ok(value) if !value.trim().is_empty() => value, + Ok(_) | Err(env::VarError::NotPresent) => { + return DiagnosticCheck::new( + "API key validity", + DiagnosticLevel::Warn, + "ANTHROPIC_API_KEY is not set", + ); + } + Err(error) => { + return DiagnosticCheck::new( + "API key validity", + DiagnosticLevel::Fail, + format!("failed to read ANTHROPIC_API_KEY: {error}"), + ); + } + }; + + let request = MessageRequest { + model: config + .and_then(runtime::RuntimeConfig::model) + .unwrap_or(DEFAULT_MODEL) + .to_string(), + max_tokens: 1, + messages: vec![InputMessage { + role: "user".to_string(), + content: vec![InputContentBlock::Text { + text: "Reply with OK.".to_string(), + }], + }], + system: None, + tools: None, + tool_choice: None, + stream: false, + }; + let runtime = match tokio::runtime::Runtime::new() { + Ok(runtime) => runtime, + Err(error) => { + return DiagnosticCheck::new( + "API key validity", + DiagnosticLevel::Fail, + format!("failed to create async runtime: {error}"), + ); + } + }; + match runtime + .block_on(AnthropicClient::from_auth(AuthSource::ApiKey(api_key)).send_message(&request)) + { + Ok(response) => DiagnosticCheck::new( + "API key validity", + DiagnosticLevel::Ok, + "Anthropic API accepted the configured API key", + ) + .with_details(vec![format!( + "request_id={} input_tokens={} output_tokens={}", + response.request_id.unwrap_or_else(|| "".to_string()), + response.usage.input_tokens, + response.usage.output_tokens + )]), + Err(ApiError::Api { status, .. }) if status.as_u16() == 401 || status.as_u16() == 403 => { + DiagnosticCheck::new( + "API key validity", + DiagnosticLevel::Fail, + format!("Anthropic API rejected the API key with HTTP {status}"), + ) + } + Err(error) => DiagnosticCheck::new( + "API key validity", + DiagnosticLevel::Warn, + format!("unable to conclusively validate the API key: {error}"), + ), + } +} + +fn classify_oauth_status() -> Result<(OAuthDiagnosticStatus, Vec), io::Error> { + let Some(token_set) = load_oauth_credentials()? else { + return Ok((OAuthDiagnosticStatus::Missing, vec![])); + }; + let token = api::OAuthTokenSet { + access_token: token_set.access_token.clone(), + refresh_token: token_set.refresh_token.clone(), + expires_at: token_set.expires_at, + scopes: token_set.scopes.clone(), + }; + let details = vec![format!( + "expires_at={} refresh_token={} scopes={}", + token + .expires_at + .map_or_else(|| "".to_string(), |value| value.to_string()), + if token.refresh_token.is_some() { + "present" + } else { + "absent" + }, + if token.scopes.is_empty() { + "".to_string() + } else { + token.scopes.join(",") + } + )]; + let status = if oauth_token_is_expired(&token) { + if token.refresh_token.is_some() { + OAuthDiagnosticStatus::ExpiredRefreshable + } else { + OAuthDiagnosticStatus::ExpiredNoRefresh + } + } else { + OAuthDiagnosticStatus::Valid + }; + Ok((status, details)) +} + +fn check_oauth_token_status(config: Option<&runtime::RuntimeConfig>) -> DiagnosticCheck { + match classify_oauth_status() { + Ok((OAuthDiagnosticStatus::Missing, _)) => DiagnosticCheck::new( + "OAuth token status", + DiagnosticLevel::Warn, + "no saved OAuth credentials found", + ), + Ok((OAuthDiagnosticStatus::Valid, details)) => DiagnosticCheck::new( + "OAuth token status", + DiagnosticLevel::Ok, + "saved OAuth token is present and not expired", + ) + .with_details(details), + Ok((OAuthDiagnosticStatus::ExpiredRefreshable, mut details)) => { + let refresh_ready = config.and_then(runtime::RuntimeConfig::oauth).is_some(); + details.push(if refresh_ready { + "runtime OAuth config is present for refresh".to_string() + } else { + "runtime OAuth config is missing for refresh".to_string() + }); + DiagnosticCheck::new( + "OAuth token status", + if refresh_ready { + DiagnosticLevel::Warn + } else { + DiagnosticLevel::Fail + }, + "saved OAuth token is expired but includes a refresh token", + ) + .with_details(details) + } + Ok((OAuthDiagnosticStatus::ExpiredNoRefresh, details)) => DiagnosticCheck::new( + "OAuth token status", + DiagnosticLevel::Fail, + "saved OAuth token is expired and cannot refresh", + ) + .with_details(details), + Err(error) => DiagnosticCheck::new( + "OAuth token status", + DiagnosticLevel::Fail, + format!("failed to read saved OAuth credentials: {error}"), + ), + } +} + +fn validate_config_file(path: &Path) -> ConfigFileCheck { + match fs::read_to_string(path) { + Ok(contents) => { + if contents.trim().is_empty() { + return ConfigFileCheck { + path: path.to_path_buf(), + exists: true, + valid: true, + note: "exists but is empty".to_string(), + }; + } + match serde_json::from_str::(&contents) { + Ok(serde_json::Value::Object(_)) => ConfigFileCheck { + path: path.to_path_buf(), + exists: true, + valid: true, + note: "valid JSON object".to_string(), + }, + Ok(_) => ConfigFileCheck { + path: path.to_path_buf(), + exists: true, + valid: false, + note: "top-level JSON value is not an object".to_string(), + }, + Err(error) => ConfigFileCheck { + path: path.to_path_buf(), + exists: true, + valid: false, + note: format!("invalid JSON: {error}"), + }, + } + } + Err(error) if error.kind() == io::ErrorKind::NotFound => ConfigFileCheck { + path: path.to_path_buf(), + exists: false, + valid: true, + note: "not present".to_string(), + }, + Err(error) => ConfigFileCheck { + path: path.to_path_buf(), + exists: true, + valid: false, + note: format!("unreadable: {error}"), + }, + } +} + +fn check_config_files( + config_loader: &ConfigLoader, + config: Result<&runtime::RuntimeConfig, &runtime::ConfigError>, +) -> DiagnosticCheck { + let file_checks = config_loader + .discover() + .into_iter() + .map(|entry| validate_config_file(&entry.path)) + .collect::>(); + let existing_count = file_checks.iter().filter(|check| check.exists).count(); + let invalid_count = file_checks + .iter() + .filter(|check| check.exists && !check.valid) + .count(); + let mut details = file_checks + .iter() + .map(|check| format!("{} => {}", check.path.display(), check.note)) + .collect::>(); + match config { + Ok(runtime_config) => details.push(format!( + "merged load succeeded with {} loaded file(s)", + runtime_config.loaded_entries().len() + )), + Err(error) => details.push(format!("merged load failed: {error}")), + } + DiagnosticCheck::new( + "Config files", + if invalid_count > 0 || config.is_err() { + DiagnosticLevel::Fail + } else if existing_count == 0 { + DiagnosticLevel::Warn + } else { + DiagnosticLevel::Ok + }, + format!( + "discovered {} candidate file(s), {} existing, {} invalid", + file_checks.len(), + existing_count, + invalid_count + ), + ) + .with_details(details) +} + +fn check_git_availability(cwd: &Path) -> DiagnosticCheck { + match Command::new("git").arg("--version").output() { + Ok(version_output) if version_output.status.success() => { + let version = String::from_utf8_lossy(&version_output.stdout) + .trim() + .to_string(); + match Command::new("git") + .args(["rev-parse", "--show-toplevel"]) + .current_dir(cwd) + .output() + { + Ok(root_output) if root_output.status.success() => DiagnosticCheck::new( + "Git availability", + DiagnosticLevel::Ok, + "git is installed and the current directory is inside a repository", + ) + .with_details(vec![ + version, + format!( + "repo_root={}", + String::from_utf8_lossy(&root_output.stdout).trim() + ), + ]), + Ok(_) => DiagnosticCheck::new( + "Git availability", + DiagnosticLevel::Warn, + "git is installed but the current directory is not a repository", + ) + .with_details(vec![version]), + Err(error) => DiagnosticCheck::new( + "Git availability", + DiagnosticLevel::Warn, + format!("git is installed but repo detection failed: {error}"), + ) + .with_details(vec![version]), + } + } + Ok(output) => DiagnosticCheck::new( + "Git availability", + DiagnosticLevel::Fail, + format!("git --version exited with status {}", output.status), + ), + Err(error) => DiagnosticCheck::new( + "Git availability", + DiagnosticLevel::Fail, + format!("failed to execute git: {error}"), + ), + } +} + +fn check_one_mcp_server( + name: &str, + server: &runtime::ScopedMcpServerConfig, +) -> (DiagnosticLevel, String) { + match &server.config { + McpServerConfig::Stdio(_) => { + let bootstrap = McpClientBootstrap::from_scoped_config(name, server); + let runtime = match tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + { + Ok(runtime) => runtime, + Err(error) => { + return ( + DiagnosticLevel::Fail, + format!("{name}: runtime error: {error}"), + ) + } + }; + let detail = runtime.block_on(async { + match tokio::time::timeout(Duration::from_secs(3), async { + let mut process = McpStdioProcess::spawn(match &bootstrap.transport { + McpClientTransport::Stdio(transport) => transport, + _ => unreachable!("stdio bootstrap expected"), + })?; + let result = process + .initialize( + runtime::JsonRpcId::Number(1), + runtime::McpInitializeParams { + protocol_version: "2025-03-26".to_string(), + capabilities: serde_json::Value::Object(serde_json::Map::new()), + client_info: runtime::McpInitializeClientInfo { + name: "doctor".to_string(), + version: VERSION.to_string(), + }, + }, + ) + .await; + let _ = process.terminate().await; + result + }) + .await + { + Ok(Ok(response)) => { + if let Some(error) = response.error { + ( + DiagnosticLevel::Fail, + format!( + "{name}: initialize JSON-RPC error {} ({})", + error.message, error.code + ), + ) + } else if let Some(result) = response.result { + ( + DiagnosticLevel::Ok, + format!( + "{name}: ok (server {} {})", + result.server_info.name, result.server_info.version + ), + ) + } else { + ( + DiagnosticLevel::Fail, + format!("{name}: initialize returned no result"), + ) + } + } + Ok(Err(error)) => ( + DiagnosticLevel::Fail, + format!("{name}: spawn/initialize failed: {error}"), + ), + Err(_) => ( + DiagnosticLevel::Fail, + format!("{name}: timed out during initialize"), + ), + } + }); + detail + } + other => ( + DiagnosticLevel::Warn, + format!( + "{name}: transport {:?} configured (active health probe not implemented)", + other.transport() + ), + ), + } +} + +fn check_mcp_server_health(config: Option<&runtime::RuntimeConfig>) -> DiagnosticCheck { + let Some(config) = config else { + return DiagnosticCheck::new( + "MCP server health", + DiagnosticLevel::Warn, + "runtime config could not be loaded, so MCP servers were not inspected", + ); + }; + let servers = config.mcp().servers(); + if servers.is_empty() { + return DiagnosticCheck::new( + "MCP server health", + DiagnosticLevel::Warn, + "no MCP servers are configured", + ); + } + let results = servers + .iter() + .map(|(name, server)| check_one_mcp_server(name, server)) + .collect::>(); + let level = if results + .iter() + .any(|(level, _)| *level == DiagnosticLevel::Fail) + { + DiagnosticLevel::Fail + } else if results + .iter() + .any(|(level, _)| *level == DiagnosticLevel::Warn) + { + DiagnosticLevel::Warn + } else { + DiagnosticLevel::Ok + }; + DiagnosticCheck::new( + "MCP server health", + level, + format!("checked {} configured MCP server(s)", servers.len()), + ) + .with_details(results.into_iter().map(|(_, detail)| detail).collect()) +} + +fn check_network_connectivity() -> DiagnosticCheck { + let address = match ("api.anthropic.com", 443).to_socket_addrs() { + Ok(mut addrs) => match addrs.next() { + Some(addr) => addr, + None => { + return DiagnosticCheck::new( + "Network connectivity", + DiagnosticLevel::Fail, + "DNS resolution returned no addresses for api.anthropic.com", + ); + } + }, + Err(error) => { + return DiagnosticCheck::new( + "Network connectivity", + DiagnosticLevel::Fail, + format!("failed to resolve api.anthropic.com: {error}"), + ); + } + }; + match TcpStream::connect_timeout(&address, Duration::from_secs(5)) { + Ok(stream) => { + let _ = stream.shutdown(std::net::Shutdown::Both); + DiagnosticCheck::new( + "Network connectivity", + DiagnosticLevel::Ok, + format!("connected to {address}"), + ) + } + Err(error) => DiagnosticCheck::new( + "Network connectivity", + DiagnosticLevel::Fail, + format!("failed to connect to {address}: {error}"), + ), + } +} + +fn check_system_info(cwd: &Path, config: Option<&runtime::RuntimeConfig>) -> DiagnosticCheck { + let mut details = vec![ + format!("os={} arch={}", env::consts::OS, env::consts::ARCH), + format!("cwd={}", cwd.display()), + format!("cli_version={VERSION}"), + format!("build_target={}", BUILD_TARGET.unwrap_or("")), + format!("git_sha={}", GIT_SHA.unwrap_or("")), + ]; + if let Some(config) = config { + details.push(format!( + "resolved_model={} loaded_config_files={}", + config.model().unwrap_or(DEFAULT_MODEL), + config.loaded_entries().len() + )); + } + DiagnosticCheck::new( + "System info", + DiagnosticLevel::Ok, + "captured local runtime and build metadata", + ) + .with_details(details) +} + fn print_system_prompt(cwd: PathBuf, date: String) { match load_system_prompt(cwd, date, env::consts::OS, "unknown") { Ok(sections) => println!("{}", sections.join("\n\n")), @@ -710,78 +1304,22 @@ fn format_permissions_switch_report(previous: &str, next: &str) -> String { ) } -fn format_cost_report(model: &str, usage: TokenUsage, max_cost_usd: Option) -> String { - let estimate = usage_cost_estimate(model, usage); +fn format_cost_report(usage: TokenUsage) -> String { format!( "Cost - Model {model} Input tokens {} Output tokens {} Cache create {} Cache read {} - Total tokens {} - Input cost {} - Output cost {} - Cache create usd {} - Cache read usd {} - Estimated cost {} - Budget {}", + Total tokens {}", usage.input_tokens, usage.output_tokens, usage.cache_creation_input_tokens, usage.cache_read_input_tokens, usage.total_tokens(), - format_usd(estimate.input_cost_usd), - format_usd(estimate.output_cost_usd), - format_usd(estimate.cache_creation_cost_usd), - format_usd(estimate.cache_read_cost_usd), - format_usd(estimate.total_cost_usd()), - format_budget_line(estimate.total_cost_usd(), max_cost_usd), ) } -fn usage_cost_estimate(model: &str, usage: TokenUsage) -> runtime::UsageCostEstimate { - pricing_for_model(model).map_or_else( - || usage.estimate_cost_usd(), - |pricing| usage.estimate_cost_usd_with_pricing(pricing), - ) -} - -fn usage_cost_total(model: &str, usage: TokenUsage) -> f64 { - usage_cost_estimate(model, usage).total_cost_usd() -} - -fn format_budget_line(cost_usd: f64, max_cost_usd: Option) -> String { - match max_cost_usd { - Some(limit) => format!("{} / {}", format_usd(cost_usd), format_usd(limit)), - None => format!("{} (unlimited)", format_usd(cost_usd)), - } -} - -fn budget_notice_message( - model: &str, - usage: TokenUsage, - max_cost_usd: Option, -) -> Option { - let limit = max_cost_usd?; - let cost = usage_cost_total(model, usage); - if cost >= limit { - Some(format!( - "cost budget exceeded: cumulative={} budget={}", - format_usd(cost), - format_usd(limit) - )) - } else if cost >= limit * COST_WARNING_FRACTION { - Some(format!( - "approaching cost budget: cumulative={} budget={}", - format_usd(cost), - format_usd(limit) - )) - } else { - None - } -} - fn format_resume_report(session_path: &str, message_count: usize, turns: u32) -> String { format!( "Session resumed @@ -925,7 +1463,6 @@ fn run_resume_command( }, default_permission_mode().as_str(), &status_context(Some(session_path))?, - None, )), }) } @@ -933,7 +1470,7 @@ fn run_resume_command( let usage = UsageTracker::from_session(session).cumulative_usage(); Ok(ResumeCommandOutcome { session: session.clone(), - message: Some(format_cost_report("restored-session", usage, None)), + message: Some(format_cost_report(usage)), }) } SlashCommand::Config { section } => Ok(ResumeCommandOutcome { @@ -980,9 +1517,8 @@ fn run_repl( model: String, allowed_tools: Option, permission_mode: PermissionMode, - max_cost_usd: Option, ) -> Result<(), Box> { - let mut cli = LiveCli::new(model, true, allowed_tools, permission_mode, max_cost_usd)?; + let mut cli = LiveCli::new(model, true, allowed_tools, permission_mode)?; let mut editor = input::LineEditor::new("› ", slash_command_completion_candidates()); println!("{}", cli.startup_banner()); @@ -1035,7 +1571,6 @@ struct LiveCli { model: String, allowed_tools: Option, permission_mode: PermissionMode, - max_cost_usd: Option, system_prompt: Vec, runtime: ConversationRuntime, session: SessionHandle, @@ -1047,7 +1582,6 @@ impl LiveCli { enable_tools: bool, allowed_tools: Option, permission_mode: PermissionMode, - max_cost_usd: Option, ) -> Result> { let system_prompt = build_system_prompt()?; let session = create_managed_session_handle()?; @@ -1063,7 +1597,6 @@ impl LiveCli { model, allowed_tools, permission_mode, - max_cost_usd, system_prompt, runtime, session, @@ -1074,10 +1607,9 @@ impl LiveCli { fn startup_banner(&self) -> String { format!( - "Rusty Claude CLI\n Model {}\n Permission mode {}\n Cost budget {}\n Working directory {}\n Session {}\n\nType /help for commands. Shift+Enter or Ctrl+J inserts a newline.", + "Rusty Claude CLI\n Model {}\n Permission mode {}\n Working directory {}\n Session {}\n\nType /help for commands. Shift+Enter or Ctrl+J inserts a newline.", self.model, self.permission_mode.as_str(), - self.max_cost_usd.map_or_else(|| "none".to_string(), format_usd), env::current_dir().map_or_else( |_| "".to_string(), |path| path.display().to_string(), @@ -1087,7 +1619,6 @@ impl LiveCli { } fn run_turn(&mut self, input: &str) -> Result<(), Box> { - self.enforce_budget_before_turn()?; let mut spinner = Spinner::new(); let mut stdout = io::stdout(); spinner.tick( @@ -1098,14 +1629,13 @@ impl LiveCli { let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode); let result = self.runtime.run_turn(input, Some(&mut permission_prompter)); match result { - Ok(summary) => { + Ok(_) => { spinner.finish( "Claude response complete", TerminalRenderer::new().color_theme(), &mut stdout, )?; println!(); - self.print_budget_notice(summary.usage); self.persist_session()?; Ok(()) } @@ -1132,7 +1662,6 @@ impl LiveCli { } fn run_prompt_json(&mut self, input: &str) -> Result<(), Box> { - self.enforce_budget_before_turn()?; let client = AnthropicClient::from_auth(resolve_cli_auth_source()?); let request = MessageRequest { model: self.model.clone(), @@ -1159,27 +1688,17 @@ impl LiveCli { }) .collect::>() .join(""); - let usage = TokenUsage { - input_tokens: response.usage.input_tokens, - output_tokens: response.usage.output_tokens, - cache_creation_input_tokens: response.usage.cache_creation_input_tokens, - cache_read_input_tokens: response.usage.cache_read_input_tokens, - }; println!( "{}", json!({ "message": text, "model": self.model, "usage": { - "input_tokens": usage.input_tokens, - "output_tokens": usage.output_tokens, - "cache_creation_input_tokens": usage.cache_creation_input_tokens, - "cache_read_input_tokens": usage.cache_read_input_tokens, - }, - "cost_usd": usage_cost_total(&self.model, usage), - "cumulative_cost_usd": usage_cost_total(&self.model, usage), - "max_cost_usd": self.max_cost_usd, - "budget_warning": budget_notice_message(&self.model, usage, self.max_cost_usd), + "input_tokens": response.usage.input_tokens, + "output_tokens": response.usage.output_tokens, + "cache_creation_input_tokens": response.usage.cache_creation_input_tokens, + "cache_read_input_tokens": response.usage.cache_read_input_tokens, + } }) ); Ok(()) @@ -1249,28 +1768,6 @@ impl LiveCli { Ok(()) } - fn enforce_budget_before_turn(&self) -> Result<(), Box> { - let Some(limit) = self.max_cost_usd else { - return Ok(()); - }; - let cost = usage_cost_total(&self.model, self.runtime.usage().cumulative_usage()); - if cost >= limit { - return Err(format!( - "cost budget exceeded before starting turn: cumulative={} budget={}", - format_usd(cost), - format_usd(limit) - ) - .into()); - } - Ok(()) - } - - fn print_budget_notice(&self, usage: TokenUsage) { - if let Some(message) = budget_notice_message(&self.model, usage, self.max_cost_usd) { - eprintln!("warning: {message}"); - } - } - fn print_status(&self) { let cumulative = self.runtime.usage().cumulative_usage(); let latest = self.runtime.usage().current_turn_usage(); @@ -1287,7 +1784,6 @@ impl LiveCli { }, self.permission_mode.as_str(), &status_context(Some(&self.session.path)).expect("status context should load"), - self.max_cost_usd, ) ); } @@ -1405,10 +1901,7 @@ impl LiveCli { fn print_cost(&self) { let cumulative = self.runtime.usage().cumulative_usage(); - println!( - "{}", - format_cost_report(&self.model, cumulative, self.max_cost_usd) - ); + println!("{}", format_cost_report(cumulative)); } fn resume_session( @@ -1686,10 +2179,7 @@ fn format_status_report( usage: StatusUsage, permission_mode: &str, context: &StatusContext, - max_cost_usd: Option, ) -> String { - let latest_cost = usage_cost_total(model, usage.latest); - let cumulative_cost = usage_cost_total(model, usage.cumulative); [ format!( "Status @@ -1697,27 +2187,19 @@ fn format_status_report( Permission mode {permission_mode} Messages {} Turns {} - Estimated tokens {} - Cost budget {}", - usage.message_count, - usage.turns, - usage.estimated_tokens, - format_budget_line(cumulative_cost, max_cost_usd), + Estimated tokens {}", + usage.message_count, usage.turns, usage.estimated_tokens, ), format!( "Usage Latest total {} - Latest cost {} Cumulative input {} Cumulative output {} - Cumulative total {} - Cumulative cost {}", + Cumulative total {}", usage.latest.total_tokens(), - format_usd(latest_cost), usage.cumulative.input_tokens, usage.cumulative.output_tokens, usage.cumulative.total_tokens(), - format_usd(cumulative_cost), ), format!( "Workspace @@ -2489,9 +2971,9 @@ fn print_help() { println!("rusty-claude-cli v{VERSION}"); println!(); println!("Usage:"); - println!(" rusty-claude-cli [--model MODEL] [--max-cost USD] [--allowedTools TOOL[,TOOL...]]"); + println!(" rusty-claude-cli [--model MODEL] [--allowedTools TOOL[,TOOL...]]"); println!(" Start the interactive REPL"); - println!(" rusty-claude-cli [--model MODEL] [--max-cost USD] [--output-format text|json] prompt TEXT"); + println!(" rusty-claude-cli [--model MODEL] [--output-format text|json] prompt TEXT"); println!(" Send one prompt and exit"); println!(" rusty-claude-cli [--model MODEL] [--output-format text|json] TEXT"); println!(" Shorthand non-interactive prompt mode"); @@ -2502,12 +2984,12 @@ fn print_help() { println!(" rusty-claude-cli system-prompt [--cwd PATH] [--date YYYY-MM-DD]"); println!(" rusty-claude-cli login"); println!(" rusty-claude-cli logout"); + println!(" rusty-claude-cli doctor"); println!(); println!("Flags:"); println!(" --model MODEL Override the active model"); println!(" --output-format FORMAT Non-interactive output format: text or json"); println!(" --permission-mode MODE Set read-only, workspace-write, or danger-full-access"); - println!(" --max-cost USD Warn at 80% of budget and stop at/exceeding the budget"); println!(" --allowedTools TOOLS Restrict enabled tools (repeatable; comma-separated aliases supported)"); println!(" --version, -V Print version and build information locally"); println!(); @@ -2529,19 +3011,19 @@ fn print_help() { println!(" rusty-claude-cli --allowedTools read,glob \"summarize Cargo.toml\""); println!(" rusty-claude-cli --resume session.json /status /diff /export notes.txt"); println!(" rusty-claude-cli login"); + println!(" rusty-claude-cli doctor"); } #[cfg(test)] mod tests { use super::{ - budget_notice_message, filter_tool_specs, format_compact_report, format_cost_report, - format_init_report, format_model_report, format_model_switch_report, - format_permissions_report, format_permissions_switch_report, format_resume_report, - format_status_report, format_tool_call_start, format_tool_result, - normalize_permission_mode, parse_args, parse_git_status_metadata, render_config_report, - render_init_claude_md, render_memory_report, render_repl_help, - resume_supported_slash_commands, status_context, CliAction, CliOutputFormat, SlashCommand, - StatusUsage, DEFAULT_MODEL, + filter_tool_specs, format_compact_report, format_cost_report, format_init_report, + format_model_report, format_model_switch_report, format_permissions_report, + format_permissions_switch_report, format_resume_report, format_status_report, + format_tool_call_start, format_tool_result, normalize_permission_mode, parse_args, + parse_git_status_metadata, render_config_report, render_init_claude_md, + render_memory_report, render_repl_help, resume_supported_slash_commands, status_context, + CliAction, CliOutputFormat, SlashCommand, StatusUsage, DEFAULT_MODEL, }; use runtime::{ContentBlock, ConversationMessage, MessageRole, PermissionMode}; use std::path::{Path, PathBuf}; @@ -2554,7 +3036,6 @@ mod tests { model: DEFAULT_MODEL.to_string(), allowed_tools: None, permission_mode: PermissionMode::WorkspaceWrite, - max_cost_usd: None, } ); } @@ -2574,7 +3055,6 @@ mod tests { output_format: CliOutputFormat::Text, allowed_tools: None, permission_mode: PermissionMode::WorkspaceWrite, - max_cost_usd: None, } ); } @@ -2596,7 +3076,6 @@ mod tests { output_format: CliOutputFormat::Json, allowed_tools: None, permission_mode: PermissionMode::WorkspaceWrite, - max_cost_usd: None, } ); } @@ -2622,32 +3101,10 @@ mod tests { model: DEFAULT_MODEL.to_string(), allowed_tools: None, permission_mode: PermissionMode::ReadOnly, - max_cost_usd: None, } ); } - #[test] - fn parses_max_cost_flag() { - let args = vec!["--max-cost=1.25".to_string()]; - assert_eq!( - parse_args(&args).expect("args should parse"), - CliAction::Repl { - model: DEFAULT_MODEL.to_string(), - allowed_tools: None, - permission_mode: PermissionMode::WorkspaceWrite, - max_cost_usd: Some(1.25), - } - ); - } - - #[test] - fn rejects_invalid_max_cost_flag() { - let error = parse_args(&["--max-cost".to_string(), "0".to_string()]) - .expect_err("zero max cost should be rejected"); - assert!(error.contains("--max-cost must be a positive finite USD amount")); - } - #[test] fn parses_allowed_tools_flags_with_aliases_and_lists() { let args = vec![ @@ -2666,7 +3123,6 @@ mod tests { .collect() ), permission_mode: PermissionMode::WorkspaceWrite, - max_cost_usd: None, } ); } @@ -2697,7 +3153,7 @@ mod tests { } #[test] - fn parses_login_and_logout_subcommands() { + fn parses_login_logout_and_doctor_subcommands() { assert_eq!( parse_args(&["login".to_string()]).expect("login should parse"), CliAction::Login @@ -2706,6 +3162,10 @@ mod tests { parse_args(&["logout".to_string()]).expect("logout should parse"), CliAction::Logout ); + assert_eq!( + parse_args(&["doctor".to_string()]).expect("doctor should parse"), + CliAction::Doctor + ); } #[test] @@ -2824,24 +3284,18 @@ mod tests { #[test] fn cost_report_uses_sectioned_layout() { - let report = format_cost_report( - "claude-sonnet", - runtime::TokenUsage { - input_tokens: 20, - output_tokens: 8, - cache_creation_input_tokens: 3, - cache_read_input_tokens: 1, - }, - None, - ); + let report = format_cost_report(runtime::TokenUsage { + input_tokens: 20, + output_tokens: 8, + cache_creation_input_tokens: 3, + cache_read_input_tokens: 1, + }); assert!(report.contains("Cost")); assert!(report.contains("Input tokens 20")); assert!(report.contains("Output tokens 8")); assert!(report.contains("Cache create 3")); assert!(report.contains("Cache read 1")); assert!(report.contains("Total tokens 32")); - assert!(report.contains("Estimated cost")); - assert!(report.contains("Budget $0.0010 (unlimited)")); } #[test] @@ -2923,7 +3377,6 @@ mod tests { project_root: Some(PathBuf::from("/tmp")), git_branch: Some("main".to_string()), }, - Some(1.0), ); assert!(status.contains("Status")); assert!(status.contains("Model claude-sonnet")); @@ -2931,7 +3384,6 @@ mod tests { assert!(status.contains("Messages 7")); assert!(status.contains("Latest total 10")); assert!(status.contains("Cumulative total 31")); - assert!(status.contains("Cost budget $0.0009 / $1.0000")); assert!(status.contains("Cwd /tmp/project")); assert!(status.contains("Project root /tmp")); assert!(status.contains("Git branch main")); @@ -2940,22 +3392,6 @@ mod tests { assert!(status.contains("Memory files 4")); } - #[test] - fn budget_notice_warns_near_limit() { - let message = budget_notice_message( - "claude-sonnet", - runtime::TokenUsage { - input_tokens: 60_000, - output_tokens: 0, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - }, - Some(1.0), - ) - .expect("budget warning expected"); - assert!(message.contains("approaching cost budget")); - } - #[test] fn config_report_supports_section_views() { let report = render_config_report(Some("env")).expect("config report should render"); @@ -2993,8 +3429,8 @@ mod tests { fn status_context_reads_real_workspace_metadata() { let context = status_context(None).expect("status context should load"); assert!(context.cwd.is_absolute()); - assert!(context.discovered_config_files >= context.loaded_config_files); - assert!(context.discovered_config_files >= 1); + assert_eq!(context.discovered_config_files, 5); + assert!(context.loaded_config_files <= context.discovered_config_files); } #[test] @@ -3090,6 +3526,87 @@ mod tests { assert!(help.contains("Shift+Enter/Ctrl+J")); } + #[test] + fn oauth_status_classifies_missing_and_expired_tokens() { + let root = std::env::temp_dir().join(format!( + "doctor-oauth-status-{}", + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::fs::create_dir_all(&root).expect("temp dir"); + std::env::set_var("CLAUDE_CONFIG_HOME", &root); + + assert_eq!( + super::classify_oauth_status() + .expect("missing should classify") + .0, + super::OAuthDiagnosticStatus::Missing + ); + + runtime::save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "token".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(1), + scopes: vec!["scope:a".to_string()], + }) + .expect("save oauth"); + assert_eq!( + super::classify_oauth_status() + .expect("expired should classify") + .0, + super::OAuthDiagnosticStatus::ExpiredRefreshable + ); + + runtime::clear_oauth_credentials().expect("clear oauth"); + std::fs::remove_dir_all(&root).expect("cleanup"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + } + + #[test] + fn config_validation_flags_invalid_json() { + let root = std::env::temp_dir().join(format!( + "doctor-config-{}", + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::fs::create_dir_all(&root).expect("temp dir"); + let path = root.join("settings.json"); + std::fs::write(&path, "[]").expect("write invalid top-level"); + let check = super::validate_config_file(&path); + assert!(check.exists); + assert!(!check.valid); + assert!(check.note.contains("not an object")); + std::fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn doctor_report_renders_requested_sections() { + let report = super::DoctorReport { + checks: vec![ + super::DiagnosticCheck::new( + "API key validity", + super::DiagnosticLevel::Ok, + "accepted", + ), + super::DiagnosticCheck::new( + "System info", + super::DiagnosticLevel::Warn, + "captured", + ) + .with_details(vec!["os=linux".to_string()]), + ], + }; + let rendered = report.render(); + assert!(rendered.contains("Doctor diagnostics")); + assert!(rendered.contains("API key validity")); + assert!(rendered.contains("System info")); + assert!(rendered.contains("Warnings 1")); + } + #[test] fn tool_rendering_helpers_compact_output() { let start = format_tool_call_start("read_file", r#"{"path":"src/main.rs"}"#);