From 0db9660727b6eeebd67c83370edb2f22bbab0d79 Mon Sep 17 00:00:00 2001 From: Yeachan-Heo Date: Wed, 1 Apr 2026 06:50:18 +0000 Subject: [PATCH] feat: plugin subsystem progress --- rust/Cargo.lock | 1 + rust/crates/commands/src/lib.rs | 143 +++++++- rust/crates/plugins/src/hooks.rs | 395 +++++++++++++++++++++ rust/crates/plugins/src/lib.rs | 432 +++++++++++++++++++---- rust/crates/runtime/src/conversation.rs | 68 +++- rust/crates/runtime/src/hooks.rs | 92 ----- rust/crates/rusty-claude-cli/src/main.rs | 141 ++++---- rust/crates/tools/Cargo.toml | 1 + rust/crates/tools/src/lib.rs | 191 ++++++++++ 9 files changed, 1189 insertions(+), 275 deletions(-) create mode 100644 rust/crates/plugins/src/hooks.rs diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 41e2d35..a182255 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1557,6 +1557,7 @@ name = "tools" version = "0.1.0" dependencies = [ "api", + "plugins", "reqwest", "runtime", "serde", diff --git a/rust/crates/commands/src/lib.rs b/rust/crates/commands/src/lib.rs index 8e7ef9d..3caa277 100644 --- a/rust/crates/commands/src/lib.rs +++ b/rust/crates/commands/src/lib.rs @@ -1,4 +1,4 @@ -use plugins::{PluginError, PluginManager, PluginSummary}; +use plugins::{PluginError, PluginKind, PluginManager, PluginSummary}; use runtime::{compact_session, CompactionConfig, Session}; #[derive(Debug, Clone, PartialEq, Eq)] @@ -370,7 +370,7 @@ pub fn handle_plugins_slash_command( ) -> Result { match action { None | Some("list") => Ok(PluginsCommandResult { - message: render_plugins_report(&manager.list_plugins()?), + message: render_plugins_report(&manager.list_installed_plugins()?), reload_runtime: false, }), Some("install") => { @@ -382,7 +382,7 @@ pub fn handle_plugins_slash_command( }; let install = manager.install(target)?; let plugin = manager - .list_plugins()? + .list_installed_plugins()? .into_iter() .find(|plugin| plugin.metadata.id == install.plugin_id); Ok(PluginsCommandResult { @@ -393,14 +393,16 @@ pub fn handle_plugins_slash_command( Some("enable") => { let Some(target) = target else { return Ok(PluginsCommandResult { - message: "Usage: /plugins enable ".to_string(), + message: "Usage: /plugins enable ".to_string(), reload_runtime: false, }); }; - manager.enable(target)?; + let plugin = resolve_plugin_target(manager, target)?; + manager.enable(&plugin.metadata.id)?; Ok(PluginsCommandResult { message: format!( - "Plugins\n Result enabled {target}\n Status enabled" + "Plugins\n Result enabled {}\n Name {}\n Version {}\n Status enabled", + plugin.metadata.id, plugin.metadata.name, plugin.metadata.version ), reload_runtime: true, }) @@ -408,14 +410,16 @@ pub fn handle_plugins_slash_command( Some("disable") => { let Some(target) = target else { return Ok(PluginsCommandResult { - message: "Usage: /plugins disable ".to_string(), + message: "Usage: /plugins disable ".to_string(), reload_runtime: false, }); }; - manager.disable(target)?; + let plugin = resolve_plugin_target(manager, target)?; + manager.disable(&plugin.metadata.id)?; Ok(PluginsCommandResult { message: format!( - "Plugins\n Result disabled {target}\n Status disabled" + "Plugins\n Result disabled {}\n Name {}\n Version {}\n Status disabled", + plugin.metadata.id, plugin.metadata.name, plugin.metadata.version ), reload_runtime: true, }) @@ -442,7 +446,7 @@ pub fn handle_plugins_slash_command( }; let update = manager.update(target)?; let plugin = manager - .list_plugins()? + .list_installed_plugins()? .into_iter() .find(|plugin| plugin.metadata.id == update.plugin_id); Ok(PluginsCommandResult { @@ -474,18 +478,23 @@ pub fn handle_plugins_slash_command( pub fn render_plugins_report(plugins: &[PluginSummary]) -> String { let mut lines = vec!["Plugins".to_string()]; if plugins.is_empty() { - lines.push(" No plugins discovered.".to_string()); + lines.push(" No plugins installed.".to_string()); return lines.join("\n"); } for plugin in plugins { + let kind = match plugin.metadata.kind { + PluginKind::Builtin => "builtin", + PluginKind::Bundled => "bundled", + PluginKind::External => "external", + }; let enabled = if plugin.enabled { "enabled" } else { "disabled" }; lines.push(format!( - " {name:<20} v{version:<10} {enabled}", - name = plugin.metadata.name, + " {id:<24} {kind:<8} {enabled:<8} v{version}", + id = plugin.metadata.id, version = plugin.metadata.version, )); } @@ -502,6 +511,26 @@ fn render_plugin_install_report(plugin_id: &str, plugin: Option<&PluginSummary>) ) } +fn resolve_plugin_target( + manager: &PluginManager, + target: &str, +) -> Result { + let mut matches = manager + .list_installed_plugins()? + .into_iter() + .filter(|plugin| plugin.metadata.id == target || plugin.metadata.name == target) + .collect::>(); + match matches.len() { + 1 => Ok(matches.remove(0)), + 0 => Err(PluginError::NotFound(format!( + "plugin `{target}` is not installed or discoverable" + ))), + _ => Err(PluginError::InvalidManifest(format!( + "plugin name `{target}` is ambiguous; use the full plugin id" + ))), + } +} + #[must_use] pub fn handle_slash_command( input: &str, @@ -560,7 +589,7 @@ mod tests { render_slash_command_help, resume_supported_slash_commands, slash_command_specs, SlashCommand, }; - use plugins::{PluginManager, PluginManagerConfig, PluginMetadata, PluginSummary}; + use plugins::{PluginKind, PluginManager, PluginManagerConfig, PluginMetadata, PluginSummary}; use runtime::{CompactionConfig, ContentBlock, ConversationMessage, MessageRole, Session}; use std::fs; use std::path::{Path, PathBuf}; @@ -585,6 +614,18 @@ mod tests { .expect("write manifest"); } + fn write_bundled_plugin(root: &Path, name: &str, version: &str, default_enabled: bool) { + fs::create_dir_all(root.join(".claude-plugin")).expect("manifest dir"); + fs::write( + root.join(".claude-plugin").join("plugin.json"), + format!( + "{{\n \"name\": \"{name}\",\n \"version\": \"{version}\",\n \"description\": \"bundled commands plugin\",\n \"defaultEnabled\": {}\n}}", + if default_enabled { "true" } else { "false" } + ), + ) + .expect("write bundled manifest"); + } + #[test] fn parses_supported_slash_commands() { assert_eq!(SlashCommand::parse("/help"), Some(SlashCommand::Help)); @@ -839,7 +880,7 @@ mod tests { name: "demo".to_string(), version: "1.2.3".to_string(), description: "demo plugin".to_string(), - kind: plugins::PluginKind::External, + kind: PluginKind::External, source: "demo".to_string(), default_enabled: false, root: None, @@ -852,7 +893,7 @@ mod tests { name: "sample".to_string(), version: "0.9.0".to_string(), description: "sample plugin".to_string(), - kind: plugins::PluginKind::External, + kind: PluginKind::External, source: "sample".to_string(), default_enabled: false, root: None, @@ -861,10 +902,10 @@ mod tests { }, ]); - assert!(rendered.contains("demo")); + assert!(rendered.contains("demo@external")); assert!(rendered.contains("v1.2.3")); assert!(rendered.contains("enabled")); - assert!(rendered.contains("sample")); + assert!(rendered.contains("sample@external")); assert!(rendered.contains("v0.9.0")); assert!(rendered.contains("disabled")); } @@ -891,11 +932,75 @@ mod tests { let list = handle_plugins_slash_command(Some("list"), None, &mut manager) .expect("list command should succeed"); assert!(!list.reload_runtime); - assert!(list.message.contains("demo")); + assert!(list.message.contains("demo@external")); assert!(list.message.contains("v1.0.0")); assert!(list.message.contains("enabled")); let _ = fs::remove_dir_all(config_home); let _ = fs::remove_dir_all(source_root); } + + #[test] + fn enables_and_disables_plugin_by_name() { + let config_home = temp_dir("toggle-home"); + let source_root = temp_dir("toggle-source"); + write_external_plugin(&source_root, "demo", "1.0.0"); + + let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + handle_plugins_slash_command( + Some("install"), + Some(source_root.to_str().expect("utf8 path")), + &mut manager, + ) + .expect("install command should succeed"); + + let disable = handle_plugins_slash_command(Some("disable"), Some("demo"), &mut manager) + .expect("disable command should succeed"); + assert!(disable.reload_runtime); + assert!(disable.message.contains("disabled demo@external")); + assert!(disable.message.contains("Name demo")); + assert!(disable.message.contains("Status disabled")); + + let list = handle_plugins_slash_command(Some("list"), None, &mut manager) + .expect("list command should succeed"); + assert!(list.message.contains("demo@external")); + assert!(list.message.contains("disabled")); + + let enable = handle_plugins_slash_command(Some("enable"), Some("demo"), &mut manager) + .expect("enable command should succeed"); + assert!(enable.reload_runtime); + assert!(enable.message.contains("enabled demo@external")); + assert!(enable.message.contains("Name demo")); + assert!(enable.message.contains("Status enabled")); + + let list = handle_plugins_slash_command(Some("list"), None, &mut manager) + .expect("list command should succeed"); + assert!(list.message.contains("demo@external")); + assert!(list.message.contains("enabled")); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(source_root); + } + + #[test] + fn lists_auto_installed_bundled_plugins_with_status() { + let config_home = temp_dir("bundled-home"); + let bundled_root = temp_dir("bundled-root"); + let bundled_plugin = bundled_root.join("starter"); + write_bundled_plugin(&bundled_plugin, "starter", "0.1.0", false); + + let mut config = PluginManagerConfig::new(&config_home); + config.bundled_root = Some(bundled_root.clone()); + let mut manager = PluginManager::new(config); + + let list = handle_plugins_slash_command(Some("list"), None, &mut manager) + .expect("list command should succeed"); + assert!(!list.reload_runtime); + assert!(list.message.contains("starter@bundled")); + assert!(list.message.contains("bundled")); + assert!(list.message.contains("disabled")); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(bundled_root); + } } diff --git a/rust/crates/plugins/src/hooks.rs b/rust/crates/plugins/src/hooks.rs new file mode 100644 index 0000000..feeb762 --- /dev/null +++ b/rust/crates/plugins/src/hooks.rs @@ -0,0 +1,395 @@ +use std::ffi::OsStr; +use std::path::Path; +use std::process::Command; + +use serde_json::json; + +use crate::{PluginError, PluginHooks, PluginRegistry}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum HookEvent { + PreToolUse, + PostToolUse, +} + +impl HookEvent { + fn as_str(self) -> &'static str { + match self { + Self::PreToolUse => "PreToolUse", + Self::PostToolUse => "PostToolUse", + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HookRunResult { + denied: bool, + messages: Vec, +} + +impl HookRunResult { + #[must_use] + pub fn allow(messages: Vec) -> Self { + Self { + denied: false, + messages, + } + } + + #[must_use] + pub fn is_denied(&self) -> bool { + self.denied + } + + #[must_use] + pub fn messages(&self) -> &[String] { + &self.messages + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct HookRunner { + hooks: PluginHooks, +} + +impl HookRunner { + #[must_use] + pub fn new(hooks: PluginHooks) -> Self { + Self { hooks } + } + + pub fn from_registry(plugin_registry: &PluginRegistry) -> Result { + Ok(Self::new(plugin_registry.aggregated_hooks()?)) + } + + #[must_use] + pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult { + self.run_commands( + HookEvent::PreToolUse, + &self.hooks.pre_tool_use, + tool_name, + tool_input, + None, + false, + ) + } + + #[must_use] + pub fn run_post_tool_use( + &self, + tool_name: &str, + tool_input: &str, + tool_output: &str, + is_error: bool, + ) -> HookRunResult { + self.run_commands( + HookEvent::PostToolUse, + &self.hooks.post_tool_use, + tool_name, + tool_input, + Some(tool_output), + is_error, + ) + } + + fn run_commands( + &self, + event: HookEvent, + commands: &[String], + tool_name: &str, + tool_input: &str, + tool_output: Option<&str>, + is_error: bool, + ) -> HookRunResult { + if commands.is_empty() { + return HookRunResult::allow(Vec::new()); + } + + let payload = json!({ + "hook_event_name": event.as_str(), + "tool_name": tool_name, + "tool_input": parse_tool_input(tool_input), + "tool_input_json": tool_input, + "tool_output": tool_output, + "tool_result_is_error": is_error, + }) + .to_string(); + + let mut messages = Vec::new(); + + for command in commands { + match self.run_command( + command, + event, + tool_name, + tool_input, + tool_output, + is_error, + &payload, + ) { + HookCommandOutcome::Allow { message } => { + if let Some(message) = message { + messages.push(message); + } + } + HookCommandOutcome::Deny { message } => { + messages.push(message.unwrap_or_else(|| { + format!("{} hook denied tool `{tool_name}`", event.as_str()) + })); + return HookRunResult { + denied: true, + messages, + }; + } + HookCommandOutcome::Warn { message } => messages.push(message), + } + } + + HookRunResult::allow(messages) + } + + #[allow(clippy::too_many_arguments)] + fn run_command( + &self, + command: &str, + event: HookEvent, + tool_name: &str, + tool_input: &str, + tool_output: Option<&str>, + is_error: bool, + payload: &str, + ) -> HookCommandOutcome { + let mut child = shell_command(command); + child.stdin(std::process::Stdio::piped()); + child.stdout(std::process::Stdio::piped()); + child.stderr(std::process::Stdio::piped()); + child.env("HOOK_EVENT", event.as_str()); + child.env("HOOK_TOOL_NAME", tool_name); + child.env("HOOK_TOOL_INPUT", tool_input); + child.env("HOOK_TOOL_IS_ERROR", if is_error { "1" } else { "0" }); + if let Some(tool_output) = tool_output { + child.env("HOOK_TOOL_OUTPUT", tool_output); + } + + match child.output_with_stdin(payload.as_bytes()) { + Ok(output) => { + let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + let message = (!stdout.is_empty()).then_some(stdout); + match output.status.code() { + Some(0) => HookCommandOutcome::Allow { message }, + Some(2) => HookCommandOutcome::Deny { message }, + Some(code) => HookCommandOutcome::Warn { + message: format_hook_warning( + command, + code, + message.as_deref(), + stderr.as_str(), + ), + }, + None => HookCommandOutcome::Warn { + message: format!( + "{} hook `{command}` terminated by signal while handling `{tool_name}`", + event.as_str() + ), + }, + } + } + Err(error) => HookCommandOutcome::Warn { + message: format!( + "{} hook `{command}` failed to start for `{tool_name}`: {error}", + event.as_str() + ), + }, + } + } +} + +enum HookCommandOutcome { + Allow { message: Option }, + Deny { message: Option }, + Warn { message: String }, +} + +fn parse_tool_input(tool_input: &str) -> serde_json::Value { + serde_json::from_str(tool_input).unwrap_or_else(|_| json!({ "raw": tool_input })) +} + +fn format_hook_warning(command: &str, code: i32, stdout: Option<&str>, stderr: &str) -> String { + let mut message = + format!("Hook `{command}` exited with status {code}; allowing tool execution to continue"); + if let Some(stdout) = stdout.filter(|stdout| !stdout.is_empty()) { + message.push_str(": "); + message.push_str(stdout); + } else if !stderr.is_empty() { + message.push_str(": "); + message.push_str(stderr); + } + message +} + +fn shell_command(command: &str) -> CommandWithStdin { + #[cfg(windows)] + let command_builder = { + let mut command_builder = Command::new("cmd"); + command_builder.arg("/C").arg(command); + CommandWithStdin::new(command_builder) + }; + + #[cfg(not(windows))] + let command_builder = if Path::new(command).exists() { + let mut command_builder = Command::new("sh"); + command_builder.arg(command); + CommandWithStdin::new(command_builder) + } else { + let mut command_builder = Command::new("sh"); + command_builder.arg("-lc").arg(command); + CommandWithStdin::new(command_builder) + }; + + command_builder +} + +struct CommandWithStdin { + command: Command, +} + +impl CommandWithStdin { + fn new(command: Command) -> Self { + Self { command } + } + + fn stdin(&mut self, cfg: std::process::Stdio) -> &mut Self { + self.command.stdin(cfg); + self + } + + fn stdout(&mut self, cfg: std::process::Stdio) -> &mut Self { + self.command.stdout(cfg); + self + } + + fn stderr(&mut self, cfg: std::process::Stdio) -> &mut Self { + self.command.stderr(cfg); + self + } + + fn env(&mut self, key: K, value: V) -> &mut Self + where + K: AsRef, + V: AsRef, + { + self.command.env(key, value); + self + } + + fn output_with_stdin(&mut self, stdin: &[u8]) -> std::io::Result { + let mut child = self.command.spawn()?; + if let Some(mut child_stdin) = child.stdin.take() { + use std::io::Write as _; + child_stdin.write_all(stdin)?; + } + child.wait_with_output() + } +} + +#[cfg(test)] +mod tests { + use super::{HookRunResult, HookRunner}; + use crate::{PluginManager, PluginManagerConfig}; + use std::fs; + use std::path::{Path, PathBuf}; + use std::time::{SystemTime, UNIX_EPOCH}; + + fn temp_dir(label: &str) -> PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after epoch") + .as_nanos(); + std::env::temp_dir().join(format!("plugins-hook-runner-{label}-{nanos}")) + } + + fn write_hook_plugin(root: &Path, name: &str, pre_message: &str, post_message: &str) { + fs::create_dir_all(root.join(".claude-plugin")).expect("manifest dir"); + fs::create_dir_all(root.join("hooks")).expect("hooks dir"); + fs::write( + root.join("hooks").join("pre.sh"), + format!("#!/bin/sh\nprintf '%s\\n' '{pre_message}'\n"), + ) + .expect("write pre hook"); + fs::write( + root.join("hooks").join("post.sh"), + format!("#!/bin/sh\nprintf '%s\\n' '{post_message}'\n"), + ) + .expect("write post hook"); + fs::write( + root.join(".claude-plugin").join("plugin.json"), + format!( + "{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"hook plugin\",\n \"hooks\": {{\n \"PreToolUse\": [\"./hooks/pre.sh\"],\n \"PostToolUse\": [\"./hooks/post.sh\"]\n }}\n}}" + ), + ) + .expect("write plugin manifest"); + } + + #[test] + fn collects_and_runs_hooks_from_enabled_plugins() { + let config_home = temp_dir("config"); + let first_source_root = temp_dir("source-a"); + let second_source_root = temp_dir("source-b"); + write_hook_plugin( + &first_source_root, + "first", + "plugin pre one", + "plugin post one", + ); + write_hook_plugin( + &second_source_root, + "second", + "plugin pre two", + "plugin post two", + ); + + let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + manager + .install(first_source_root.to_str().expect("utf8 path")) + .expect("first plugin install should succeed"); + manager + .install(second_source_root.to_str().expect("utf8 path")) + .expect("second plugin install should succeed"); + let registry = manager.plugin_registry().expect("registry should build"); + + let runner = HookRunner::from_registry(®istry).expect("plugin hooks should load"); + + assert_eq!( + runner.run_pre_tool_use("Read", r#"{"path":"README.md"}"#), + HookRunResult::allow(vec![ + "plugin pre one".to_string(), + "plugin pre two".to_string(), + ]) + ); + assert_eq!( + runner.run_post_tool_use("Read", r#"{"path":"README.md"}"#, "ok", false), + HookRunResult::allow(vec![ + "plugin post one".to_string(), + "plugin post two".to_string(), + ]) + ); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(first_source_root); + let _ = fs::remove_dir_all(second_source_root); + } + + #[test] + fn pre_tool_use_denies_when_plugin_hook_exits_two() { + let runner = HookRunner::new(crate::PluginHooks { + pre_tool_use: vec!["printf 'blocked by plugin'; exit 2".to_string()], + post_tool_use: Vec::new(), + }); + + let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#); + + assert!(result.is_denied()); + assert_eq!(result.messages(), &["blocked by plugin".to_string()]); + } +} diff --git a/rust/crates/plugins/src/lib.rs b/rust/crates/plugins/src/lib.rs index e539add..68ba2c4 100644 --- a/rust/crates/plugins/src/lib.rs +++ b/rust/crates/plugins/src/lib.rs @@ -1,3 +1,5 @@ +mod hooks; + use std::collections::BTreeMap; use std::fmt::{Display, Formatter}; use std::fs; @@ -8,6 +10,8 @@ use std::time::{SystemTime, UNIX_EPOCH}; use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; +pub use hooks::{HookEvent, HookRunResult, HookRunner}; + const EXTERNAL_MARKETPLACE: &str = "external"; const BUILTIN_MARKETPLACE: &str = "builtin"; const BUNDLED_MARKETPLACE: &str = "bundled"; @@ -15,7 +19,6 @@ const SETTINGS_FILE_NAME: &str = "settings.json"; const REGISTRY_FILE_NAME: &str = "installed.json"; const MANIFEST_FILE_NAME: &str = "plugin.json"; const MANIFEST_RELATIVE_PATH: &str = ".claude-plugin/plugin.json"; -const PACKAGE_MANIFEST_RELATIVE_PATH: &str = MANIFEST_RELATIVE_PATH; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] @@ -35,6 +38,17 @@ impl Display for PluginKind { } } +impl PluginKind { + #[must_use] + fn marketplace(self) -> &'static str { + match self { + Self::Builtin => BUILTIN_MARKETPLACE, + Self::Bundled => BUNDLED_MARKETPLACE, + Self::External => EXTERNAL_MARKETPLACE, + } + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct PluginMetadata { pub id: String, @@ -244,6 +258,8 @@ pub enum PluginInstallSource { #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct InstalledPluginRecord { + #[serde(default = "default_plugin_kind")] + pub kind: PluginKind, pub id: String, pub name: String, pub version: String, @@ -260,6 +276,10 @@ pub struct InstalledPluginRegistry { pub plugins: BTreeMap, } +fn default_plugin_kind() -> PluginKind { + PluginKind::External +} + #[derive(Debug, Clone, PartialEq)] pub struct BuiltinPlugin { metadata: PluginMetadata, @@ -750,10 +770,15 @@ impl PluginManager { Ok(self.plugin_registry()?.summaries()) } + pub fn list_installed_plugins(&self) -> Result, PluginError> { + Ok(self.installed_plugin_registry()?.summaries()) + } + pub fn discover_plugins(&self) -> Result, PluginError> { + self.sync_bundled_plugins()?; let mut plugins = builtin_plugins(); - plugins.extend(self.discover_bundled_plugins()?); - plugins.extend(self.discover_external_plugins()?); + plugins.extend(self.discover_installed_plugins()?); + plugins.extend(self.discover_external_directory_plugins(&plugins)?); Ok(plugins) } @@ -761,6 +786,10 @@ impl PluginManager { self.plugin_registry()?.aggregated_hooks() } + pub fn aggregated_tools(&self) -> Result, PluginError> { + self.plugin_registry()?.aggregated_tools() + } + pub fn validate_plugin_source(&self, source: &str) -> Result { let path = resolve_local_source(source)?; load_plugin_from_directory(&path) @@ -785,6 +814,7 @@ impl PluginManager { let now = unix_time_ms(); let record = InstalledPluginRecord { + kind: PluginKind::External, id: plugin_id.clone(), name: manifest.name, version: manifest.version.clone(), @@ -831,6 +861,12 @@ impl PluginManager { let record = registry.plugins.remove(plugin_id).ok_or_else(|| { PluginError::NotFound(format!("plugin `{plugin_id}` is not installed")) })?; + if record.kind == PluginKind::Bundled { + registry.plugins.insert(plugin_id.to_string(), record); + return Err(PluginError::CommandFailed(format!( + "plugin `{plugin_id}` is bundled and managed automatically; disable it instead" + ))); + } if record.install_path.exists() { fs::remove_dir_all(&record.install_path)?; } @@ -878,40 +914,27 @@ impl PluginManager { }) } - fn discover_bundled_plugins(&self) -> Result, PluginError> { - discover_plugin_dirs( - &self - .config - .bundled_root - .clone() - .unwrap_or_else(Self::bundled_root), - )? - .into_iter() - .map(|root| { - load_plugin_definition( - &root, - PluginKind::Bundled, - format!("{BUNDLED_MARKETPLACE}:{}", root.display()), - BUNDLED_MARKETPLACE, - ) - }) - .collect() - } - - fn discover_external_plugins(&self) -> Result, PluginError> { + fn discover_installed_plugins(&self) -> Result, PluginError> { let registry = self.load_registry()?; - let mut plugins = registry + registry .plugins .values() .map(|record| { load_plugin_definition( &record.install_path, - PluginKind::External, + record.kind, describe_install_source(&record.source), - EXTERNAL_MARKETPLACE, + record.kind.marketplace(), ) }) - .collect::, _>>()?; + .collect() + } + + fn discover_external_directory_plugins( + &self, + existing_plugins: &[PluginDefinition], + ) -> Result, PluginError> { + let mut plugins = Vec::new(); for directory in &self.config.external_dirs { for root in discover_plugin_dirs(directory)? { @@ -921,8 +944,9 @@ impl PluginManager { root.display().to_string(), EXTERNAL_MARKETPLACE, )?; - if plugins + if existing_plugins .iter() + .chain(plugins.iter()) .all(|existing| existing.metadata().id != plugin.metadata().id) { plugins.push(plugin); @@ -933,6 +957,84 @@ impl PluginManager { Ok(plugins) } + fn installed_plugin_registry(&self) -> Result { + self.sync_bundled_plugins()?; + Ok(PluginRegistry::new( + self.discover_installed_plugins()? + .into_iter() + .map(|plugin| { + let enabled = self.is_enabled(plugin.metadata()); + RegisteredPlugin::new(plugin, enabled) + }) + .collect(), + )) + } + + fn sync_bundled_plugins(&self) -> Result<(), PluginError> { + let bundled_root = self + .config + .bundled_root + .clone() + .unwrap_or_else(Self::bundled_root); + let bundled_plugins = discover_plugin_dirs(&bundled_root)?; + if bundled_plugins.is_empty() { + return Ok(()); + } + + let mut registry = self.load_registry()?; + let mut changed = false; + let install_root = self.install_root(); + + for source_root in bundled_plugins { + let manifest = load_validated_package_manifest_from_root(&source_root)?; + let plugin_id = plugin_id(&manifest.name, BUNDLED_MARKETPLACE); + let install_path = install_root.join(sanitize_plugin_id(&plugin_id)); + let now = unix_time_ms(); + let existing_record = registry.plugins.get(&plugin_id); + let needs_sync = existing_record.map_or(true, |record| { + record.kind != PluginKind::Bundled + || record.version != manifest.version + || record.name != manifest.name + || record.description != manifest.description + || record.install_path != install_path + || !record.install_path.exists() + }); + + if !needs_sync { + continue; + } + + if install_path.exists() { + fs::remove_dir_all(&install_path)?; + } + copy_dir_all(&source_root, &install_path)?; + + let installed_at_unix_ms = + existing_record.map_or(now, |record| record.installed_at_unix_ms); + registry.plugins.insert( + plugin_id.clone(), + InstalledPluginRecord { + kind: PluginKind::Bundled, + id: plugin_id, + name: manifest.name, + version: manifest.version, + description: manifest.description, + install_path, + source: PluginInstallSource::LocalPath { path: source_root }, + installed_at_unix_ms, + updated_at_unix_ms: now, + }, + ); + changed = true; + } + + if changed { + self.store_registry(®istry)?; + } + + Ok(()) + } + fn is_enabled(&self, metadata: &PluginMetadata) -> bool { self.config .enabled_plugins @@ -1089,11 +1191,15 @@ fn validate_plugin_manifest(root: &Path, manifest: &PluginManifest) -> Result<() validate_named_strings(&manifest.permissions, "permission")?; validate_hook_paths(Some(root), &manifest.hooks)?; validate_named_commands(root, &manifest.tools, "tool")?; + validate_tool_manifest_entries(&manifest.tools)?; validate_named_commands(root, &manifest.commands, "command")?; Ok(()) } -fn validate_package_manifest(root: &Path, manifest: &PluginPackageManifest) -> Result<(), PluginError> { +fn validate_package_manifest( + root: &Path, + manifest: &PluginPackageManifest, +) -> Result<(), PluginError> { if manifest.name.trim().is_empty() { return Err(PluginError::InvalidManifest( "plugin manifest name cannot be empty".to_string(), @@ -1110,6 +1216,7 @@ fn validate_package_manifest(root: &Path, manifest: &PluginPackageManifest) -> R )); } validate_named_commands(root, &manifest.tools, "tool")?; + validate_tool_manifest_entries(&manifest.tools)?; Ok(()) } @@ -1204,6 +1311,27 @@ fn validate_named_commands( Ok(()) } +fn validate_tool_manifest_entries(entries: &[PluginToolManifest]) -> Result<(), PluginError> { + for entry in entries { + if !entry.input_schema.is_object() { + return Err(PluginError::InvalidManifest(format!( + "plugin tool `{}` inputSchema must be a JSON object", + entry.name + ))); + } + if !matches!( + entry.required_permission.as_str(), + "read-only" | "workspace-write" | "danger-full-access" + ) { + return Err(PluginError::InvalidManifest(format!( + "plugin tool `{}` requiredPermission must be read-only, workspace-write, or danger-full-access", + entry.name + ))); + } + } + Ok(()) +} + trait NamedCommand { fn name(&self) -> &str; fn description(&self) -> &str; @@ -1568,75 +1696,225 @@ mod tests { std::env::temp_dir().join(format!("plugins-{label}-{}", unix_time_ms())) } - fn write_external_plugin(root: &Path, name: &str, version: &str) { - fs::create_dir_all(root.join(".claude-plugin")).expect("manifest dir"); - fs::create_dir_all(root.join("hooks")).expect("hooks dir"); - fs::write( - root.join("hooks").join("pre.sh"), + fn write_file(path: &Path, contents: &str) { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent).expect("parent dir"); + } + fs::write(path, contents).expect("write file"); + } + + fn write_loader_plugin(root: &Path) { + write_file( + root.join("hooks").join("pre.sh").as_path(), "#!/bin/sh\nprintf 'pre'\n", - ) - .expect("write pre hook"); - fs::write( - root.join("hooks").join("post.sh"), + ); + write_file( + root.join("tools").join("echo-tool.sh").as_path(), + "#!/bin/sh\ncat\n", + ); + write_file( + root.join("commands").join("sync.sh").as_path(), + "#!/bin/sh\nprintf 'sync'\n", + ); + write_file( + root.join(MANIFEST_FILE_NAME).as_path(), + r#"{ + "name": "loader-demo", + "version": "1.2.3", + "description": "Manifest loader test plugin", + "permissions": ["read", "write"], + "hooks": { + "PreToolUse": ["./hooks/pre.sh"] + }, + "tools": [ + { + "name": "echo_tool", + "description": "Echoes JSON input", + "inputSchema": { + "type": "object" + }, + "command": "./tools/echo-tool.sh", + "requiredPermission": "workspace-write" + } + ], + "commands": [ + { + "name": "sync", + "description": "Sync command", + "command": "./commands/sync.sh" + } + ] +}"#, + ); + } + + fn write_external_plugin(root: &Path, name: &str, version: &str) { + write_file( + root.join("hooks").join("pre.sh").as_path(), + "#!/bin/sh\nprintf 'pre'\n", + ); + write_file( + root.join("hooks").join("post.sh").as_path(), "#!/bin/sh\nprintf 'post'\n", - ) - .expect("write post hook"); - fs::write( - root.join(MANIFEST_RELATIVE_PATH), + ); + write_file( + root.join(MANIFEST_RELATIVE_PATH).as_path(), format!( "{{\n \"name\": \"{name}\",\n \"version\": \"{version}\",\n \"description\": \"test plugin\",\n \"hooks\": {{\n \"PreToolUse\": [\"./hooks/pre.sh\"],\n \"PostToolUse\": [\"./hooks/post.sh\"]\n }}\n}}" - ), - ) - .expect("write manifest"); + ) + .as_str(), + ); } fn write_broken_plugin(root: &Path, name: &str) { - fs::create_dir_all(root.join(".claude-plugin")).expect("manifest dir"); - fs::write( - root.join(MANIFEST_RELATIVE_PATH), + write_file( + root.join(MANIFEST_RELATIVE_PATH).as_path(), format!( "{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"broken plugin\",\n \"hooks\": {{\n \"PreToolUse\": [\"./hooks/missing.sh\"]\n }}\n}}" - ), - ) - .expect("write broken manifest"); + ) + .as_str(), + ); } fn write_lifecycle_plugin(root: &Path, name: &str, version: &str) -> PathBuf { - fs::create_dir_all(root.join(".claude-plugin")).expect("manifest dir"); - fs::create_dir_all(root.join("lifecycle")).expect("lifecycle dir"); let log_path = root.join("lifecycle.log"); - fs::write( - root.join("lifecycle").join("init.sh"), + write_file( + root.join("lifecycle").join("init.sh").as_path(), "#!/bin/sh\nprintf 'init\\n' >> lifecycle.log\n", - ) - .expect("write init hook"); - fs::write( - root.join("lifecycle").join("shutdown.sh"), + ); + write_file( + root.join("lifecycle").join("shutdown.sh").as_path(), "#!/bin/sh\nprintf 'shutdown\\n' >> lifecycle.log\n", - ) - .expect("write shutdown hook"); - fs::write( - root.join(MANIFEST_RELATIVE_PATH), + ); + write_file( + root.join(MANIFEST_RELATIVE_PATH).as_path(), format!( "{{\n \"name\": \"{name}\",\n \"version\": \"{version}\",\n \"description\": \"lifecycle plugin\",\n \"lifecycle\": {{\n \"Init\": [\"./lifecycle/init.sh\"],\n \"Shutdown\": [\"./lifecycle/shutdown.sh\"]\n }}\n}}" - ), - ) - .expect("write manifest"); + ) + .as_str(), + ); log_path } #[test] - fn validates_manifest_shape() { - let error = validate_manifest(&PluginManifest { - name: String::new(), - version: "1.0.0".to_string(), - description: "desc".to_string(), - default_enabled: false, - hooks: PluginHooks::default(), - lifecycle: PluginLifecycle::default(), - }) - .expect_err("empty name should fail"); + fn load_plugin_from_directory_validates_required_fields() { + let root = temp_dir("manifest-required"); + write_file( + root.join(MANIFEST_FILE_NAME).as_path(), + r#"{"name":"","version":"1.0.0","description":"desc"}"#, + ); + + let error = load_plugin_from_directory(&root).expect_err("empty name should fail"); assert!(error.to_string().contains("name cannot be empty")); + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn load_plugin_from_directory_reads_root_manifest_and_validates_entries() { + let root = temp_dir("manifest-root"); + write_loader_plugin(&root); + + let manifest = load_plugin_from_directory(&root).expect("manifest should load"); + assert_eq!(manifest.name, "loader-demo"); + assert_eq!(manifest.version, "1.2.3"); + assert_eq!(manifest.permissions, vec!["read", "write"]); + assert_eq!(manifest.hooks.pre_tool_use, vec!["./hooks/pre.sh"]); + assert_eq!(manifest.tools.len(), 1); + assert_eq!(manifest.tools[0].name, "echo_tool"); + assert_eq!(manifest.commands.len(), 1); + assert_eq!(manifest.commands[0].name, "sync"); + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn load_plugin_from_directory_supports_packaged_manifest_path() { + let root = temp_dir("manifest-packaged"); + write_external_plugin(&root, "packaged-demo", "1.0.0"); + + let manifest = load_plugin_from_directory(&root).expect("packaged manifest should load"); + assert_eq!(manifest.name, "packaged-demo"); + assert!(manifest.tools.is_empty()); + assert!(manifest.commands.is_empty()); + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn load_plugin_from_directory_defaults_optional_fields() { + let root = temp_dir("manifest-defaults"); + write_file( + root.join(MANIFEST_FILE_NAME).as_path(), + r#"{ + "name": "minimal", + "version": "0.1.0", + "description": "Minimal manifest" +}"#, + ); + + let manifest = load_plugin_from_directory(&root).expect("minimal manifest should load"); + assert!(manifest.permissions.is_empty()); + assert!(manifest.hooks.is_empty()); + assert!(manifest.tools.is_empty()); + assert!(manifest.commands.is_empty()); + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn load_plugin_from_directory_rejects_duplicate_permissions_and_commands() { + let root = temp_dir("manifest-duplicates"); + write_file( + root.join("commands").join("sync.sh").as_path(), + "#!/bin/sh\nprintf 'sync'\n", + ); + write_file( + root.join(MANIFEST_FILE_NAME).as_path(), + r#"{ + "name": "duplicate-manifest", + "version": "1.0.0", + "description": "Duplicate validation", + "permissions": ["read", "read"], + "commands": [ + {"name": "sync", "description": "Sync one", "command": "./commands/sync.sh"}, + {"name": "sync", "description": "Sync two", "command": "./commands/sync.sh"} + ] +}"#, + ); + + let error = load_plugin_from_directory(&root).expect_err("duplicates should fail"); + assert!(error + .to_string() + .contains("permission `read` is duplicated")); + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn load_plugin_from_directory_rejects_missing_tool_or_command_paths() { + let root = temp_dir("manifest-paths"); + write_file( + root.join(MANIFEST_FILE_NAME).as_path(), + r#"{ + "name": "missing-paths", + "version": "1.0.0", + "description": "Missing path validation", + "tools": [ + { + "name": "tool_one", + "description": "Missing tool script", + "inputSchema": {"type": "object"}, + "command": "./tools/missing.sh" + } + ] +}"#, + ); + + let error = load_plugin_from_directory(&root).expect_err("missing paths should fail"); + assert!(error.to_string().contains("does not exist")); + + let _ = fs::remove_dir_all(root); } #[test] diff --git a/rust/crates/runtime/src/conversation.rs b/rust/crates/runtime/src/conversation.rs index c66cd13..7e79f9a 100644 --- a/rust/crates/runtime/src/conversation.rs +++ b/rust/crates/runtime/src/conversation.rs @@ -1,13 +1,13 @@ use std::collections::BTreeMap; use std::fmt::{Display, Formatter}; -use plugins::PluginRegistry; +use plugins::{HookRunner as PluginHookRunner, PluginRegistry}; use crate::compact::{ compact_session, estimate_session_tokens, CompactionConfig, CompactionResult, }; use crate::config::RuntimeFeatureConfig; -use crate::hooks::{HookRunResult, HookRunner}; +use crate::hooks::HookRunner; use crate::permissions::{PermissionOutcome, PermissionPolicy, PermissionPrompter}; use crate::session::{ContentBlock, ConversationMessage, Session}; use crate::usage::{TokenUsage, UsageTracker}; @@ -109,6 +109,7 @@ pub struct ConversationRuntime { usage_tracker: UsageTracker, hook_runner: HookRunner, auto_compaction_input_tokens_threshold: u32, + plugin_hook_runner: Option, plugin_registry: Option, plugins_shutdown: bool, } @@ -172,6 +173,7 @@ where usage_tracker, hook_runner: HookRunner::from_feature_config(&feature_config), auto_compaction_input_tokens_threshold: auto_compaction_threshold_from_env(), + plugin_hook_runner: None, plugin_registry: None, plugins_shutdown: false, } @@ -187,11 +189,8 @@ where feature_config: RuntimeFeatureConfig, plugin_registry: PluginRegistry, ) -> Result { - let hook_runner = - HookRunner::from_feature_config_and_plugins(&feature_config, &plugin_registry) - .map_err(|error| { - RuntimeError::new(format!("plugin hook registration failed: {error}")) - })?; + let plugin_hook_runner = PluginHookRunner::from_registry(&plugin_registry) + .map_err(|error| RuntimeError::new(format!("plugin hook registration failed: {error}")))?; plugin_registry .initialize() .map_err(|error| RuntimeError::new(format!("plugin initialization failed: {error}")))?; @@ -203,7 +202,7 @@ where system_prompt, feature_config, ); - runtime.hook_runner = hook_runner; + runtime.plugin_hook_runner = Some(plugin_hook_runner); runtime.plugin_registry = Some(plugin_registry); Ok(runtime) } @@ -284,16 +283,36 @@ where ConversationMessage::tool_result( tool_use_id, tool_name, - format_hook_message(&pre_hook_result, &deny_message), + format_hook_message(pre_hook_result.messages(), &deny_message), true, ) } else { + let plugin_pre_hook_result = + self.run_plugin_pre_tool_use(&tool_name, &input); + if plugin_pre_hook_result.is_denied() { + let deny_message = + format!("PreToolUse hook denied tool `{tool_name}`"); + ConversationMessage::tool_result( + tool_use_id, + tool_name, + format_hook_message( + plugin_pre_hook_result.messages(), + &deny_message, + ), + true, + ) + } else { let (mut output, mut is_error) = match self.tool_executor.execute(&tool_name, &input) { Ok(output) => (output, false), Err(error) => (error.to_string(), true), }; output = merge_hook_feedback(pre_hook_result.messages(), output, false); + output = merge_hook_feedback( + plugin_pre_hook_result.messages(), + output, + false, + ); let post_hook_result = self .hook_runner @@ -306,6 +325,16 @@ where output, post_hook_result.is_denied(), ); + let plugin_post_hook_result = + self.run_plugin_post_tool_use(&tool_name, &input, &output, is_error); + if plugin_post_hook_result.is_denied() { + is_error = true; + } + output = merge_hook_feedback( + plugin_post_hook_result.messages(), + output, + plugin_post_hook_result.is_denied(), + ); ConversationMessage::tool_result( tool_use_id, @@ -313,6 +342,7 @@ where output, is_error, ) + } } } PermissionOutcome::Deny { reason } => { @@ -365,6 +395,26 @@ where self.shutdown_registered_plugins() } + fn run_plugin_pre_tool_use(&self, tool_name: &str, input: &str) -> plugins::HookRunResult { + self.plugin_hook_runner.as_ref().map_or_else( + || plugins::HookRunResult::allow(Vec::new()), + |runner| runner.run_pre_tool_use(tool_name, input), + ) + } + + fn run_plugin_post_tool_use( + &self, + tool_name: &str, + input: &str, + output: &str, + is_error: bool, + ) -> plugins::HookRunResult { + self.plugin_hook_runner.as_ref().map_or_else( + || plugins::HookRunResult::allow(Vec::new()), + |runner| runner.run_post_tool_use(tool_name, input, output, is_error), + ) + } + fn maybe_auto_compact(&mut self) -> Option { if self.usage_tracker.cumulative_usage().input_tokens < self.auto_compaction_input_tokens_threshold diff --git a/rust/crates/runtime/src/hooks.rs b/rust/crates/runtime/src/hooks.rs index 3e3e8f1..4aff002 100644 --- a/rust/crates/runtime/src/hooks.rs +++ b/rust/crates/runtime/src/hooks.rs @@ -2,7 +2,6 @@ use std::ffi::OsStr; use std::path::Path; use std::process::Command; -use plugins::{PluginError, PluginRegistry}; use serde_json::json; use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig}; @@ -64,19 +63,6 @@ impl HookRunner { Self::new(feature_config.hooks().clone()) } - pub fn from_feature_config_and_plugins( - feature_config: &RuntimeFeatureConfig, - plugin_registry: &PluginRegistry, - ) -> Result { - let mut config = feature_config.hooks().clone(); - let plugin_hooks = plugin_registry.aggregated_hooks()?; - config.extend(&RuntimeHookConfig::new( - plugin_hooks.pre_tool_use, - plugin_hooks.post_tool_use, - )); - Ok(Self::new(config)) - } - #[must_use] pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult { self.run_commands( @@ -313,50 +299,6 @@ impl CommandWithStdin { mod tests { use super::{HookRunResult, HookRunner}; use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig}; - use plugins::{PluginManager, PluginManagerConfig}; - use std::fs; - #[cfg(unix)] - use std::os::unix::fs::PermissionsExt; - use std::path::{Path, PathBuf}; - use std::time::{SystemTime, UNIX_EPOCH}; - - fn temp_dir(label: &str) -> PathBuf { - let nanos = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("time should be after epoch") - .as_nanos(); - std::env::temp_dir().join(format!("hook-runner-{label}-{nanos}")) - } - - fn write_hook_plugin(root: &Path, name: &str) { - fs::create_dir_all(root.join(".claude-plugin")).expect("manifest dir"); - fs::create_dir_all(root.join("hooks")).expect("hooks dir"); - fs::write( - root.join("hooks").join("pre.sh"), - "#!/bin/sh\nprintf 'plugin pre'\n", - ) - .expect("write pre hook"); - fs::write( - root.join("hooks").join("post.sh"), - "#!/bin/sh\nprintf 'plugin post'\n", - ) - .expect("write post hook"); - #[cfg(unix)] - { - let exec_mode = fs::Permissions::from_mode(0o755); - fs::set_permissions(root.join("hooks").join("pre.sh"), exec_mode.clone()) - .expect("chmod pre hook"); - fs::set_permissions(root.join("hooks").join("post.sh"), exec_mode) - .expect("chmod post hook"); - } - fs::write( - root.join(".claude-plugin").join("plugin.json"), - format!( - "{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"hook plugin\",\n \"hooks\": {{\n \"PreToolUse\": [\"./hooks/pre.sh\"],\n \"PostToolUse\": [\"./hooks/post.sh\"]\n }}\n}}" - ), - ) - .expect("write plugin manifest"); - } #[test] fn allows_exit_code_zero_and_captures_stdout() { @@ -401,40 +343,6 @@ mod tests { .any(|message| message.contains("allowing tool execution to continue"))); } - #[test] - fn collects_hooks_from_enabled_plugins() { - let config_home = temp_dir("config"); - let source_root = temp_dir("source"); - write_hook_plugin(&source_root, "hooked"); - - let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); - manager - .install(source_root.to_str().expect("utf8 path")) - .expect("install should succeed"); - let registry = manager.plugin_registry().expect("registry should build"); - - let runner = HookRunner::from_feature_config_and_plugins( - &RuntimeFeatureConfig::default(), - ®istry, - ) - .expect("plugin hooks should load"); - - let pre_result = runner.run_pre_tool_use("Read", r#"{"path":"README.md"}"#); - let post_result = runner.run_post_tool_use("Read", r#"{"path":"README.md"}"#, "ok", false); - - assert_eq!( - pre_result, - HookRunResult::allow(vec!["plugin pre".to_string()]) - ); - assert_eq!( - post_result, - HookRunResult::allow(vec!["plugin post".to_string()]) - ); - - let _ = fs::remove_dir_all(config_home); - let _ = fs::remove_dir_all(source_root); - } - #[cfg(windows)] fn shell_snippet(script: &str) -> String { script.replace('\'', "\"") diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index a16aa2e..2442aae 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -35,7 +35,7 @@ use runtime::{ Session, TokenUsage, ToolError, ToolExecutor, UsageTracker, }; use serde_json::json; -use tools::{execute_tool, mvp_tool_specs, ToolSpec}; +use tools::GlobalToolRegistry; const DEFAULT_MODEL: &str = "claude-opus-4-6"; fn max_tokens_for_model(model: &str) -> u32 { @@ -301,51 +301,20 @@ fn resolve_model_alias(model: &str) -> &str { } fn normalize_allowed_tools(values: &[String]) -> Result, String> { - if values.is_empty() { - return Ok(None); - } - - let canonical_names = mvp_tool_specs() - .into_iter() - .map(|spec| spec.name.to_string()) - .collect::>(); - let mut name_map = canonical_names - .iter() - .map(|name| (normalize_tool_name(name), name.clone())) - .collect::>(); - - for (alias, canonical) in [ - ("read", "read_file"), - ("write", "write_file"), - ("edit", "edit_file"), - ("glob", "glob_search"), - ("grep", "grep_search"), - ] { - name_map.insert(alias.to_string(), canonical.to_string()); - } - - let mut allowed = AllowedToolSet::new(); - for value in values { - for token in value - .split(|ch: char| ch == ',' || ch.is_whitespace()) - .filter(|token| !token.is_empty()) - { - let normalized = normalize_tool_name(token); - let canonical = name_map.get(&normalized).ok_or_else(|| { - format!( - "unsupported tool in --allowedTools: {token} (expected one of: {})", - canonical_names.join(", ") - ) - })?; - allowed.insert(canonical.clone()); - } - } - - Ok(Some(allowed)) + current_tool_registry() + .unwrap_or_else(|_| GlobalToolRegistry::builtin()) + .normalize_allowed_tools(values) } -fn normalize_tool_name(value: &str) -> String { - value.trim().replace('-', "_").to_ascii_lowercase() +fn current_tool_registry() -> Result { + let cwd = env::current_dir().map_err(|error| error.to_string())?; + let loader = ConfigLoader::default_for(&cwd); + let runtime_config = loader.load().map_err(|error| error.to_string())?; + let plugin_manager = build_plugin_manager(&cwd, &loader, &runtime_config); + let plugin_tools = plugin_manager + .aggregated_tools() + .map_err(|error| error.to_string())?; + GlobalToolRegistry::with_plugin_tools(plugin_tools) } fn parse_permission_mode_arg(value: &str) -> Result { @@ -375,11 +344,11 @@ fn default_permission_mode() -> PermissionMode { .map_or(PermissionMode::DangerFullAccess, permission_mode_from_label) } -fn filter_tool_specs(allowed_tools: Option<&AllowedToolSet>) -> Vec { - mvp_tool_specs() - .into_iter() - .filter(|spec| allowed_tools.is_none_or(|allowed| allowed.contains(spec.name))) - .collect() +fn filter_tool_specs( + tool_registry: &GlobalToolRegistry, + allowed_tools: Option<&AllowedToolSet>, +) -> Vec { + tool_registry.definitions(allowed_tools) } fn parse_system_prompt_args(args: &[String]) -> Result { @@ -2347,14 +2316,25 @@ fn build_system_prompt() -> Result, Box> { )?) } -fn build_runtime_plugin_state( -) -> Result<(runtime::RuntimeFeatureConfig, PluginRegistry), Box> { +fn build_runtime_plugin_state() -> Result< + ( + runtime::RuntimeFeatureConfig, + PluginRegistry, + GlobalToolRegistry, + ), + Box, +> { let cwd = env::current_dir()?; let loader = ConfigLoader::default_for(&cwd); let runtime_config = loader.load()?; let plugin_manager = build_plugin_manager(&cwd, &loader, &runtime_config); let plugin_registry = plugin_manager.plugin_registry()?; - Ok((runtime_config.feature_config().clone(), plugin_registry)) + let tool_registry = GlobalToolRegistry::with_plugin_tools(plugin_registry.aggregated_tools()?)?; + Ok(( + runtime_config.feature_config().clone(), + plugin_registry, + tool_registry, + )) } fn build_plugin_manager( @@ -2404,12 +2384,18 @@ fn build_runtime( permission_mode: PermissionMode, ) -> Result, Box> { - let (feature_config, plugin_registry) = build_runtime_plugin_state()?; + let (feature_config, plugin_registry, tool_registry) = build_runtime_plugin_state()?; Ok(ConversationRuntime::new_with_plugins( session, - AnthropicRuntimeClient::new(model, enable_tools, emit_output, allowed_tools.clone())?, - CliToolExecutor::new(allowed_tools, emit_output), - permission_policy(permission_mode), + AnthropicRuntimeClient::new( + model, + enable_tools, + emit_output, + allowed_tools.clone(), + tool_registry.clone(), + )?, + CliToolExecutor::new(allowed_tools.clone(), emit_output, tool_registry.clone()), + permission_policy(permission_mode, &tool_registry), system_prompt, feature_config, plugin_registry, @@ -2469,6 +2455,7 @@ struct AnthropicRuntimeClient { enable_tools: bool, emit_output: bool, allowed_tools: Option, + tool_registry: GlobalToolRegistry, } impl AnthropicRuntimeClient { @@ -2477,6 +2464,7 @@ impl AnthropicRuntimeClient { enable_tools: bool, emit_output: bool, allowed_tools: Option, + tool_registry: GlobalToolRegistry, ) -> Result> { Ok(Self { runtime: tokio::runtime::Runtime::new()?, @@ -2486,6 +2474,7 @@ impl AnthropicRuntimeClient { enable_tools, emit_output, allowed_tools, + tool_registry, }) } } @@ -2508,16 +2497,9 @@ impl ApiClient for AnthropicRuntimeClient { max_tokens: max_tokens_for_model(&self.model), messages: convert_messages(&request.messages), system: (!request.system_prompt.is_empty()).then(|| request.system_prompt.join("\n\n")), - tools: self.enable_tools.then(|| { - filter_tool_specs(self.allowed_tools.as_ref()) - .into_iter() - .map(|spec| ToolDefinition { - name: spec.name.to_string(), - description: Some(spec.description.to_string()), - input_schema: spec.input_schema, - }) - .collect() - }), + tools: self + .enable_tools + .then(|| filter_tool_specs(&self.tool_registry, self.allowed_tools.as_ref())), tool_choice: self.enable_tools.then_some(ToolChoice::Auto), stream: true, }; @@ -3108,14 +3090,20 @@ struct CliToolExecutor { renderer: TerminalRenderer, emit_output: bool, allowed_tools: Option, + tool_registry: GlobalToolRegistry, } impl CliToolExecutor { - fn new(allowed_tools: Option, emit_output: bool) -> Self { + fn new( + allowed_tools: Option, + emit_output: bool, + tool_registry: GlobalToolRegistry, + ) -> Self { Self { renderer: TerminalRenderer::new(), emit_output, allowed_tools, + tool_registry, } } } @@ -3133,7 +3121,7 @@ impl ToolExecutor for CliToolExecutor { } let value = serde_json::from_str(input) .map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?; - match execute_tool(tool_name, &value) { + match self.tool_registry.execute(tool_name, &value) { Ok(output) => { if self.emit_output { let markdown = format_tool_result(tool_name, &output, false); @@ -3156,16 +3144,13 @@ impl ToolExecutor for CliToolExecutor { } } -fn permission_policy(mode: PermissionMode) -> PermissionPolicy { - tool_permission_specs() - .into_iter() - .fold(PermissionPolicy::new(mode), |policy, spec| { - policy.with_tool_requirement(spec.name, spec.required_permission) - }) -} - -fn tool_permission_specs() -> Vec { - mvp_tool_specs() +fn permission_policy(mode: PermissionMode, tool_registry: &GlobalToolRegistry) -> PermissionPolicy { + tool_registry.permission_specs(None).into_iter().fold( + PermissionPolicy::new(mode), + |policy, (name, required_permission)| { + policy.with_tool_requirement(name, required_permission) + }, + ) } fn convert_messages(messages: &[ConversationMessage]) -> Vec { diff --git a/rust/crates/tools/Cargo.toml b/rust/crates/tools/Cargo.toml index dfa003d..9ecbb06 100644 --- a/rust/crates/tools/Cargo.toml +++ b/rust/crates/tools/Cargo.toml @@ -7,6 +7,7 @@ publish.workspace = true [dependencies] api = { path = "../api" } +plugins = { path = "../plugins" } runtime = { path = "../runtime" } reqwest = { version = "0.12", default-features = false, features = ["blocking", "rustls-tls"] } serde = { version = "1", features = ["derive"] } diff --git a/rust/crates/tools/src/lib.rs b/rust/crates/tools/src/lib.rs index 4071c9b..79294de 100644 --- a/rust/crates/tools/src/lib.rs +++ b/rust/crates/tools/src/lib.rs @@ -8,6 +8,7 @@ use api::{ MessageRequest, MessageResponse, OutputContentBlock, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, }; +use plugins::PluginTool; use reqwest::blocking::Client; use runtime::{ edit_file, execute_bash, glob_search, grep_search, load_system_prompt, read_file, write_file, @@ -55,6 +56,196 @@ pub struct ToolSpec { pub required_permission: PermissionMode, } +#[derive(Debug, Clone, PartialEq)] +pub struct RegisteredTool { + pub definition: ToolDefinition, + pub required_permission: PermissionMode, + handler: RegisteredToolHandler, +} + +#[derive(Debug, Clone, PartialEq)] +enum RegisteredToolHandler { + Builtin, + Plugin(PluginTool), +} + +#[derive(Debug, Clone, PartialEq)] +pub struct GlobalToolRegistry { + entries: Vec, +} + +impl GlobalToolRegistry { + #[must_use] + pub fn builtin() -> Self { + Self { + entries: mvp_tool_specs() + .into_iter() + .map(|spec| RegisteredTool { + definition: ToolDefinition { + name: spec.name.to_string(), + description: Some(spec.description.to_string()), + input_schema: spec.input_schema, + }, + required_permission: spec.required_permission, + handler: RegisteredToolHandler::Builtin, + }) + .collect(), + } + } + + pub fn with_plugin_tools(plugin_tools: Vec) -> Result { + let mut registry = Self::builtin(); + let mut seen = registry + .entries + .iter() + .map(|entry| { + ( + normalize_registry_tool_name(&entry.definition.name), + entry.definition.name.clone(), + ) + }) + .collect::>(); + + for tool in plugin_tools { + let normalized = normalize_registry_tool_name(&tool.definition().name); + if let Some(existing) = seen.get(&normalized) { + return Err(format!( + "plugin tool `{}` from `{}` conflicts with already-registered tool `{existing}`", + tool.definition().name, + tool.plugin_id() + )); + } + seen.insert(normalized, tool.definition().name.clone()); + registry.entries.push(RegisteredTool { + definition: ToolDefinition { + name: tool.definition().name.clone(), + description: tool.definition().description.clone(), + input_schema: tool.definition().input_schema.clone(), + }, + required_permission: permission_mode_from_plugin_tool(tool.required_permission())?, + handler: RegisteredToolHandler::Plugin(tool), + }); + } + + Ok(registry) + } + + #[must_use] + pub fn entries(&self) -> &[RegisteredTool] { + &self.entries + } + + #[must_use] + pub fn definitions(&self, allowed_tools: Option<&BTreeSet>) -> Vec { + self.entries + .iter() + .filter(|entry| { + allowed_tools.is_none_or(|allowed| allowed.contains(entry.definition.name.as_str())) + }) + .map(|entry| entry.definition.clone()) + .collect() + } + + #[must_use] + pub fn permission_specs( + &self, + allowed_tools: Option<&BTreeSet>, + ) -> Vec<(String, PermissionMode)> { + self.entries + .iter() + .filter(|entry| { + allowed_tools.is_none_or(|allowed| allowed.contains(entry.definition.name.as_str())) + }) + .map(|entry| (entry.definition.name.clone(), entry.required_permission)) + .collect() + } + + pub fn normalize_allowed_tools( + &self, + values: &[String], + ) -> Result>, String> { + if values.is_empty() { + return Ok(None); + } + + let canonical_names = self + .entries + .iter() + .map(|entry| entry.definition.name.clone()) + .collect::>(); + let mut name_map = canonical_names + .iter() + .map(|name| (normalize_registry_tool_name(name), name.clone())) + .collect::>(); + + for (alias, canonical) in [ + ("read", "read_file"), + ("write", "write_file"), + ("edit", "edit_file"), + ("glob", "glob_search"), + ("grep", "grep_search"), + ] { + if canonical_names.iter().any(|name| name == canonical) { + name_map.insert(alias.to_string(), canonical.to_string()); + } + } + + let mut allowed = BTreeSet::new(); + for value in values { + for token in value + .split(|ch: char| ch == ',' || ch.is_whitespace()) + .filter(|token| !token.is_empty()) + { + let normalized = normalize_registry_tool_name(token); + let canonical = name_map.get(&normalized).ok_or_else(|| { + format!( + "unsupported tool in --allowedTools: {token} (expected one of: {})", + canonical_names.join(", ") + ) + })?; + allowed.insert(canonical.clone()); + } + } + + Ok(Some(allowed)) + } + + pub fn execute(&self, name: &str, input: &Value) -> Result { + let entry = self + .entries + .iter() + .find(|entry| entry.definition.name == name) + .ok_or_else(|| format!("unsupported tool: {name}"))?; + match &entry.handler { + RegisteredToolHandler::Builtin => execute_tool(name, input), + RegisteredToolHandler::Plugin(tool) => { + tool.execute(input).map_err(|error| error.to_string()) + } + } + } +} + +impl Default for GlobalToolRegistry { + fn default() -> Self { + Self::builtin() + } +} + +fn normalize_registry_tool_name(value: &str) -> String { + value.trim().replace('-', "_").to_ascii_lowercase() +} + +fn permission_mode_from_plugin_tool(value: &str) -> Result { + match value { + "read-only" => Ok(PermissionMode::ReadOnly), + "workspace-write" => Ok(PermissionMode::WorkspaceWrite), + "danger-full-access" => Ok(PermissionMode::DangerFullAccess), + other => Err(format!( + "unsupported plugin tool permission `{other}` (expected read-only, workspace-write, or danger-full-access)" + )), + } +} + #[must_use] #[allow(clippy::too_many_lines)] pub fn mvp_tool_specs() -> Vec {