Merge remote-tracking branch 'origin/rcc/runtime' into dev/rust
This commit is contained in:
72
rust/Cargo.lock
generated
72
rust/Cargo.lock
generated
@@ -54,6 +54,15 @@ version = "2.11.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af"
|
checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "block-buffer"
|
||||||
|
version = "0.10.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
|
||||||
|
dependencies = [
|
||||||
|
"generic-array",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bumpalo"
|
name = "bumpalo"
|
||||||
version = "3.20.2"
|
version = "3.20.2"
|
||||||
@@ -104,6 +113,15 @@ dependencies = [
|
|||||||
"tools",
|
"tools",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cpufeatures"
|
||||||
|
version = "0.2.17"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crc32fast"
|
name = "crc32fast"
|
||||||
version = "1.5.0"
|
version = "1.5.0"
|
||||||
@@ -138,6 +156,16 @@ dependencies = [
|
|||||||
"winapi",
|
"winapi",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crypto-common"
|
||||||
|
version = "0.1.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a"
|
||||||
|
dependencies = [
|
||||||
|
"generic-array",
|
||||||
|
"typenum",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "deranged"
|
name = "deranged"
|
||||||
version = "0.5.8"
|
version = "0.5.8"
|
||||||
@@ -147,6 +175,16 @@ dependencies = [
|
|||||||
"powerfmt",
|
"powerfmt",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "digest"
|
||||||
|
version = "0.10.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
|
||||||
|
dependencies = [
|
||||||
|
"block-buffer",
|
||||||
|
"crypto-common",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "displaydoc"
|
name = "displaydoc"
|
||||||
version = "0.2.5"
|
version = "0.2.5"
|
||||||
@@ -238,6 +276,16 @@ dependencies = [
|
|||||||
"slab",
|
"slab",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "generic-array"
|
||||||
|
version = "0.14.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
|
||||||
|
dependencies = [
|
||||||
|
"typenum",
|
||||||
|
"version_check",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "getopts"
|
name = "getopts"
|
||||||
version = "0.2.24"
|
version = "0.2.24"
|
||||||
@@ -950,6 +998,7 @@ dependencies = [
|
|||||||
"regex",
|
"regex",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"sha2",
|
||||||
"tokio",
|
"tokio",
|
||||||
"walkdir",
|
"walkdir",
|
||||||
]
|
]
|
||||||
@@ -1106,6 +1155,17 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sha2"
|
||||||
|
version = "0.10.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"cpufeatures",
|
||||||
|
"digest",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "shlex"
|
name = "shlex"
|
||||||
version = "1.3.0"
|
version = "1.3.0"
|
||||||
@@ -1427,6 +1487,12 @@ version = "0.2.5"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "typenum"
|
||||||
|
version = "1.19.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicase"
|
name = "unicase"
|
||||||
version = "2.9.0"
|
version = "2.9.0"
|
||||||
@@ -1469,6 +1535,12 @@ version = "1.0.4"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
|
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "version_check"
|
||||||
|
version = "0.9.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "walkdir"
|
name = "walkdir"
|
||||||
version = "2.5.0"
|
version = "2.5.0"
|
||||||
|
|||||||
@@ -15,11 +15,90 @@ const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200);
|
|||||||
const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2);
|
const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2);
|
||||||
const DEFAULT_MAX_RETRIES: u32 = 2;
|
const DEFAULT_MAX_RETRIES: u32 = 2;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum AuthSource {
|
||||||
|
None,
|
||||||
|
ApiKey(String),
|
||||||
|
BearerToken(String),
|
||||||
|
ApiKeyAndBearer {
|
||||||
|
api_key: String,
|
||||||
|
bearer_token: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AuthSource {
|
||||||
|
pub fn from_env() -> Result<Self, ApiError> {
|
||||||
|
let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?;
|
||||||
|
let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?;
|
||||||
|
match (api_key, auth_token) {
|
||||||
|
(Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer {
|
||||||
|
api_key,
|
||||||
|
bearer_token,
|
||||||
|
}),
|
||||||
|
(Some(api_key), None) => Ok(Self::ApiKey(api_key)),
|
||||||
|
(None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)),
|
||||||
|
(None, None) => Err(ApiError::MissingApiKey),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn api_key(&self) -> Option<&str> {
|
||||||
|
match self {
|
||||||
|
Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key),
|
||||||
|
Self::None | Self::BearerToken(_) => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn bearer_token(&self) -> Option<&str> {
|
||||||
|
match self {
|
||||||
|
Self::BearerToken(token)
|
||||||
|
| Self::ApiKeyAndBearer {
|
||||||
|
bearer_token: token,
|
||||||
|
..
|
||||||
|
} => Some(token),
|
||||||
|
Self::None | Self::ApiKey(_) => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn masked_authorization_header(&self) -> &'static str {
|
||||||
|
if self.bearer_token().is_some() {
|
||||||
|
"Bearer [REDACTED]"
|
||||||
|
} else {
|
||||||
|
"<absent>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
|
||||||
|
if let Some(api_key) = self.api_key() {
|
||||||
|
request_builder = request_builder.header("x-api-key", api_key);
|
||||||
|
}
|
||||||
|
if let Some(token) = self.bearer_token() {
|
||||||
|
request_builder = request_builder.bearer_auth(token);
|
||||||
|
}
|
||||||
|
request_builder
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct OAuthTokenSet {
|
||||||
|
pub access_token: String,
|
||||||
|
pub refresh_token: Option<String>,
|
||||||
|
pub expires_at: Option<u64>,
|
||||||
|
pub scopes: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<OAuthTokenSet> for AuthSource {
|
||||||
|
fn from(value: OAuthTokenSet) -> Self {
|
||||||
|
Self::BearerToken(value.access_token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct AnthropicClient {
|
pub struct AnthropicClient {
|
||||||
http: reqwest::Client,
|
http: reqwest::Client,
|
||||||
api_key: String,
|
auth: AuthSource,
|
||||||
auth_token: Option<String>,
|
|
||||||
base_url: String,
|
base_url: String,
|
||||||
max_retries: u32,
|
max_retries: u32,
|
||||||
initial_backoff: Duration,
|
initial_backoff: Duration,
|
||||||
@@ -31,8 +110,19 @@ impl AnthropicClient {
|
|||||||
pub fn new(api_key: impl Into<String>) -> Self {
|
pub fn new(api_key: impl Into<String>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
http: reqwest::Client::new(),
|
http: reqwest::Client::new(),
|
||||||
api_key: api_key.into(),
|
auth: AuthSource::ApiKey(api_key.into()),
|
||||||
auth_token: None,
|
base_url: DEFAULT_BASE_URL.to_string(),
|
||||||
|
max_retries: DEFAULT_MAX_RETRIES,
|
||||||
|
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
||||||
|
max_backoff: DEFAULT_MAX_BACKOFF,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_auth(auth: AuthSource) -> Self {
|
||||||
|
Self {
|
||||||
|
http: reqwest::Client::new(),
|
||||||
|
auth,
|
||||||
base_url: DEFAULT_BASE_URL.to_string(),
|
base_url: DEFAULT_BASE_URL.to_string(),
|
||||||
max_retries: DEFAULT_MAX_RETRIES,
|
max_retries: DEFAULT_MAX_RETRIES,
|
||||||
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
||||||
@@ -41,14 +131,37 @@ impl AnthropicClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_env() -> Result<Self, ApiError> {
|
pub fn from_env() -> Result<Self, ApiError> {
|
||||||
Ok(Self::new(read_api_key()?)
|
Ok(Self::from_auth(AuthSource::from_env()?).with_base_url(read_base_url()))
|
||||||
.with_auth_token(read_auth_token())
|
}
|
||||||
.with_base_url(read_base_url()))
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn with_auth_source(mut self, auth: AuthSource) -> Self {
|
||||||
|
self.auth = auth;
|
||||||
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn with_auth_token(mut self, auth_token: Option<String>) -> Self {
|
pub fn with_auth_token(mut self, auth_token: Option<String>) -> Self {
|
||||||
self.auth_token = auth_token.filter(|token| !token.is_empty());
|
match (
|
||||||
|
self.auth.api_key().map(ToOwned::to_owned),
|
||||||
|
auth_token.filter(|token| !token.is_empty()),
|
||||||
|
) {
|
||||||
|
(Some(api_key), Some(bearer_token)) => {
|
||||||
|
self.auth = AuthSource::ApiKeyAndBearer {
|
||||||
|
api_key,
|
||||||
|
bearer_token,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
(Some(api_key), None) => {
|
||||||
|
self.auth = AuthSource::ApiKey(api_key);
|
||||||
|
}
|
||||||
|
(None, Some(bearer_token)) => {
|
||||||
|
self.auth = AuthSource::BearerToken(bearer_token);
|
||||||
|
}
|
||||||
|
(None, None) => {
|
||||||
|
self.auth = AuthSource::None;
|
||||||
|
}
|
||||||
|
}
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -71,6 +184,11 @@ impl AnthropicClient {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn auth_source(&self) -> &AuthSource {
|
||||||
|
&self.auth
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn send_message(
|
pub async fn send_message(
|
||||||
&self,
|
&self,
|
||||||
request: &MessageRequest,
|
request: &MessageRequest,
|
||||||
@@ -151,28 +269,25 @@ impl AnthropicClient {
|
|||||||
let resolved_base_url = self.base_url.trim_end_matches('/');
|
let resolved_base_url = self.base_url.trim_end_matches('/');
|
||||||
eprintln!("[anthropic-client] resolved_base_url={resolved_base_url}");
|
eprintln!("[anthropic-client] resolved_base_url={resolved_base_url}");
|
||||||
eprintln!("[anthropic-client] request_url={request_url}");
|
eprintln!("[anthropic-client] request_url={request_url}");
|
||||||
let mut request_builder = self
|
let request_builder = self
|
||||||
.http
|
.http
|
||||||
.post(&request_url)
|
.post(&request_url)
|
||||||
.header("x-api-key", &self.api_key)
|
|
||||||
.header("anthropic-version", ANTHROPIC_VERSION)
|
.header("anthropic-version", ANTHROPIC_VERSION)
|
||||||
.header("content-type", "application/json");
|
.header("content-type", "application/json");
|
||||||
|
let mut request_builder = self.auth.apply(request_builder);
|
||||||
|
|
||||||
let auth_header = self
|
eprintln!(
|
||||||
.auth_token
|
"[anthropic-client] headers x-api-key={} authorization={} anthropic-version={ANTHROPIC_VERSION} content-type=application/json",
|
||||||
.as_ref()
|
if self.auth.api_key().is_some() {
|
||||||
.map_or("<absent>", |_| "Bearer [REDACTED]");
|
"[REDACTED]"
|
||||||
eprintln!("[anthropic-client] headers x-api-key=[REDACTED] authorization={auth_header} anthropic-version={ANTHROPIC_VERSION} content-type=application/json");
|
} else {
|
||||||
|
"<absent>"
|
||||||
|
},
|
||||||
|
self.auth.masked_authorization_header()
|
||||||
|
);
|
||||||
|
|
||||||
if let Some(auth_token) = &self.auth_token {
|
request_builder = request_builder.json(request);
|
||||||
request_builder = request_builder.bearer_auth(auth_token);
|
request_builder.send().await.map_err(ApiError::from)
|
||||||
}
|
|
||||||
|
|
||||||
request_builder
|
|
||||||
.json(request)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(ApiError::from)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
|
fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
|
||||||
@@ -189,24 +304,28 @@ impl AnthropicClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_api_key() -> Result<String, ApiError> {
|
fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
|
||||||
match std::env::var("ANTHROPIC_API_KEY") {
|
match std::env::var(key) {
|
||||||
Ok(api_key) if !api_key.is_empty() => Ok(api_key),
|
Ok(value) if !value.is_empty() => Ok(Some(value)),
|
||||||
Ok(_) => Err(ApiError::MissingApiKey),
|
Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
|
||||||
Err(std::env::VarError::NotPresent) => match std::env::var("ANTHROPIC_AUTH_TOKEN") {
|
|
||||||
Ok(api_key) if !api_key.is_empty() => Ok(api_key),
|
|
||||||
Ok(_) | Err(std::env::VarError::NotPresent) => Err(ApiError::MissingApiKey),
|
|
||||||
Err(error) => Err(ApiError::from(error)),
|
|
||||||
},
|
|
||||||
Err(error) => Err(ApiError::from(error)),
|
Err(error) => Err(ApiError::from(error)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn read_api_key() -> Result<String, ApiError> {
|
||||||
|
let auth = AuthSource::from_env()?;
|
||||||
|
auth.api_key()
|
||||||
|
.or_else(|| auth.bearer_token())
|
||||||
|
.map(ToOwned::to_owned)
|
||||||
|
.ok_or(ApiError::MissingApiKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
fn read_auth_token() -> Option<String> {
|
fn read_auth_token() -> Option<String> {
|
||||||
match std::env::var("ANTHROPIC_AUTH_TOKEN") {
|
read_env_non_empty("ANTHROPIC_AUTH_TOKEN")
|
||||||
Ok(token) if !token.is_empty() => Some(token),
|
.ok()
|
||||||
_ => None,
|
.and_then(std::convert::identity)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_base_url() -> String {
|
fn read_base_url() -> String {
|
||||||
@@ -308,14 +427,14 @@ mod tests {
|
|||||||
use std::sync::{Mutex, OnceLock};
|
use std::sync::{Mutex, OnceLock};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use crate::client::{AuthSource, OAuthTokenSet};
|
||||||
use crate::types::{ContentBlockDelta, MessageRequest};
|
use crate::types::{ContentBlockDelta, MessageRequest};
|
||||||
|
|
||||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||||
static ENV_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||||
ENV_LOCK
|
LOCK.get_or_init(|| Mutex::new(()))
|
||||||
.get_or_init(|| Mutex::new(()))
|
|
||||||
.lock()
|
.lock()
|
||||||
.expect("env lock should not be poisoned")
|
.expect("env lock")
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -357,6 +476,30 @@ mod tests {
|
|||||||
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn oauth_token_maps_to_bearer_auth_source() {
|
||||||
|
let auth = AuthSource::from(OAuthTokenSet {
|
||||||
|
access_token: "access-token".to_string(),
|
||||||
|
refresh_token: Some("refresh".to_string()),
|
||||||
|
expires_at: Some(123),
|
||||||
|
scopes: vec!["scope:a".to_string()],
|
||||||
|
});
|
||||||
|
assert_eq!(auth.bearer_token(), Some("access-token"));
|
||||||
|
assert_eq!(auth.api_key(), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_source_from_env_combines_api_key_and_bearer_token() {
|
||||||
|
let _guard = env_lock();
|
||||||
|
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
|
||||||
|
std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
|
||||||
|
let auth = AuthSource::from_env().expect("env auth");
|
||||||
|
assert_eq!(auth.api_key(), Some("legacy-key"));
|
||||||
|
assert_eq!(auth.bearer_token(), Some("auth-token"));
|
||||||
|
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||||
|
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn message_request_stream_helper_sets_stream_true() {
|
fn message_request_stream_helper_sets_stream_true() {
|
||||||
let request = MessageRequest {
|
let request = MessageRequest {
|
||||||
@@ -436,4 +579,25 @@ mod tests {
|
|||||||
Some("req_fallback")
|
Some("req_fallback")
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_source_applies_headers() {
|
||||||
|
let auth = AuthSource::ApiKeyAndBearer {
|
||||||
|
api_key: "test-key".to_string(),
|
||||||
|
bearer_token: "proxy-token".to_string(),
|
||||||
|
};
|
||||||
|
let request = auth
|
||||||
|
.apply(reqwest::Client::new().post("https://example.test"))
|
||||||
|
.build()
|
||||||
|
.expect("request build");
|
||||||
|
let headers = request.headers();
|
||||||
|
assert_eq!(
|
||||||
|
headers.get("x-api-key").and_then(|v| v.to_str().ok()),
|
||||||
|
Some("test-key")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
headers.get("authorization").and_then(|v| v.to_str().ok()),
|
||||||
|
Some("Bearer proxy-token")
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ mod error;
|
|||||||
mod sse;
|
mod sse;
|
||||||
mod types;
|
mod types;
|
||||||
|
|
||||||
pub use client::{AnthropicClient, MessageStream};
|
pub use client::{AnthropicClient, AuthSource, MessageStream, OAuthTokenSet};
|
||||||
pub use error::ApiError;
|
pub use error::ApiError;
|
||||||
pub use sse::{parse_frame, SseParser};
|
pub use sse::{parse_frame, SseParser};
|
||||||
pub use types::{
|
pub use types::{
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ license.workspace = true
|
|||||||
publish.workspace = true
|
publish.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
sha2 = "0.10"
|
||||||
glob = "0.3"
|
glob = "0.3"
|
||||||
regex = "1"
|
regex = "1"
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
|||||||
@@ -24,6 +24,95 @@ pub struct ConfigEntry {
|
|||||||
pub struct RuntimeConfig {
|
pub struct RuntimeConfig {
|
||||||
merged: BTreeMap<String, JsonValue>,
|
merged: BTreeMap<String, JsonValue>,
|
||||||
loaded_entries: Vec<ConfigEntry>,
|
loaded_entries: Vec<ConfigEntry>,
|
||||||
|
feature_config: RuntimeFeatureConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||||
|
pub struct RuntimeFeatureConfig {
|
||||||
|
mcp: McpConfigCollection,
|
||||||
|
oauth: Option<OAuthConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||||
|
pub struct McpConfigCollection {
|
||||||
|
servers: BTreeMap<String, ScopedMcpServerConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct ScopedMcpServerConfig {
|
||||||
|
pub scope: ConfigSource,
|
||||||
|
pub config: McpServerConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum McpTransport {
|
||||||
|
Stdio,
|
||||||
|
Sse,
|
||||||
|
Http,
|
||||||
|
Ws,
|
||||||
|
Sdk,
|
||||||
|
ClaudeAiProxy,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum McpServerConfig {
|
||||||
|
Stdio(McpStdioServerConfig),
|
||||||
|
Sse(McpRemoteServerConfig),
|
||||||
|
Http(McpRemoteServerConfig),
|
||||||
|
Ws(McpWebSocketServerConfig),
|
||||||
|
Sdk(McpSdkServerConfig),
|
||||||
|
ClaudeAiProxy(McpClaudeAiProxyServerConfig),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct McpStdioServerConfig {
|
||||||
|
pub command: String,
|
||||||
|
pub args: Vec<String>,
|
||||||
|
pub env: BTreeMap<String, String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct McpRemoteServerConfig {
|
||||||
|
pub url: String,
|
||||||
|
pub headers: BTreeMap<String, String>,
|
||||||
|
pub headers_helper: Option<String>,
|
||||||
|
pub oauth: Option<McpOAuthConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct McpWebSocketServerConfig {
|
||||||
|
pub url: String,
|
||||||
|
pub headers: BTreeMap<String, String>,
|
||||||
|
pub headers_helper: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct McpSdkServerConfig {
|
||||||
|
pub name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct McpClaudeAiProxyServerConfig {
|
||||||
|
pub url: String,
|
||||||
|
pub id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct McpOAuthConfig {
|
||||||
|
pub client_id: Option<String>,
|
||||||
|
pub callback_port: Option<u16>,
|
||||||
|
pub auth_server_metadata_url: Option<String>,
|
||||||
|
pub xaa: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct OAuthConfig {
|
||||||
|
pub client_id: String,
|
||||||
|
pub authorize_url: String,
|
||||||
|
pub token_url: String,
|
||||||
|
pub callback_port: Option<u16>,
|
||||||
|
pub manual_redirect_url: Option<String>,
|
||||||
|
pub scopes: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@@ -95,18 +184,31 @@ impl ConfigLoader {
|
|||||||
pub fn load(&self) -> Result<RuntimeConfig, ConfigError> {
|
pub fn load(&self) -> Result<RuntimeConfig, ConfigError> {
|
||||||
let mut merged = BTreeMap::new();
|
let mut merged = BTreeMap::new();
|
||||||
let mut loaded_entries = Vec::new();
|
let mut loaded_entries = Vec::new();
|
||||||
|
let mut mcp_servers = BTreeMap::new();
|
||||||
|
|
||||||
for entry in self.discover() {
|
for entry in self.discover() {
|
||||||
let Some(value) = read_optional_json_object(&entry.path)? else {
|
let Some(value) = read_optional_json_object(&entry.path)? else {
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
merge_mcp_servers(&mut mcp_servers, entry.source, &value, &entry.path)?;
|
||||||
deep_merge_objects(&mut merged, &value);
|
deep_merge_objects(&mut merged, &value);
|
||||||
loaded_entries.push(entry);
|
loaded_entries.push(entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let feature_config = RuntimeFeatureConfig {
|
||||||
|
mcp: McpConfigCollection {
|
||||||
|
servers: mcp_servers,
|
||||||
|
},
|
||||||
|
oauth: parse_optional_oauth_config(
|
||||||
|
&JsonValue::Object(merged.clone()),
|
||||||
|
"merged settings.oauth",
|
||||||
|
)?,
|
||||||
|
};
|
||||||
|
|
||||||
Ok(RuntimeConfig {
|
Ok(RuntimeConfig {
|
||||||
merged,
|
merged,
|
||||||
loaded_entries,
|
loaded_entries,
|
||||||
|
feature_config,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -117,6 +219,7 @@ impl RuntimeConfig {
|
|||||||
Self {
|
Self {
|
||||||
merged: BTreeMap::new(),
|
merged: BTreeMap::new(),
|
||||||
loaded_entries: Vec::new(),
|
loaded_entries: Vec::new(),
|
||||||
|
feature_config: RuntimeFeatureConfig::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,6 +242,66 @@ impl RuntimeConfig {
|
|||||||
pub fn as_json(&self) -> JsonValue {
|
pub fn as_json(&self) -> JsonValue {
|
||||||
JsonValue::Object(self.merged.clone())
|
JsonValue::Object(self.merged.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn feature_config(&self) -> &RuntimeFeatureConfig {
|
||||||
|
&self.feature_config
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn mcp(&self) -> &McpConfigCollection {
|
||||||
|
&self.feature_config.mcp
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn oauth(&self) -> Option<&OAuthConfig> {
|
||||||
|
self.feature_config.oauth.as_ref()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RuntimeFeatureConfig {
|
||||||
|
#[must_use]
|
||||||
|
pub fn mcp(&self) -> &McpConfigCollection {
|
||||||
|
&self.mcp
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn oauth(&self) -> Option<&OAuthConfig> {
|
||||||
|
self.oauth.as_ref()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl McpConfigCollection {
|
||||||
|
#[must_use]
|
||||||
|
pub fn servers(&self) -> &BTreeMap<String, ScopedMcpServerConfig> {
|
||||||
|
&self.servers
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn get(&self, name: &str) -> Option<&ScopedMcpServerConfig> {
|
||||||
|
self.servers.get(name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ScopedMcpServerConfig {
|
||||||
|
#[must_use]
|
||||||
|
pub fn transport(&self) -> McpTransport {
|
||||||
|
self.config.transport()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl McpServerConfig {
|
||||||
|
#[must_use]
|
||||||
|
pub fn transport(&self) -> McpTransport {
|
||||||
|
match self {
|
||||||
|
Self::Stdio(_) => McpTransport::Stdio,
|
||||||
|
Self::Sse(_) => McpTransport::Sse,
|
||||||
|
Self::Http(_) => McpTransport::Http,
|
||||||
|
Self::Ws(_) => McpTransport::Ws,
|
||||||
|
Self::Sdk(_) => McpTransport::Sdk,
|
||||||
|
Self::ClaudeAiProxy(_) => McpTransport::ClaudeAiProxy,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_optional_json_object(
|
fn read_optional_json_object(
|
||||||
@@ -165,6 +328,253 @@ fn read_optional_json_object(
|
|||||||
Ok(Some(object.clone()))
|
Ok(Some(object.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn merge_mcp_servers(
|
||||||
|
target: &mut BTreeMap<String, ScopedMcpServerConfig>,
|
||||||
|
source: ConfigSource,
|
||||||
|
root: &BTreeMap<String, JsonValue>,
|
||||||
|
path: &Path,
|
||||||
|
) -> Result<(), ConfigError> {
|
||||||
|
let Some(mcp_servers) = root.get("mcpServers") else {
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
let servers = expect_object(mcp_servers, &format!("{}: mcpServers", path.display()))?;
|
||||||
|
for (name, value) in servers {
|
||||||
|
let parsed = parse_mcp_server_config(
|
||||||
|
name,
|
||||||
|
value,
|
||||||
|
&format!("{}: mcpServers.{name}", path.display()),
|
||||||
|
)?;
|
||||||
|
target.insert(
|
||||||
|
name.clone(),
|
||||||
|
ScopedMcpServerConfig {
|
||||||
|
scope: source,
|
||||||
|
config: parsed,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_optional_oauth_config(
|
||||||
|
root: &JsonValue,
|
||||||
|
context: &str,
|
||||||
|
) -> Result<Option<OAuthConfig>, ConfigError> {
|
||||||
|
let Some(oauth_value) = root.as_object().and_then(|object| object.get("oauth")) else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
let object = expect_object(oauth_value, context)?;
|
||||||
|
let client_id = expect_string(object, "clientId", context)?.to_string();
|
||||||
|
let authorize_url = expect_string(object, "authorizeUrl", context)?.to_string();
|
||||||
|
let token_url = expect_string(object, "tokenUrl", context)?.to_string();
|
||||||
|
let callback_port = optional_u16(object, "callbackPort", context)?;
|
||||||
|
let manual_redirect_url =
|
||||||
|
optional_string(object, "manualRedirectUrl", context)?.map(str::to_string);
|
||||||
|
let scopes = optional_string_array(object, "scopes", context)?.unwrap_or_default();
|
||||||
|
Ok(Some(OAuthConfig {
|
||||||
|
client_id,
|
||||||
|
authorize_url,
|
||||||
|
token_url,
|
||||||
|
callback_port,
|
||||||
|
manual_redirect_url,
|
||||||
|
scopes,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_mcp_server_config(
|
||||||
|
server_name: &str,
|
||||||
|
value: &JsonValue,
|
||||||
|
context: &str,
|
||||||
|
) -> Result<McpServerConfig, ConfigError> {
|
||||||
|
let object = expect_object(value, context)?;
|
||||||
|
let server_type = optional_string(object, "type", context)?.unwrap_or("stdio");
|
||||||
|
match server_type {
|
||||||
|
"stdio" => Ok(McpServerConfig::Stdio(McpStdioServerConfig {
|
||||||
|
command: expect_string(object, "command", context)?.to_string(),
|
||||||
|
args: optional_string_array(object, "args", context)?.unwrap_or_default(),
|
||||||
|
env: optional_string_map(object, "env", context)?.unwrap_or_default(),
|
||||||
|
})),
|
||||||
|
"sse" => Ok(McpServerConfig::Sse(parse_mcp_remote_server_config(
|
||||||
|
object, context,
|
||||||
|
)?)),
|
||||||
|
"http" => Ok(McpServerConfig::Http(parse_mcp_remote_server_config(
|
||||||
|
object, context,
|
||||||
|
)?)),
|
||||||
|
"ws" => Ok(McpServerConfig::Ws(McpWebSocketServerConfig {
|
||||||
|
url: expect_string(object, "url", context)?.to_string(),
|
||||||
|
headers: optional_string_map(object, "headers", context)?.unwrap_or_default(),
|
||||||
|
headers_helper: optional_string(object, "headersHelper", context)?.map(str::to_string),
|
||||||
|
})),
|
||||||
|
"sdk" => Ok(McpServerConfig::Sdk(McpSdkServerConfig {
|
||||||
|
name: expect_string(object, "name", context)?.to_string(),
|
||||||
|
})),
|
||||||
|
"claudeai-proxy" => Ok(McpServerConfig::ClaudeAiProxy(
|
||||||
|
McpClaudeAiProxyServerConfig {
|
||||||
|
url: expect_string(object, "url", context)?.to_string(),
|
||||||
|
id: expect_string(object, "id", context)?.to_string(),
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
other => Err(ConfigError::Parse(format!(
|
||||||
|
"{context}: unsupported MCP server type for {server_name}: {other}"
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_mcp_remote_server_config(
|
||||||
|
object: &BTreeMap<String, JsonValue>,
|
||||||
|
context: &str,
|
||||||
|
) -> Result<McpRemoteServerConfig, ConfigError> {
|
||||||
|
Ok(McpRemoteServerConfig {
|
||||||
|
url: expect_string(object, "url", context)?.to_string(),
|
||||||
|
headers: optional_string_map(object, "headers", context)?.unwrap_or_default(),
|
||||||
|
headers_helper: optional_string(object, "headersHelper", context)?.map(str::to_string),
|
||||||
|
oauth: parse_optional_mcp_oauth_config(object, context)?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_optional_mcp_oauth_config(
|
||||||
|
object: &BTreeMap<String, JsonValue>,
|
||||||
|
context: &str,
|
||||||
|
) -> Result<Option<McpOAuthConfig>, ConfigError> {
|
||||||
|
let Some(value) = object.get("oauth") else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
let oauth = expect_object(value, &format!("{context}.oauth"))?;
|
||||||
|
Ok(Some(McpOAuthConfig {
|
||||||
|
client_id: optional_string(oauth, "clientId", context)?.map(str::to_string),
|
||||||
|
callback_port: optional_u16(oauth, "callbackPort", context)?,
|
||||||
|
auth_server_metadata_url: optional_string(oauth, "authServerMetadataUrl", context)?
|
||||||
|
.map(str::to_string),
|
||||||
|
xaa: optional_bool(oauth, "xaa", context)?,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn expect_object<'a>(
|
||||||
|
value: &'a JsonValue,
|
||||||
|
context: &str,
|
||||||
|
) -> Result<&'a BTreeMap<String, JsonValue>, ConfigError> {
|
||||||
|
value
|
||||||
|
.as_object()
|
||||||
|
.ok_or_else(|| ConfigError::Parse(format!("{context}: expected JSON object")))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn expect_string<'a>(
|
||||||
|
object: &'a BTreeMap<String, JsonValue>,
|
||||||
|
key: &str,
|
||||||
|
context: &str,
|
||||||
|
) -> Result<&'a str, ConfigError> {
|
||||||
|
object
|
||||||
|
.get(key)
|
||||||
|
.and_then(JsonValue::as_str)
|
||||||
|
.ok_or_else(|| ConfigError::Parse(format!("{context}: missing string field {key}")))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn optional_string<'a>(
|
||||||
|
object: &'a BTreeMap<String, JsonValue>,
|
||||||
|
key: &str,
|
||||||
|
context: &str,
|
||||||
|
) -> Result<Option<&'a str>, ConfigError> {
|
||||||
|
match object.get(key) {
|
||||||
|
Some(value) => value
|
||||||
|
.as_str()
|
||||||
|
.map(Some)
|
||||||
|
.ok_or_else(|| ConfigError::Parse(format!("{context}: field {key} must be a string"))),
|
||||||
|
None => Ok(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn optional_bool(
|
||||||
|
object: &BTreeMap<String, JsonValue>,
|
||||||
|
key: &str,
|
||||||
|
context: &str,
|
||||||
|
) -> Result<Option<bool>, ConfigError> {
|
||||||
|
match object.get(key) {
|
||||||
|
Some(value) => value
|
||||||
|
.as_bool()
|
||||||
|
.map(Some)
|
||||||
|
.ok_or_else(|| ConfigError::Parse(format!("{context}: field {key} must be a boolean"))),
|
||||||
|
None => Ok(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn optional_u16(
|
||||||
|
object: &BTreeMap<String, JsonValue>,
|
||||||
|
key: &str,
|
||||||
|
context: &str,
|
||||||
|
) -> Result<Option<u16>, ConfigError> {
|
||||||
|
match object.get(key) {
|
||||||
|
Some(value) => {
|
||||||
|
let Some(number) = value.as_i64() else {
|
||||||
|
return Err(ConfigError::Parse(format!(
|
||||||
|
"{context}: field {key} must be an integer"
|
||||||
|
)));
|
||||||
|
};
|
||||||
|
let number = u16::try_from(number).map_err(|_| {
|
||||||
|
ConfigError::Parse(format!("{context}: field {key} is out of range"))
|
||||||
|
})?;
|
||||||
|
Ok(Some(number))
|
||||||
|
}
|
||||||
|
None => Ok(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn optional_string_array(
|
||||||
|
object: &BTreeMap<String, JsonValue>,
|
||||||
|
key: &str,
|
||||||
|
context: &str,
|
||||||
|
) -> Result<Option<Vec<String>>, ConfigError> {
|
||||||
|
match object.get(key) {
|
||||||
|
Some(value) => {
|
||||||
|
let Some(array) = value.as_array() else {
|
||||||
|
return Err(ConfigError::Parse(format!(
|
||||||
|
"{context}: field {key} must be an array"
|
||||||
|
)));
|
||||||
|
};
|
||||||
|
array
|
||||||
|
.iter()
|
||||||
|
.map(|item| {
|
||||||
|
item.as_str().map(ToOwned::to_owned).ok_or_else(|| {
|
||||||
|
ConfigError::Parse(format!(
|
||||||
|
"{context}: field {key} must contain only strings"
|
||||||
|
))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>()
|
||||||
|
.map(Some)
|
||||||
|
}
|
||||||
|
None => Ok(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn optional_string_map(
|
||||||
|
object: &BTreeMap<String, JsonValue>,
|
||||||
|
key: &str,
|
||||||
|
context: &str,
|
||||||
|
) -> Result<Option<BTreeMap<String, String>>, ConfigError> {
|
||||||
|
match object.get(key) {
|
||||||
|
Some(value) => {
|
||||||
|
let Some(map) = value.as_object() else {
|
||||||
|
return Err(ConfigError::Parse(format!(
|
||||||
|
"{context}: field {key} must be an object"
|
||||||
|
)));
|
||||||
|
};
|
||||||
|
map.iter()
|
||||||
|
.map(|(entry_key, entry_value)| {
|
||||||
|
entry_value
|
||||||
|
.as_str()
|
||||||
|
.map(|text| (entry_key.clone(), text.to_string()))
|
||||||
|
.ok_or_else(|| {
|
||||||
|
ConfigError::Parse(format!(
|
||||||
|
"{context}: field {key} must contain only string values"
|
||||||
|
))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect::<Result<BTreeMap<_, _>, _>>()
|
||||||
|
.map(Some)
|
||||||
|
}
|
||||||
|
None => Ok(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn deep_merge_objects(
|
fn deep_merge_objects(
|
||||||
target: &mut BTreeMap<String, JsonValue>,
|
target: &mut BTreeMap<String, JsonValue>,
|
||||||
source: &BTreeMap<String, JsonValue>,
|
source: &BTreeMap<String, JsonValue>,
|
||||||
@@ -183,7 +593,9 @@ fn deep_merge_objects(
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::{ConfigLoader, ConfigSource, CLAUDE_CODE_SETTINGS_SCHEMA_NAME};
|
use super::{
|
||||||
|
ConfigLoader, ConfigSource, McpServerConfig, McpTransport, CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
|
||||||
|
};
|
||||||
use crate::json::JsonValue;
|
use crate::json::JsonValue;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
@@ -266,4 +678,118 @@ mod tests {
|
|||||||
|
|
||||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parses_typed_mcp_and_oauth_config() {
|
||||||
|
let root = temp_dir();
|
||||||
|
let cwd = root.join("project");
|
||||||
|
let home = root.join("home").join(".claude");
|
||||||
|
fs::create_dir_all(cwd.join(".claude")).expect("project config dir");
|
||||||
|
fs::create_dir_all(&home).expect("home config dir");
|
||||||
|
|
||||||
|
fs::write(
|
||||||
|
home.join("settings.json"),
|
||||||
|
r#"{
|
||||||
|
"mcpServers": {
|
||||||
|
"stdio-server": {
|
||||||
|
"command": "uvx",
|
||||||
|
"args": ["mcp-server"],
|
||||||
|
"env": {"TOKEN": "secret"}
|
||||||
|
},
|
||||||
|
"remote-server": {
|
||||||
|
"type": "http",
|
||||||
|
"url": "https://example.test/mcp",
|
||||||
|
"headers": {"Authorization": "Bearer token"},
|
||||||
|
"headersHelper": "helper.sh",
|
||||||
|
"oauth": {
|
||||||
|
"clientId": "mcp-client",
|
||||||
|
"callbackPort": 7777,
|
||||||
|
"authServerMetadataUrl": "https://issuer.test/.well-known/oauth-authorization-server",
|
||||||
|
"xaa": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"oauth": {
|
||||||
|
"clientId": "runtime-client",
|
||||||
|
"authorizeUrl": "https://console.test/oauth/authorize",
|
||||||
|
"tokenUrl": "https://console.test/oauth/token",
|
||||||
|
"callbackPort": 54545,
|
||||||
|
"manualRedirectUrl": "https://console.test/oauth/callback",
|
||||||
|
"scopes": ["org:read", "user:write"]
|
||||||
|
}
|
||||||
|
}"#,
|
||||||
|
)
|
||||||
|
.expect("write user settings");
|
||||||
|
fs::write(
|
||||||
|
cwd.join(".claude").join("settings.local.json"),
|
||||||
|
r#"{
|
||||||
|
"mcpServers": {
|
||||||
|
"remote-server": {
|
||||||
|
"type": "ws",
|
||||||
|
"url": "wss://override.test/mcp",
|
||||||
|
"headers": {"X-Env": "local"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}"#,
|
||||||
|
)
|
||||||
|
.expect("write local settings");
|
||||||
|
|
||||||
|
let loaded = ConfigLoader::new(&cwd, &home)
|
||||||
|
.load()
|
||||||
|
.expect("config should load");
|
||||||
|
|
||||||
|
let stdio_server = loaded
|
||||||
|
.mcp()
|
||||||
|
.get("stdio-server")
|
||||||
|
.expect("stdio server should exist");
|
||||||
|
assert_eq!(stdio_server.scope, ConfigSource::User);
|
||||||
|
assert_eq!(stdio_server.transport(), McpTransport::Stdio);
|
||||||
|
|
||||||
|
let remote_server = loaded
|
||||||
|
.mcp()
|
||||||
|
.get("remote-server")
|
||||||
|
.expect("remote server should exist");
|
||||||
|
assert_eq!(remote_server.scope, ConfigSource::Local);
|
||||||
|
assert_eq!(remote_server.transport(), McpTransport::Ws);
|
||||||
|
match &remote_server.config {
|
||||||
|
McpServerConfig::Ws(config) => {
|
||||||
|
assert_eq!(config.url, "wss://override.test/mcp");
|
||||||
|
assert_eq!(
|
||||||
|
config.headers.get("X-Env").map(String::as_str),
|
||||||
|
Some("local")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
other => panic!("expected ws config, got {other:?}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
let oauth = loaded.oauth().expect("oauth config should exist");
|
||||||
|
assert_eq!(oauth.client_id, "runtime-client");
|
||||||
|
assert_eq!(oauth.callback_port, Some(54_545));
|
||||||
|
assert_eq!(oauth.scopes, vec!["org:read", "user:write"]);
|
||||||
|
|
||||||
|
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rejects_invalid_mcp_server_shapes() {
|
||||||
|
let root = temp_dir();
|
||||||
|
let cwd = root.join("project");
|
||||||
|
let home = root.join("home").join(".claude");
|
||||||
|
fs::create_dir_all(&home).expect("home config dir");
|
||||||
|
fs::create_dir_all(&cwd).expect("project dir");
|
||||||
|
fs::write(
|
||||||
|
home.join("settings.json"),
|
||||||
|
r#"{"mcpServers":{"broken":{"type":"http","url":123}}}"#,
|
||||||
|
)
|
||||||
|
.expect("write broken settings");
|
||||||
|
|
||||||
|
let error = ConfigLoader::new(&cwd, &home)
|
||||||
|
.load()
|
||||||
|
.expect_err("config should fail");
|
||||||
|
assert!(error
|
||||||
|
.to_string()
|
||||||
|
.contains("mcpServers.broken: missing string field url"));
|
||||||
|
|
||||||
|
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -285,7 +285,7 @@ pub fn grep_search(input: &GrepSearchInput) -> io::Result<GrepSearchOutput> {
|
|||||||
.output_mode
|
.output_mode
|
||||||
.clone()
|
.clone()
|
||||||
.unwrap_or_else(|| String::from("files_with_matches"));
|
.unwrap_or_else(|| String::from("files_with_matches"));
|
||||||
let context_window = input.context.or(input.context_short).unwrap_or(0);
|
let context = input.context.or(input.context_short).unwrap_or(0);
|
||||||
|
|
||||||
let mut filenames = Vec::new();
|
let mut filenames = Vec::new();
|
||||||
let mut content_lines = Vec::new();
|
let mut content_lines = Vec::new();
|
||||||
@@ -325,8 +325,8 @@ pub fn grep_search(input: &GrepSearchInput) -> io::Result<GrepSearchOutput> {
|
|||||||
filenames.push(file_path.to_string_lossy().into_owned());
|
filenames.push(file_path.to_string_lossy().into_owned());
|
||||||
if output_mode == "content" {
|
if output_mode == "content" {
|
||||||
for index in matched_lines {
|
for index in matched_lines {
|
||||||
let start = index.saturating_sub(input.before.unwrap_or(context_window));
|
let start = index.saturating_sub(input.before.unwrap_or(context));
|
||||||
let end = (index + input.after.unwrap_or(context_window) + 1).min(lines.len());
|
let end = (index + input.after.unwrap_or(context) + 1).min(lines.len());
|
||||||
for (current, line_content) in lines.iter().enumerate().take(end).skip(start) {
|
for (current, line_content) in lines.iter().enumerate().take(end).skip(start) {
|
||||||
let prefix = if input.line_numbers.unwrap_or(true) {
|
let prefix = if input.line_numbers.unwrap_or(true) {
|
||||||
format!("{}:{}:", file_path.to_string_lossy(), current + 1)
|
format!("{}:{}:", file_path.to_string_lossy(), current + 1)
|
||||||
@@ -341,7 +341,7 @@ pub fn grep_search(input: &GrepSearchInput) -> io::Result<GrepSearchOutput> {
|
|||||||
|
|
||||||
let (filenames, applied_limit, applied_offset) =
|
let (filenames, applied_limit, applied_offset) =
|
||||||
apply_limit(filenames, input.head_limit, input.offset);
|
apply_limit(filenames, input.head_limit, input.offset);
|
||||||
let content = if output_mode == "content" {
|
let rendered_content = if output_mode == "content" {
|
||||||
let (lines, limit, offset) = apply_limit(content_lines, input.head_limit, input.offset);
|
let (lines, limit, offset) = apply_limit(content_lines, input.head_limit, input.offset);
|
||||||
return Ok(GrepSearchOutput {
|
return Ok(GrepSearchOutput {
|
||||||
mode: Some(output_mode),
|
mode: Some(output_mode),
|
||||||
@@ -361,7 +361,7 @@ pub fn grep_search(input: &GrepSearchInput) -> io::Result<GrepSearchOutput> {
|
|||||||
mode: Some(output_mode.clone()),
|
mode: Some(output_mode.clone()),
|
||||||
num_files: filenames.len(),
|
num_files: filenames.len(),
|
||||||
filenames,
|
filenames,
|
||||||
content,
|
content: rendered_content,
|
||||||
num_lines: None,
|
num_lines: None,
|
||||||
num_matches: (output_mode == "count").then_some(total_matches),
|
num_matches: (output_mode == "count").then_some(total_matches),
|
||||||
applied_limit,
|
applied_limit,
|
||||||
|
|||||||
@@ -5,8 +5,12 @@ mod config;
|
|||||||
mod conversation;
|
mod conversation;
|
||||||
mod file_ops;
|
mod file_ops;
|
||||||
mod json;
|
mod json;
|
||||||
|
mod mcp;
|
||||||
|
mod mcp_client;
|
||||||
|
mod oauth;
|
||||||
mod permissions;
|
mod permissions;
|
||||||
mod prompt;
|
mod prompt;
|
||||||
|
mod remote;
|
||||||
mod session;
|
mod session;
|
||||||
mod usage;
|
mod usage;
|
||||||
|
|
||||||
@@ -17,8 +21,10 @@ pub use compact::{
|
|||||||
get_compact_continuation_message, should_compact, CompactionConfig, CompactionResult,
|
get_compact_continuation_message, should_compact, CompactionConfig, CompactionResult,
|
||||||
};
|
};
|
||||||
pub use config::{
|
pub use config::{
|
||||||
ConfigEntry, ConfigError, ConfigLoader, ConfigSource, RuntimeConfig,
|
ConfigEntry, ConfigError, ConfigLoader, ConfigSource, McpClaudeAiProxyServerConfig,
|
||||||
CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
|
McpConfigCollection, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig,
|
||||||
|
McpServerConfig, McpStdioServerConfig, McpTransport, McpWebSocketServerConfig, OAuthConfig,
|
||||||
|
RuntimeConfig, RuntimeFeatureConfig, ScopedMcpServerConfig, CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
|
||||||
};
|
};
|
||||||
pub use conversation::{
|
pub use conversation::{
|
||||||
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, StaticToolExecutor,
|
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, StaticToolExecutor,
|
||||||
@@ -29,6 +35,19 @@ pub use file_ops::{
|
|||||||
GrepSearchInput, GrepSearchOutput, ReadFileOutput, StructuredPatchHunk, TextFilePayload,
|
GrepSearchInput, GrepSearchOutput, ReadFileOutput, StructuredPatchHunk, TextFilePayload,
|
||||||
WriteFileOutput,
|
WriteFileOutput,
|
||||||
};
|
};
|
||||||
|
pub use mcp::{
|
||||||
|
mcp_server_signature, mcp_tool_name, mcp_tool_prefix, normalize_name_for_mcp,
|
||||||
|
scoped_mcp_config_hash, unwrap_ccr_proxy_url,
|
||||||
|
};
|
||||||
|
pub use mcp_client::{
|
||||||
|
McpClaudeAiProxyTransport, McpClientAuth, McpClientBootstrap, McpClientTransport,
|
||||||
|
McpRemoteTransport, McpSdkTransport, McpStdioTransport,
|
||||||
|
};
|
||||||
|
pub use oauth::{
|
||||||
|
code_challenge_s256, generate_pkce_pair, generate_state, loopback_redirect_uri,
|
||||||
|
OAuthAuthorizationRequest, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet,
|
||||||
|
PkceChallengeMethod, PkceCodePair,
|
||||||
|
};
|
||||||
pub use permissions::{
|
pub use permissions::{
|
||||||
PermissionMode, PermissionOutcome, PermissionPolicy, PermissionPromptDecision,
|
PermissionMode, PermissionOutcome, PermissionPolicy, PermissionPromptDecision,
|
||||||
PermissionPrompter, PermissionRequest,
|
PermissionPrompter, PermissionRequest,
|
||||||
@@ -37,6 +56,11 @@ pub use prompt::{
|
|||||||
load_system_prompt, prepend_bullets, ContextFile, ProjectContext, PromptBuildError,
|
load_system_prompt, prepend_bullets, ContextFile, ProjectContext, PromptBuildError,
|
||||||
SystemPromptBuilder, FRONTIER_MODEL_NAME, SYSTEM_PROMPT_DYNAMIC_BOUNDARY,
|
SystemPromptBuilder, FRONTIER_MODEL_NAME, SYSTEM_PROMPT_DYNAMIC_BOUNDARY,
|
||||||
};
|
};
|
||||||
|
pub use remote::{
|
||||||
|
inherited_upstream_proxy_env, no_proxy_list, read_token, upstream_proxy_ws_url,
|
||||||
|
RemoteSessionContext, UpstreamProxyBootstrap, UpstreamProxyState, DEFAULT_REMOTE_BASE_URL,
|
||||||
|
DEFAULT_SESSION_TOKEN_PATH, DEFAULT_SYSTEM_CA_BUNDLE, NO_PROXY_HOSTS, UPSTREAM_PROXY_ENV_KEYS,
|
||||||
|
};
|
||||||
pub use session::{ContentBlock, ConversationMessage, MessageRole, Session, SessionError};
|
pub use session::{ContentBlock, ConversationMessage, MessageRole, Session, SessionError};
|
||||||
pub use usage::{
|
pub use usage::{
|
||||||
format_usd, pricing_for_model, ModelPricing, TokenUsage, UsageCostEstimate, UsageTracker,
|
format_usd, pricing_for_model, ModelPricing, TokenUsage, UsageCostEstimate, UsageTracker,
|
||||||
|
|||||||
300
rust/crates/runtime/src/mcp.rs
Normal file
300
rust/crates/runtime/src/mcp.rs
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
use crate::config::{McpServerConfig, ScopedMcpServerConfig};
|
||||||
|
|
||||||
|
const CLAUDEAI_SERVER_PREFIX: &str = "claude.ai ";
|
||||||
|
const CCR_PROXY_PATH_MARKERS: [&str; 2] = ["/v2/session_ingress/shttp/mcp/", "/v2/ccr-sessions/"];
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn normalize_name_for_mcp(name: &str) -> String {
|
||||||
|
let mut normalized = name
|
||||||
|
.chars()
|
||||||
|
.map(|ch| match ch {
|
||||||
|
'a'..='z' | 'A'..='Z' | '0'..='9' | '_' | '-' => ch,
|
||||||
|
_ => '_',
|
||||||
|
})
|
||||||
|
.collect::<String>();
|
||||||
|
|
||||||
|
if name.starts_with(CLAUDEAI_SERVER_PREFIX) {
|
||||||
|
normalized = collapse_underscores(&normalized)
|
||||||
|
.trim_matches('_')
|
||||||
|
.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn mcp_tool_prefix(server_name: &str) -> String {
|
||||||
|
format!("mcp__{}__", normalize_name_for_mcp(server_name))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn mcp_tool_name(server_name: &str, tool_name: &str) -> String {
|
||||||
|
format!(
|
||||||
|
"{}{}",
|
||||||
|
mcp_tool_prefix(server_name),
|
||||||
|
normalize_name_for_mcp(tool_name)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn unwrap_ccr_proxy_url(url: &str) -> String {
|
||||||
|
if !CCR_PROXY_PATH_MARKERS
|
||||||
|
.iter()
|
||||||
|
.any(|marker| url.contains(marker))
|
||||||
|
{
|
||||||
|
return url.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(query_start) = url.find('?') else {
|
||||||
|
return url.to_string();
|
||||||
|
};
|
||||||
|
let query = &url[query_start + 1..];
|
||||||
|
for pair in query.split('&') {
|
||||||
|
let mut parts = pair.splitn(2, '=');
|
||||||
|
if matches!(parts.next(), Some("mcp_url")) {
|
||||||
|
if let Some(value) = parts.next() {
|
||||||
|
return percent_decode(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
url.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn mcp_server_signature(config: &McpServerConfig) -> Option<String> {
|
||||||
|
match config {
|
||||||
|
McpServerConfig::Stdio(config) => {
|
||||||
|
let mut command = vec![config.command.clone()];
|
||||||
|
command.extend(config.args.clone());
|
||||||
|
Some(format!("stdio:{}", render_command_signature(&command)))
|
||||||
|
}
|
||||||
|
McpServerConfig::Sse(config) | McpServerConfig::Http(config) => {
|
||||||
|
Some(format!("url:{}", unwrap_ccr_proxy_url(&config.url)))
|
||||||
|
}
|
||||||
|
McpServerConfig::Ws(config) => Some(format!("url:{}", unwrap_ccr_proxy_url(&config.url))),
|
||||||
|
McpServerConfig::ClaudeAiProxy(config) => {
|
||||||
|
Some(format!("url:{}", unwrap_ccr_proxy_url(&config.url)))
|
||||||
|
}
|
||||||
|
McpServerConfig::Sdk(_) => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn scoped_mcp_config_hash(config: &ScopedMcpServerConfig) -> String {
|
||||||
|
let rendered = match &config.config {
|
||||||
|
McpServerConfig::Stdio(stdio) => format!(
|
||||||
|
"stdio|{}|{}|{}",
|
||||||
|
stdio.command,
|
||||||
|
render_command_signature(&stdio.args),
|
||||||
|
render_env_signature(&stdio.env)
|
||||||
|
),
|
||||||
|
McpServerConfig::Sse(remote) => format!(
|
||||||
|
"sse|{}|{}|{}|{}",
|
||||||
|
remote.url,
|
||||||
|
render_env_signature(&remote.headers),
|
||||||
|
remote.headers_helper.as_deref().unwrap_or(""),
|
||||||
|
render_oauth_signature(remote.oauth.as_ref())
|
||||||
|
),
|
||||||
|
McpServerConfig::Http(remote) => format!(
|
||||||
|
"http|{}|{}|{}|{}",
|
||||||
|
remote.url,
|
||||||
|
render_env_signature(&remote.headers),
|
||||||
|
remote.headers_helper.as_deref().unwrap_or(""),
|
||||||
|
render_oauth_signature(remote.oauth.as_ref())
|
||||||
|
),
|
||||||
|
McpServerConfig::Ws(ws) => format!(
|
||||||
|
"ws|{}|{}|{}",
|
||||||
|
ws.url,
|
||||||
|
render_env_signature(&ws.headers),
|
||||||
|
ws.headers_helper.as_deref().unwrap_or("")
|
||||||
|
),
|
||||||
|
McpServerConfig::Sdk(sdk) => format!("sdk|{}", sdk.name),
|
||||||
|
McpServerConfig::ClaudeAiProxy(proxy) => {
|
||||||
|
format!("claudeai-proxy|{}|{}", proxy.url, proxy.id)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
stable_hex_hash(&rendered)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_command_signature(command: &[String]) -> String {
|
||||||
|
let escaped = command
|
||||||
|
.iter()
|
||||||
|
.map(|part| part.replace('\\', "\\\\").replace('|', "\\|"))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
format!("[{}]", escaped.join("|"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_env_signature(map: &std::collections::BTreeMap<String, String>) -> String {
|
||||||
|
map.iter()
|
||||||
|
.map(|(key, value)| format!("{key}={value}"))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(";")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_oauth_signature(oauth: Option<&crate::config::McpOAuthConfig>) -> String {
|
||||||
|
oauth.map_or_else(String::new, |oauth| {
|
||||||
|
format!(
|
||||||
|
"{}|{}|{}|{}",
|
||||||
|
oauth.client_id.as_deref().unwrap_or(""),
|
||||||
|
oauth
|
||||||
|
.callback_port
|
||||||
|
.map_or_else(String::new, |port| port.to_string()),
|
||||||
|
oauth.auth_server_metadata_url.as_deref().unwrap_or(""),
|
||||||
|
oauth.xaa.map_or_else(String::new, |flag| flag.to_string())
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stable_hex_hash(value: &str) -> String {
|
||||||
|
let mut hash = 0xcbf2_9ce4_8422_2325_u64;
|
||||||
|
for byte in value.as_bytes() {
|
||||||
|
hash ^= u64::from(*byte);
|
||||||
|
hash = hash.wrapping_mul(0x0100_0000_01b3);
|
||||||
|
}
|
||||||
|
format!("{hash:016x}")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn collapse_underscores(value: &str) -> String {
|
||||||
|
let mut collapsed = String::with_capacity(value.len());
|
||||||
|
let mut last_was_underscore = false;
|
||||||
|
for ch in value.chars() {
|
||||||
|
if ch == '_' {
|
||||||
|
if !last_was_underscore {
|
||||||
|
collapsed.push(ch);
|
||||||
|
}
|
||||||
|
last_was_underscore = true;
|
||||||
|
} else {
|
||||||
|
collapsed.push(ch);
|
||||||
|
last_was_underscore = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
collapsed
|
||||||
|
}
|
||||||
|
|
||||||
|
fn percent_decode(value: &str) -> String {
|
||||||
|
let bytes = value.as_bytes();
|
||||||
|
let mut decoded = Vec::with_capacity(bytes.len());
|
||||||
|
let mut index = 0;
|
||||||
|
while index < bytes.len() {
|
||||||
|
match bytes[index] {
|
||||||
|
b'%' if index + 2 < bytes.len() => {
|
||||||
|
let hex = &value[index + 1..index + 3];
|
||||||
|
if let Ok(byte) = u8::from_str_radix(hex, 16) {
|
||||||
|
decoded.push(byte);
|
||||||
|
index += 3;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
decoded.push(bytes[index]);
|
||||||
|
index += 1;
|
||||||
|
}
|
||||||
|
b'+' => {
|
||||||
|
decoded.push(b' ');
|
||||||
|
index += 1;
|
||||||
|
}
|
||||||
|
byte => {
|
||||||
|
decoded.push(byte);
|
||||||
|
index += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
String::from_utf8_lossy(&decoded).into_owned()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
|
use crate::config::{
|
||||||
|
ConfigSource, McpRemoteServerConfig, McpServerConfig, McpStdioServerConfig,
|
||||||
|
McpWebSocketServerConfig, ScopedMcpServerConfig,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
mcp_server_signature, mcp_tool_name, normalize_name_for_mcp, scoped_mcp_config_hash,
|
||||||
|
unwrap_ccr_proxy_url,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn normalizes_server_names_for_mcp_tooling() {
|
||||||
|
assert_eq!(normalize_name_for_mcp("github.com"), "github_com");
|
||||||
|
assert_eq!(normalize_name_for_mcp("tool name!"), "tool_name_");
|
||||||
|
assert_eq!(
|
||||||
|
normalize_name_for_mcp("claude.ai Example Server!!"),
|
||||||
|
"claude_ai_Example_Server"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
mcp_tool_name("claude.ai Example Server", "weather tool"),
|
||||||
|
"mcp__claude_ai_Example_Server__weather_tool"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn unwraps_ccr_proxy_urls_for_signature_matching() {
|
||||||
|
let wrapped = "https://api.anthropic.com/v2/session_ingress/shttp/mcp/123?mcp_url=https%3A%2F%2Fvendor.example%2Fmcp&other=1";
|
||||||
|
assert_eq!(unwrap_ccr_proxy_url(wrapped), "https://vendor.example/mcp");
|
||||||
|
assert_eq!(
|
||||||
|
unwrap_ccr_proxy_url("https://vendor.example/mcp"),
|
||||||
|
"https://vendor.example/mcp"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn computes_signatures_for_stdio_and_remote_servers() {
|
||||||
|
let stdio = McpServerConfig::Stdio(McpStdioServerConfig {
|
||||||
|
command: "uvx".to_string(),
|
||||||
|
args: vec!["mcp-server".to_string()],
|
||||||
|
env: BTreeMap::from([("TOKEN".to_string(), "secret".to_string())]),
|
||||||
|
});
|
||||||
|
assert_eq!(
|
||||||
|
mcp_server_signature(&stdio),
|
||||||
|
Some("stdio:[uvx|mcp-server]".to_string())
|
||||||
|
);
|
||||||
|
|
||||||
|
let remote = McpServerConfig::Ws(McpWebSocketServerConfig {
|
||||||
|
url: "https://api.anthropic.com/v2/ccr-sessions/1?mcp_url=wss%3A%2F%2Fvendor.example%2Fmcp".to_string(),
|
||||||
|
headers: BTreeMap::new(),
|
||||||
|
headers_helper: None,
|
||||||
|
});
|
||||||
|
assert_eq!(
|
||||||
|
mcp_server_signature(&remote),
|
||||||
|
Some("url:wss://vendor.example/mcp".to_string())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn scoped_hash_ignores_scope_but_tracks_config_content() {
|
||||||
|
let base_config = McpServerConfig::Http(McpRemoteServerConfig {
|
||||||
|
url: "https://vendor.example/mcp".to_string(),
|
||||||
|
headers: BTreeMap::from([("Authorization".to_string(), "Bearer token".to_string())]),
|
||||||
|
headers_helper: Some("helper.sh".to_string()),
|
||||||
|
oauth: None,
|
||||||
|
});
|
||||||
|
let user = ScopedMcpServerConfig {
|
||||||
|
scope: ConfigSource::User,
|
||||||
|
config: base_config.clone(),
|
||||||
|
};
|
||||||
|
let local = ScopedMcpServerConfig {
|
||||||
|
scope: ConfigSource::Local,
|
||||||
|
config: base_config,
|
||||||
|
};
|
||||||
|
assert_eq!(
|
||||||
|
scoped_mcp_config_hash(&user),
|
||||||
|
scoped_mcp_config_hash(&local)
|
||||||
|
);
|
||||||
|
|
||||||
|
let changed = ScopedMcpServerConfig {
|
||||||
|
scope: ConfigSource::Local,
|
||||||
|
config: McpServerConfig::Http(McpRemoteServerConfig {
|
||||||
|
url: "https://vendor.example/v2/mcp".to_string(),
|
||||||
|
headers: BTreeMap::new(),
|
||||||
|
headers_helper: None,
|
||||||
|
oauth: None,
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
assert_ne!(
|
||||||
|
scoped_mcp_config_hash(&user),
|
||||||
|
scoped_mcp_config_hash(&changed)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
236
rust/crates/runtime/src/mcp_client.rs
Normal file
236
rust/crates/runtime/src/mcp_client.rs
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
|
use crate::config::{McpOAuthConfig, McpServerConfig, ScopedMcpServerConfig};
|
||||||
|
use crate::mcp::{mcp_server_signature, mcp_tool_prefix, normalize_name_for_mcp};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum McpClientTransport {
|
||||||
|
Stdio(McpStdioTransport),
|
||||||
|
Sse(McpRemoteTransport),
|
||||||
|
Http(McpRemoteTransport),
|
||||||
|
WebSocket(McpRemoteTransport),
|
||||||
|
Sdk(McpSdkTransport),
|
||||||
|
ClaudeAiProxy(McpClaudeAiProxyTransport),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct McpStdioTransport {
|
||||||
|
pub command: String,
|
||||||
|
pub args: Vec<String>,
|
||||||
|
pub env: BTreeMap<String, String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct McpRemoteTransport {
|
||||||
|
pub url: String,
|
||||||
|
pub headers: BTreeMap<String, String>,
|
||||||
|
pub headers_helper: Option<String>,
|
||||||
|
pub auth: McpClientAuth,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct McpSdkTransport {
|
||||||
|
pub name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct McpClaudeAiProxyTransport {
|
||||||
|
pub url: String,
|
||||||
|
pub id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum McpClientAuth {
|
||||||
|
None,
|
||||||
|
OAuth(McpOAuthConfig),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct McpClientBootstrap {
|
||||||
|
pub server_name: String,
|
||||||
|
pub normalized_name: String,
|
||||||
|
pub tool_prefix: String,
|
||||||
|
pub signature: Option<String>,
|
||||||
|
pub transport: McpClientTransport,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl McpClientBootstrap {
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_scoped_config(server_name: &str, config: &ScopedMcpServerConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
server_name: server_name.to_string(),
|
||||||
|
normalized_name: normalize_name_for_mcp(server_name),
|
||||||
|
tool_prefix: mcp_tool_prefix(server_name),
|
||||||
|
signature: mcp_server_signature(&config.config),
|
||||||
|
transport: McpClientTransport::from_config(&config.config),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl McpClientTransport {
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_config(config: &McpServerConfig) -> Self {
|
||||||
|
match config {
|
||||||
|
McpServerConfig::Stdio(config) => Self::Stdio(McpStdioTransport {
|
||||||
|
command: config.command.clone(),
|
||||||
|
args: config.args.clone(),
|
||||||
|
env: config.env.clone(),
|
||||||
|
}),
|
||||||
|
McpServerConfig::Sse(config) => Self::Sse(McpRemoteTransport {
|
||||||
|
url: config.url.clone(),
|
||||||
|
headers: config.headers.clone(),
|
||||||
|
headers_helper: config.headers_helper.clone(),
|
||||||
|
auth: McpClientAuth::from_oauth(config.oauth.clone()),
|
||||||
|
}),
|
||||||
|
McpServerConfig::Http(config) => Self::Http(McpRemoteTransport {
|
||||||
|
url: config.url.clone(),
|
||||||
|
headers: config.headers.clone(),
|
||||||
|
headers_helper: config.headers_helper.clone(),
|
||||||
|
auth: McpClientAuth::from_oauth(config.oauth.clone()),
|
||||||
|
}),
|
||||||
|
McpServerConfig::Ws(config) => Self::WebSocket(McpRemoteTransport {
|
||||||
|
url: config.url.clone(),
|
||||||
|
headers: config.headers.clone(),
|
||||||
|
headers_helper: config.headers_helper.clone(),
|
||||||
|
auth: McpClientAuth::None,
|
||||||
|
}),
|
||||||
|
McpServerConfig::Sdk(config) => Self::Sdk(McpSdkTransport {
|
||||||
|
name: config.name.clone(),
|
||||||
|
}),
|
||||||
|
McpServerConfig::ClaudeAiProxy(config) => {
|
||||||
|
Self::ClaudeAiProxy(McpClaudeAiProxyTransport {
|
||||||
|
url: config.url.clone(),
|
||||||
|
id: config.id.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl McpClientAuth {
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_oauth(oauth: Option<McpOAuthConfig>) -> Self {
|
||||||
|
oauth.map_or(Self::None, Self::OAuth)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub const fn requires_user_auth(&self) -> bool {
|
||||||
|
matches!(self, Self::OAuth(_))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
|
use crate::config::{
|
||||||
|
ConfigSource, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig, McpServerConfig,
|
||||||
|
McpStdioServerConfig, McpWebSocketServerConfig, ScopedMcpServerConfig,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{McpClientAuth, McpClientBootstrap, McpClientTransport};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bootstraps_stdio_servers_into_transport_targets() {
|
||||||
|
let config = ScopedMcpServerConfig {
|
||||||
|
scope: ConfigSource::User,
|
||||||
|
config: McpServerConfig::Stdio(McpStdioServerConfig {
|
||||||
|
command: "uvx".to_string(),
|
||||||
|
args: vec!["mcp-server".to_string()],
|
||||||
|
env: BTreeMap::from([("TOKEN".to_string(), "secret".to_string())]),
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
let bootstrap = McpClientBootstrap::from_scoped_config("stdio-server", &config);
|
||||||
|
assert_eq!(bootstrap.normalized_name, "stdio-server");
|
||||||
|
assert_eq!(bootstrap.tool_prefix, "mcp__stdio-server__");
|
||||||
|
assert_eq!(
|
||||||
|
bootstrap.signature.as_deref(),
|
||||||
|
Some("stdio:[uvx|mcp-server]")
|
||||||
|
);
|
||||||
|
match bootstrap.transport {
|
||||||
|
McpClientTransport::Stdio(transport) => {
|
||||||
|
assert_eq!(transport.command, "uvx");
|
||||||
|
assert_eq!(transport.args, vec!["mcp-server"]);
|
||||||
|
assert_eq!(
|
||||||
|
transport.env.get("TOKEN").map(String::as_str),
|
||||||
|
Some("secret")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
other => panic!("expected stdio transport, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bootstraps_remote_servers_with_oauth_auth() {
|
||||||
|
let config = ScopedMcpServerConfig {
|
||||||
|
scope: ConfigSource::Project,
|
||||||
|
config: McpServerConfig::Http(McpRemoteServerConfig {
|
||||||
|
url: "https://vendor.example/mcp".to_string(),
|
||||||
|
headers: BTreeMap::from([("X-Test".to_string(), "1".to_string())]),
|
||||||
|
headers_helper: Some("helper.sh".to_string()),
|
||||||
|
oauth: Some(McpOAuthConfig {
|
||||||
|
client_id: Some("client-id".to_string()),
|
||||||
|
callback_port: Some(7777),
|
||||||
|
auth_server_metadata_url: Some(
|
||||||
|
"https://issuer.example/.well-known/oauth-authorization-server".to_string(),
|
||||||
|
),
|
||||||
|
xaa: Some(true),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
let bootstrap = McpClientBootstrap::from_scoped_config("remote server", &config);
|
||||||
|
assert_eq!(bootstrap.normalized_name, "remote_server");
|
||||||
|
match bootstrap.transport {
|
||||||
|
McpClientTransport::Http(transport) => {
|
||||||
|
assert_eq!(transport.url, "https://vendor.example/mcp");
|
||||||
|
assert_eq!(transport.headers_helper.as_deref(), Some("helper.sh"));
|
||||||
|
assert!(transport.auth.requires_user_auth());
|
||||||
|
match transport.auth {
|
||||||
|
McpClientAuth::OAuth(oauth) => {
|
||||||
|
assert_eq!(oauth.client_id.as_deref(), Some("client-id"));
|
||||||
|
}
|
||||||
|
other @ McpClientAuth::None => panic!("expected oauth auth, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
other => panic!("expected http transport, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bootstraps_websocket_and_sdk_transports_without_oauth() {
|
||||||
|
let ws = ScopedMcpServerConfig {
|
||||||
|
scope: ConfigSource::Local,
|
||||||
|
config: McpServerConfig::Ws(McpWebSocketServerConfig {
|
||||||
|
url: "wss://vendor.example/mcp".to_string(),
|
||||||
|
headers: BTreeMap::new(),
|
||||||
|
headers_helper: None,
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
let sdk = ScopedMcpServerConfig {
|
||||||
|
scope: ConfigSource::Local,
|
||||||
|
config: McpServerConfig::Sdk(McpSdkServerConfig {
|
||||||
|
name: "sdk-server".to_string(),
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
let ws_bootstrap = McpClientBootstrap::from_scoped_config("ws server", &ws);
|
||||||
|
match ws_bootstrap.transport {
|
||||||
|
McpClientTransport::WebSocket(transport) => {
|
||||||
|
assert_eq!(transport.url, "wss://vendor.example/mcp");
|
||||||
|
assert!(!transport.auth.requires_user_auth());
|
||||||
|
}
|
||||||
|
other => panic!("expected websocket transport, got {other:?}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
let sdk_bootstrap = McpClientBootstrap::from_scoped_config("sdk server", &sdk);
|
||||||
|
assert_eq!(sdk_bootstrap.signature, None);
|
||||||
|
match sdk_bootstrap.transport {
|
||||||
|
McpClientTransport::Sdk(transport) => {
|
||||||
|
assert_eq!(transport.name, "sdk-server");
|
||||||
|
}
|
||||||
|
other => panic!("expected sdk transport, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
338
rust/crates/runtime/src/oauth.rs
Normal file
338
rust/crates/runtime/src/oauth.rs
Normal file
@@ -0,0 +1,338 @@
|
|||||||
|
use std::collections::BTreeMap;
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io::{self, Read};
|
||||||
|
|
||||||
|
use sha2::{Digest, Sha256};
|
||||||
|
|
||||||
|
use crate::config::OAuthConfig;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct OAuthTokenSet {
|
||||||
|
pub access_token: String,
|
||||||
|
pub refresh_token: Option<String>,
|
||||||
|
pub expires_at: Option<u64>,
|
||||||
|
pub scopes: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct PkceCodePair {
|
||||||
|
pub verifier: String,
|
||||||
|
pub challenge: String,
|
||||||
|
pub challenge_method: PkceChallengeMethod,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum PkceChallengeMethod {
|
||||||
|
S256,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PkceChallengeMethod {
|
||||||
|
#[must_use]
|
||||||
|
pub const fn as_str(self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::S256 => "S256",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct OAuthAuthorizationRequest {
|
||||||
|
pub authorize_url: String,
|
||||||
|
pub client_id: String,
|
||||||
|
pub redirect_uri: String,
|
||||||
|
pub scopes: Vec<String>,
|
||||||
|
pub state: String,
|
||||||
|
pub code_challenge: String,
|
||||||
|
pub code_challenge_method: PkceChallengeMethod,
|
||||||
|
pub extra_params: BTreeMap<String, String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct OAuthTokenExchangeRequest {
|
||||||
|
pub grant_type: &'static str,
|
||||||
|
pub code: String,
|
||||||
|
pub redirect_uri: String,
|
||||||
|
pub client_id: String,
|
||||||
|
pub code_verifier: String,
|
||||||
|
pub state: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct OAuthRefreshRequest {
|
||||||
|
pub grant_type: &'static str,
|
||||||
|
pub refresh_token: String,
|
||||||
|
pub client_id: String,
|
||||||
|
pub scopes: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OAuthAuthorizationRequest {
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_config(
|
||||||
|
config: &OAuthConfig,
|
||||||
|
redirect_uri: impl Into<String>,
|
||||||
|
state: impl Into<String>,
|
||||||
|
pkce: &PkceCodePair,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
authorize_url: config.authorize_url.clone(),
|
||||||
|
client_id: config.client_id.clone(),
|
||||||
|
redirect_uri: redirect_uri.into(),
|
||||||
|
scopes: config.scopes.clone(),
|
||||||
|
state: state.into(),
|
||||||
|
code_challenge: pkce.challenge.clone(),
|
||||||
|
code_challenge_method: pkce.challenge_method,
|
||||||
|
extra_params: BTreeMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn with_extra_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
|
||||||
|
self.extra_params.insert(key.into(), value.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn build_url(&self) -> String {
|
||||||
|
let mut params = vec![
|
||||||
|
("response_type", "code".to_string()),
|
||||||
|
("client_id", self.client_id.clone()),
|
||||||
|
("redirect_uri", self.redirect_uri.clone()),
|
||||||
|
("scope", self.scopes.join(" ")),
|
||||||
|
("state", self.state.clone()),
|
||||||
|
("code_challenge", self.code_challenge.clone()),
|
||||||
|
(
|
||||||
|
"code_challenge_method",
|
||||||
|
self.code_challenge_method.as_str().to_string(),
|
||||||
|
),
|
||||||
|
];
|
||||||
|
params.extend(
|
||||||
|
self.extra_params
|
||||||
|
.iter()
|
||||||
|
.map(|(key, value)| (key.as_str(), value.clone())),
|
||||||
|
);
|
||||||
|
let query = params
|
||||||
|
.into_iter()
|
||||||
|
.map(|(key, value)| format!("{}={}", percent_encode(key), percent_encode(&value)))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("&");
|
||||||
|
format!(
|
||||||
|
"{}{}{}",
|
||||||
|
self.authorize_url,
|
||||||
|
if self.authorize_url.contains('?') {
|
||||||
|
'&'
|
||||||
|
} else {
|
||||||
|
'?'
|
||||||
|
},
|
||||||
|
query
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OAuthTokenExchangeRequest {
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_config(
|
||||||
|
config: &OAuthConfig,
|
||||||
|
code: impl Into<String>,
|
||||||
|
state: impl Into<String>,
|
||||||
|
verifier: impl Into<String>,
|
||||||
|
redirect_uri: impl Into<String>,
|
||||||
|
) -> Self {
|
||||||
|
let _ = config;
|
||||||
|
Self {
|
||||||
|
grant_type: "authorization_code",
|
||||||
|
code: code.into(),
|
||||||
|
redirect_uri: redirect_uri.into(),
|
||||||
|
client_id: config.client_id.clone(),
|
||||||
|
code_verifier: verifier.into(),
|
||||||
|
state: state.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn form_params(&self) -> BTreeMap<&str, String> {
|
||||||
|
BTreeMap::from([
|
||||||
|
("grant_type", self.grant_type.to_string()),
|
||||||
|
("code", self.code.clone()),
|
||||||
|
("redirect_uri", self.redirect_uri.clone()),
|
||||||
|
("client_id", self.client_id.clone()),
|
||||||
|
("code_verifier", self.code_verifier.clone()),
|
||||||
|
("state", self.state.clone()),
|
||||||
|
])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OAuthRefreshRequest {
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_config(
|
||||||
|
config: &OAuthConfig,
|
||||||
|
refresh_token: impl Into<String>,
|
||||||
|
scopes: Option<Vec<String>>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
grant_type: "refresh_token",
|
||||||
|
refresh_token: refresh_token.into(),
|
||||||
|
client_id: config.client_id.clone(),
|
||||||
|
scopes: scopes.unwrap_or_else(|| config.scopes.clone()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn form_params(&self) -> BTreeMap<&str, String> {
|
||||||
|
BTreeMap::from([
|
||||||
|
("grant_type", self.grant_type.to_string()),
|
||||||
|
("refresh_token", self.refresh_token.clone()),
|
||||||
|
("client_id", self.client_id.clone()),
|
||||||
|
("scope", self.scopes.join(" ")),
|
||||||
|
])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn generate_pkce_pair() -> io::Result<PkceCodePair> {
|
||||||
|
let verifier = generate_random_token(32)?;
|
||||||
|
Ok(PkceCodePair {
|
||||||
|
challenge: code_challenge_s256(&verifier),
|
||||||
|
verifier,
|
||||||
|
challenge_method: PkceChallengeMethod::S256,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn generate_state() -> io::Result<String> {
|
||||||
|
generate_random_token(32)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn code_challenge_s256(verifier: &str) -> String {
|
||||||
|
let digest = Sha256::digest(verifier.as_bytes());
|
||||||
|
base64url_encode(&digest)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn loopback_redirect_uri(port: u16) -> String {
|
||||||
|
format!("http://localhost:{port}/callback")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn generate_random_token(bytes: usize) -> io::Result<String> {
|
||||||
|
let mut buffer = vec![0_u8; bytes];
|
||||||
|
File::open("/dev/urandom")?.read_exact(&mut buffer)?;
|
||||||
|
Ok(base64url_encode(&buffer))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn base64url_encode(bytes: &[u8]) -> String {
|
||||||
|
const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
|
||||||
|
let mut output = String::new();
|
||||||
|
let mut index = 0;
|
||||||
|
while index + 3 <= bytes.len() {
|
||||||
|
let block = (u32::from(bytes[index]) << 16)
|
||||||
|
| (u32::from(bytes[index + 1]) << 8)
|
||||||
|
| u32::from(bytes[index + 2]);
|
||||||
|
output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
|
||||||
|
output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
|
||||||
|
output.push(TABLE[((block >> 6) & 0x3F) as usize] as char);
|
||||||
|
output.push(TABLE[(block & 0x3F) as usize] as char);
|
||||||
|
index += 3;
|
||||||
|
}
|
||||||
|
match bytes.len().saturating_sub(index) {
|
||||||
|
1 => {
|
||||||
|
let block = u32::from(bytes[index]) << 16;
|
||||||
|
output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
|
||||||
|
output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
|
||||||
|
}
|
||||||
|
2 => {
|
||||||
|
let block = (u32::from(bytes[index]) << 16) | (u32::from(bytes[index + 1]) << 8);
|
||||||
|
output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
|
||||||
|
output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
|
||||||
|
output.push(TABLE[((block >> 6) & 0x3F) as usize] as char);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
fn percent_encode(value: &str) -> String {
|
||||||
|
let mut encoded = String::new();
|
||||||
|
for byte in value.bytes() {
|
||||||
|
match byte {
|
||||||
|
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
|
||||||
|
encoded.push(char::from(byte));
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
use std::fmt::Write as _;
|
||||||
|
let _ = write!(&mut encoded, "%{byte:02X}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
encoded
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::{
|
||||||
|
code_challenge_s256, generate_pkce_pair, generate_state, loopback_redirect_uri,
|
||||||
|
OAuthAuthorizationRequest, OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest,
|
||||||
|
};
|
||||||
|
|
||||||
|
fn sample_config() -> OAuthConfig {
|
||||||
|
OAuthConfig {
|
||||||
|
client_id: "runtime-client".to_string(),
|
||||||
|
authorize_url: "https://console.test/oauth/authorize".to_string(),
|
||||||
|
token_url: "https://console.test/oauth/token".to_string(),
|
||||||
|
callback_port: Some(4545),
|
||||||
|
manual_redirect_url: Some("https://console.test/oauth/callback".to_string()),
|
||||||
|
scopes: vec!["org:read".to_string(), "user:write".to_string()],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn s256_challenge_matches_expected_vector() {
|
||||||
|
assert_eq!(
|
||||||
|
code_challenge_s256("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"),
|
||||||
|
"E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn generates_pkce_pair_and_state() {
|
||||||
|
let pair = generate_pkce_pair().expect("pkce pair");
|
||||||
|
let state = generate_state().expect("state");
|
||||||
|
assert!(!pair.verifier.is_empty());
|
||||||
|
assert!(!pair.challenge.is_empty());
|
||||||
|
assert!(!state.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn builds_authorize_url_and_form_requests() {
|
||||||
|
let config = sample_config();
|
||||||
|
let pair = generate_pkce_pair().expect("pkce");
|
||||||
|
let url = OAuthAuthorizationRequest::from_config(
|
||||||
|
&config,
|
||||||
|
loopback_redirect_uri(4545),
|
||||||
|
"state-123",
|
||||||
|
&pair,
|
||||||
|
)
|
||||||
|
.with_extra_param("login_hint", "user@example.com")
|
||||||
|
.build_url();
|
||||||
|
assert!(url.starts_with("https://console.test/oauth/authorize?"));
|
||||||
|
assert!(url.contains("response_type=code"));
|
||||||
|
assert!(url.contains("client_id=runtime-client"));
|
||||||
|
assert!(url.contains("scope=org%3Aread%20user%3Awrite"));
|
||||||
|
assert!(url.contains("login_hint=user%40example.com"));
|
||||||
|
|
||||||
|
let exchange = OAuthTokenExchangeRequest::from_config(
|
||||||
|
&config,
|
||||||
|
"auth-code",
|
||||||
|
"state-123",
|
||||||
|
pair.verifier,
|
||||||
|
loopback_redirect_uri(4545),
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
exchange.form_params().get("grant_type").map(String::as_str),
|
||||||
|
Some("authorization_code")
|
||||||
|
);
|
||||||
|
|
||||||
|
let refresh = OAuthRefreshRequest::from_config(&config, "refresh-token", None);
|
||||||
|
assert_eq!(
|
||||||
|
refresh.form_params().get("scope").map(String::as_str),
|
||||||
|
Some("org:read user:write")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
401
rust/crates/runtime/src/remote.rs
Normal file
401
rust/crates/runtime/src/remote.rs
Normal file
@@ -0,0 +1,401 @@
|
|||||||
|
use std::collections::BTreeMap;
|
||||||
|
use std::env;
|
||||||
|
use std::fs;
|
||||||
|
use std::io;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
pub const DEFAULT_REMOTE_BASE_URL: &str = "https://api.anthropic.com";
|
||||||
|
pub const DEFAULT_SESSION_TOKEN_PATH: &str = "/run/ccr/session_token";
|
||||||
|
pub const DEFAULT_SYSTEM_CA_BUNDLE: &str = "/etc/ssl/certs/ca-certificates.crt";
|
||||||
|
|
||||||
|
pub const UPSTREAM_PROXY_ENV_KEYS: [&str; 8] = [
|
||||||
|
"HTTPS_PROXY",
|
||||||
|
"https_proxy",
|
||||||
|
"NO_PROXY",
|
||||||
|
"no_proxy",
|
||||||
|
"SSL_CERT_FILE",
|
||||||
|
"NODE_EXTRA_CA_CERTS",
|
||||||
|
"REQUESTS_CA_BUNDLE",
|
||||||
|
"CURL_CA_BUNDLE",
|
||||||
|
];
|
||||||
|
|
||||||
|
pub const NO_PROXY_HOSTS: [&str; 16] = [
|
||||||
|
"localhost",
|
||||||
|
"127.0.0.1",
|
||||||
|
"::1",
|
||||||
|
"169.254.0.0/16",
|
||||||
|
"10.0.0.0/8",
|
||||||
|
"172.16.0.0/12",
|
||||||
|
"192.168.0.0/16",
|
||||||
|
"anthropic.com",
|
||||||
|
".anthropic.com",
|
||||||
|
"*.anthropic.com",
|
||||||
|
"github.com",
|
||||||
|
"api.github.com",
|
||||||
|
"*.github.com",
|
||||||
|
"*.githubusercontent.com",
|
||||||
|
"registry.npmjs.org",
|
||||||
|
"index.crates.io",
|
||||||
|
];
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct RemoteSessionContext {
|
||||||
|
pub enabled: bool,
|
||||||
|
pub session_id: Option<String>,
|
||||||
|
pub base_url: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct UpstreamProxyBootstrap {
|
||||||
|
pub remote: RemoteSessionContext,
|
||||||
|
pub upstream_proxy_enabled: bool,
|
||||||
|
pub token_path: PathBuf,
|
||||||
|
pub ca_bundle_path: PathBuf,
|
||||||
|
pub system_ca_path: PathBuf,
|
||||||
|
pub token: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct UpstreamProxyState {
|
||||||
|
pub enabled: bool,
|
||||||
|
pub proxy_url: Option<String>,
|
||||||
|
pub ca_bundle_path: Option<PathBuf>,
|
||||||
|
pub no_proxy: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RemoteSessionContext {
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_env() -> Self {
|
||||||
|
Self::from_env_map(&env::vars().collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_env_map(env_map: &BTreeMap<String, String>) -> Self {
|
||||||
|
Self {
|
||||||
|
enabled: env_truthy(env_map.get("CLAUDE_CODE_REMOTE")),
|
||||||
|
session_id: env_map
|
||||||
|
.get("CLAUDE_CODE_REMOTE_SESSION_ID")
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
.cloned(),
|
||||||
|
base_url: env_map
|
||||||
|
.get("ANTHROPIC_BASE_URL")
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_else(|| DEFAULT_REMOTE_BASE_URL.to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UpstreamProxyBootstrap {
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_env() -> Self {
|
||||||
|
Self::from_env_map(&env::vars().collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_env_map(env_map: &BTreeMap<String, String>) -> Self {
|
||||||
|
let remote = RemoteSessionContext::from_env_map(env_map);
|
||||||
|
let token_path = env_map
|
||||||
|
.get("CCR_SESSION_TOKEN_PATH")
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
.map_or_else(|| PathBuf::from(DEFAULT_SESSION_TOKEN_PATH), PathBuf::from);
|
||||||
|
let system_ca_path = env_map
|
||||||
|
.get("CCR_SYSTEM_CA_BUNDLE")
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
.map_or_else(|| PathBuf::from(DEFAULT_SYSTEM_CA_BUNDLE), PathBuf::from);
|
||||||
|
let ca_bundle_path = env_map
|
||||||
|
.get("CCR_CA_BUNDLE_PATH")
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
.map_or_else(default_ca_bundle_path, PathBuf::from);
|
||||||
|
let token = read_token(&token_path).ok().flatten();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
remote,
|
||||||
|
upstream_proxy_enabled: env_truthy(env_map.get("CCR_UPSTREAM_PROXY_ENABLED")),
|
||||||
|
token_path,
|
||||||
|
ca_bundle_path,
|
||||||
|
system_ca_path,
|
||||||
|
token,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn should_enable(&self) -> bool {
|
||||||
|
self.remote.enabled
|
||||||
|
&& self.upstream_proxy_enabled
|
||||||
|
&& self.remote.session_id.is_some()
|
||||||
|
&& self.token.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn ws_url(&self) -> String {
|
||||||
|
upstream_proxy_ws_url(&self.remote.base_url)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn state_for_port(&self, port: u16) -> UpstreamProxyState {
|
||||||
|
if !self.should_enable() {
|
||||||
|
return UpstreamProxyState::disabled();
|
||||||
|
}
|
||||||
|
UpstreamProxyState {
|
||||||
|
enabled: true,
|
||||||
|
proxy_url: Some(format!("http://127.0.0.1:{port}")),
|
||||||
|
ca_bundle_path: Some(self.ca_bundle_path.clone()),
|
||||||
|
no_proxy: no_proxy_list(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UpstreamProxyState {
|
||||||
|
#[must_use]
|
||||||
|
pub fn disabled() -> Self {
|
||||||
|
Self {
|
||||||
|
enabled: false,
|
||||||
|
proxy_url: None,
|
||||||
|
ca_bundle_path: None,
|
||||||
|
no_proxy: no_proxy_list(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn subprocess_env(&self) -> BTreeMap<String, String> {
|
||||||
|
if !self.enabled {
|
||||||
|
return BTreeMap::new();
|
||||||
|
}
|
||||||
|
let Some(proxy_url) = &self.proxy_url else {
|
||||||
|
return BTreeMap::new();
|
||||||
|
};
|
||||||
|
let Some(ca_bundle_path) = &self.ca_bundle_path else {
|
||||||
|
return BTreeMap::new();
|
||||||
|
};
|
||||||
|
let ca_bundle_path = ca_bundle_path.to_string_lossy().into_owned();
|
||||||
|
BTreeMap::from([
|
||||||
|
("HTTPS_PROXY".to_string(), proxy_url.clone()),
|
||||||
|
("https_proxy".to_string(), proxy_url.clone()),
|
||||||
|
("NO_PROXY".to_string(), self.no_proxy.clone()),
|
||||||
|
("no_proxy".to_string(), self.no_proxy.clone()),
|
||||||
|
("SSL_CERT_FILE".to_string(), ca_bundle_path.clone()),
|
||||||
|
("NODE_EXTRA_CA_CERTS".to_string(), ca_bundle_path.clone()),
|
||||||
|
("REQUESTS_CA_BUNDLE".to_string(), ca_bundle_path.clone()),
|
||||||
|
("CURL_CA_BUNDLE".to_string(), ca_bundle_path),
|
||||||
|
])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read_token(path: &Path) -> io::Result<Option<String>> {
|
||||||
|
match fs::read_to_string(path) {
|
||||||
|
Ok(contents) => {
|
||||||
|
let token = contents.trim();
|
||||||
|
if token.is_empty() {
|
||||||
|
Ok(None)
|
||||||
|
} else {
|
||||||
|
Ok(Some(token.to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(None),
|
||||||
|
Err(error) => Err(error),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn upstream_proxy_ws_url(base_url: &str) -> String {
|
||||||
|
let base = base_url.trim_end_matches('/');
|
||||||
|
let ws_base = if let Some(stripped) = base.strip_prefix("https://") {
|
||||||
|
format!("wss://{stripped}")
|
||||||
|
} else if let Some(stripped) = base.strip_prefix("http://") {
|
||||||
|
format!("ws://{stripped}")
|
||||||
|
} else {
|
||||||
|
format!("wss://{base}")
|
||||||
|
};
|
||||||
|
format!("{ws_base}/v1/code/upstreamproxy/ws")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn no_proxy_list() -> String {
|
||||||
|
let mut hosts = NO_PROXY_HOSTS.to_vec();
|
||||||
|
hosts.extend(["pypi.org", "files.pythonhosted.org", "proxy.golang.org"]);
|
||||||
|
hosts.join(",")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn inherited_upstream_proxy_env(
|
||||||
|
env_map: &BTreeMap<String, String>,
|
||||||
|
) -> BTreeMap<String, String> {
|
||||||
|
if !(env_map.contains_key("HTTPS_PROXY") && env_map.contains_key("SSL_CERT_FILE")) {
|
||||||
|
return BTreeMap::new();
|
||||||
|
}
|
||||||
|
UPSTREAM_PROXY_ENV_KEYS
|
||||||
|
.iter()
|
||||||
|
.filter_map(|key| {
|
||||||
|
env_map
|
||||||
|
.get(*key)
|
||||||
|
.map(|value| ((*key).to_string(), value.clone()))
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_ca_bundle_path() -> PathBuf {
|
||||||
|
env::var_os("HOME")
|
||||||
|
.map_or_else(|| PathBuf::from("."), PathBuf::from)
|
||||||
|
.join(".ccr")
|
||||||
|
.join("ca-bundle.crt")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn env_truthy(value: Option<&String>) -> bool {
|
||||||
|
value.is_some_and(|raw| {
|
||||||
|
matches!(
|
||||||
|
raw.trim().to_ascii_lowercase().as_str(),
|
||||||
|
"1" | "true" | "yes" | "on"
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::{
|
||||||
|
inherited_upstream_proxy_env, no_proxy_list, read_token, upstream_proxy_ws_url,
|
||||||
|
RemoteSessionContext, UpstreamProxyBootstrap,
|
||||||
|
};
|
||||||
|
use std::collections::BTreeMap;
|
||||||
|
use std::fs;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
|
fn temp_dir() -> PathBuf {
|
||||||
|
let nanos = SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.expect("time should be after epoch")
|
||||||
|
.as_nanos();
|
||||||
|
std::env::temp_dir().join(format!("runtime-remote-{nanos}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn remote_context_reads_env_state() {
|
||||||
|
let env = BTreeMap::from([
|
||||||
|
("CLAUDE_CODE_REMOTE".to_string(), "true".to_string()),
|
||||||
|
(
|
||||||
|
"CLAUDE_CODE_REMOTE_SESSION_ID".to_string(),
|
||||||
|
"session-123".to_string(),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"ANTHROPIC_BASE_URL".to_string(),
|
||||||
|
"https://remote.test".to_string(),
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
let context = RemoteSessionContext::from_env_map(&env);
|
||||||
|
assert!(context.enabled);
|
||||||
|
assert_eq!(context.session_id.as_deref(), Some("session-123"));
|
||||||
|
assert_eq!(context.base_url, "https://remote.test");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bootstrap_fails_open_when_token_or_session_is_missing() {
|
||||||
|
let env = BTreeMap::from([
|
||||||
|
("CLAUDE_CODE_REMOTE".to_string(), "1".to_string()),
|
||||||
|
("CCR_UPSTREAM_PROXY_ENABLED".to_string(), "true".to_string()),
|
||||||
|
]);
|
||||||
|
let bootstrap = UpstreamProxyBootstrap::from_env_map(&env);
|
||||||
|
assert!(!bootstrap.should_enable());
|
||||||
|
assert!(!bootstrap.state_for_port(8080).enabled);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bootstrap_derives_proxy_state_and_env() {
|
||||||
|
let root = temp_dir();
|
||||||
|
let token_path = root.join("session_token");
|
||||||
|
fs::create_dir_all(&root).expect("temp dir");
|
||||||
|
fs::write(&token_path, "secret-token\n").expect("write token");
|
||||||
|
|
||||||
|
let env = BTreeMap::from([
|
||||||
|
("CLAUDE_CODE_REMOTE".to_string(), "1".to_string()),
|
||||||
|
("CCR_UPSTREAM_PROXY_ENABLED".to_string(), "true".to_string()),
|
||||||
|
(
|
||||||
|
"CLAUDE_CODE_REMOTE_SESSION_ID".to_string(),
|
||||||
|
"session-123".to_string(),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"ANTHROPIC_BASE_URL".to_string(),
|
||||||
|
"https://remote.test".to_string(),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"CCR_SESSION_TOKEN_PATH".to_string(),
|
||||||
|
token_path.to_string_lossy().into_owned(),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"CCR_CA_BUNDLE_PATH".to_string(),
|
||||||
|
root.join("ca-bundle.crt").to_string_lossy().into_owned(),
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
|
||||||
|
let bootstrap = UpstreamProxyBootstrap::from_env_map(&env);
|
||||||
|
assert!(bootstrap.should_enable());
|
||||||
|
assert_eq!(bootstrap.token.as_deref(), Some("secret-token"));
|
||||||
|
assert_eq!(
|
||||||
|
bootstrap.ws_url(),
|
||||||
|
"wss://remote.test/v1/code/upstreamproxy/ws"
|
||||||
|
);
|
||||||
|
|
||||||
|
let state = bootstrap.state_for_port(9443);
|
||||||
|
assert!(state.enabled);
|
||||||
|
let env = state.subprocess_env();
|
||||||
|
assert_eq!(
|
||||||
|
env.get("HTTPS_PROXY").map(String::as_str),
|
||||||
|
Some("http://127.0.0.1:9443")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
env.get("SSL_CERT_FILE").map(String::as_str),
|
||||||
|
Some(root.join("ca-bundle.crt").to_string_lossy().as_ref())
|
||||||
|
);
|
||||||
|
|
||||||
|
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn token_reader_trims_and_handles_missing_files() {
|
||||||
|
let root = temp_dir();
|
||||||
|
fs::create_dir_all(&root).expect("temp dir");
|
||||||
|
let token_path = root.join("session_token");
|
||||||
|
fs::write(&token_path, " abc123 \n").expect("write token");
|
||||||
|
assert_eq!(
|
||||||
|
read_token(&token_path).expect("read token").as_deref(),
|
||||||
|
Some("abc123")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
read_token(&root.join("missing")).expect("missing token"),
|
||||||
|
None
|
||||||
|
);
|
||||||
|
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn inherited_proxy_env_requires_proxy_and_ca() {
|
||||||
|
let env = BTreeMap::from([
|
||||||
|
(
|
||||||
|
"HTTPS_PROXY".to_string(),
|
||||||
|
"http://127.0.0.1:8888".to_string(),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SSL_CERT_FILE".to_string(),
|
||||||
|
"/tmp/ca-bundle.crt".to_string(),
|
||||||
|
),
|
||||||
|
("NO_PROXY".to_string(), "localhost".to_string()),
|
||||||
|
]);
|
||||||
|
let inherited = inherited_upstream_proxy_env(&env);
|
||||||
|
assert_eq!(inherited.len(), 3);
|
||||||
|
assert_eq!(
|
||||||
|
inherited.get("NO_PROXY").map(String::as_str),
|
||||||
|
Some("localhost")
|
||||||
|
);
|
||||||
|
assert!(inherited_upstream_proxy_env(&BTreeMap::new()).is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn helper_outputs_match_expected_shapes() {
|
||||||
|
assert_eq!(
|
||||||
|
upstream_proxy_ws_url("http://localhost:3000/"),
|
||||||
|
"ws://localhost:3000/v1/code/upstreamproxy/ws"
|
||||||
|
);
|
||||||
|
assert!(no_proxy_list().contains("anthropic.com"));
|
||||||
|
assert!(no_proxy_list().contains("github.com"));
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user