From 32e89df6310e48afedda8052fdf4f1f42c87f450 Mon Sep 17 00:00:00 2001 From: Yeachan-Heo Date: Tue, 31 Mar 2026 23:38:05 +0000 Subject: [PATCH] Enable Claude OAuth login without requiring API keys This adds an end-to-end OAuth PKCE login/logout path to the Rust CLI, persists OAuth credentials under the Claude config home, and teaches the API client to use persisted bearer credentials with refresh support when env-based API credentials are absent. Constraint: Reuse existing runtime OAuth primitives and keep browser/callback orchestration in the CLI Constraint: Preserve auth precedence as API key, then auth-token env, then persisted OAuth credentials Rejected: Put browser launch and token exchange entirely in runtime | caused boundary creep across shared crates Rejected: Duplicate credential parsing in CLI and api | increased drift and refresh inconsistency Confidence: medium Scope-risk: moderate Reversibility: clean Directive: Keep logout non-destructive to unrelated credentials.json fields and do not silently fall back to stale expired tokens Tested: cargo fmt; cargo clippy --workspace --all-targets -- -D warnings; cargo test Not-tested: Manual live Anthropic OAuth browser flow against real authorize/token endpoints --- rust/Cargo.lock | 1 + rust/README.md | 25 ++- rust/crates/api/Cargo.toml | 1 + rust/crates/api/src/client.rs | 274 ++++++++++++++++++++++- rust/crates/api/src/error.rs | 11 + rust/crates/api/src/lib.rs | 5 +- rust/crates/runtime/src/lib.rs | 6 +- rust/crates/runtime/src/oauth.rs | 265 +++++++++++++++++++++- rust/crates/rusty-claude-cli/src/args.rs | 13 ++ rust/crates/rusty-claude-cli/src/main.rs | 179 ++++++++++++++- 10 files changed, 753 insertions(+), 27 deletions(-) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 548466a..9030127 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -22,6 +22,7 @@ name = "api" version = "0.1.0" dependencies = [ "reqwest", + "runtime", "serde", "serde_json", "tokio", diff --git a/rust/README.md b/rust/README.md index dadefe3..8bc5787 100644 --- a/rust/README.md +++ b/rust/README.md @@ -64,6 +64,26 @@ cd rust cargo run -p rusty-claude-cli -- --version ``` +### Login with OAuth + +Configure `settings.json` with an `oauth` block containing `clientId`, `authorizeUrl`, `tokenUrl`, optional `callbackPort`, and optional `scopes`, then run: + +```bash +cd rust +cargo run -p rusty-claude-cli -- login +``` + +This opens the browser, listens on the configured localhost callback, exchanges the auth code for tokens, and stores OAuth credentials in `~/.claude/credentials.json` (or `$CLAUDE_CONFIG_HOME/credentials.json`). + +### Logout + +```bash +cd rust +cargo run -p rusty-claude-cli -- logout +``` + +This removes only the stored OAuth credentials and preserves unrelated JSON fields in `credentials.json`. + ## Usage examples ### 1) Prompt mode @@ -153,8 +173,9 @@ cargo run -p rusty-claude-cli -- --resume session.json /memory /config ### Anthropic/API -- `ANTHROPIC_AUTH_TOKEN` — preferred bearer token for API auth -- `ANTHROPIC_API_KEY` — legacy API key fallback if auth token is unset +- `ANTHROPIC_API_KEY` — highest-precedence API credential +- `ANTHROPIC_AUTH_TOKEN` — bearer-token override used when no API key is set +- Persisted OAuth credentials in `~/.claude/credentials.json` — used when neither env var is set - `ANTHROPIC_BASE_URL` — override the Anthropic API base URL - `ANTHROPIC_MODEL` — default model used by selected live integration tests diff --git a/rust/crates/api/Cargo.toml b/rust/crates/api/Cargo.toml index 32c4865..c5e152e 100644 --- a/rust/crates/api/Cargo.toml +++ b/rust/crates/api/Cargo.toml @@ -7,6 +7,7 @@ publish.workspace = true [dependencies] reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } +runtime = { path = "../runtime" } serde = { version = "1", features = ["derive"] } serde_json = "1" tokio = { version = "1", features = ["io-util", "macros", "net", "rt-multi-thread", "time"] } diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs index 5e7d319..9bfe422 100644 --- a/rust/crates/api/src/client.rs +++ b/rust/crates/api/src/client.rs @@ -1,6 +1,10 @@ use std::collections::VecDeque; -use std::time::Duration; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use runtime::{ + load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest, + OAuthTokenExchangeRequest, +}; use serde::Deserialize; use crate::error::ApiError; @@ -81,11 +85,12 @@ impl AuthSource { } } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] pub struct OAuthTokenSet { pub access_token: String, pub refresh_token: Option, pub expires_at: Option, + #[serde(default)] pub scopes: Vec, } @@ -131,7 +136,7 @@ impl AnthropicClient { } pub fn from_env() -> Result { - Ok(Self::from_auth(AuthSource::from_env()?).with_base_url(read_base_url())) + Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url())) } #[must_use] @@ -225,6 +230,46 @@ impl AnthropicClient { }) } + pub async fn exchange_oauth_code( + &self, + config: &OAuthConfig, + request: &OAuthTokenExchangeRequest, + ) -> Result { + let response = self + .http + .post(&config.token_url) + .header("content-type", "application/x-www-form-urlencoded") + .form(&request.form_params()) + .send() + .await + .map_err(ApiError::from)?; + let response = expect_success(response).await?; + response + .json::() + .await + .map_err(ApiError::from) + } + + pub async fn refresh_oauth_token( + &self, + config: &OAuthConfig, + request: &OAuthRefreshRequest, + ) -> Result { + let response = self + .http + .post(&config.token_url) + .header("content-type", "application/x-www-form-urlencoded") + .form(&request.form_params()) + .send() + .await + .map_err(ApiError::from)?; + let response = expect_success(response).await?; + response + .json::() + .await + .map_err(ApiError::from) + } + async fn send_with_retry( &self, request: &MessageRequest, @@ -304,6 +349,99 @@ impl AnthropicClient { } } +impl AuthSource { + pub fn from_env_or_saved() -> Result { + if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? { + return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + Some(bearer_token) => Ok(Self::ApiKeyAndBearer { + api_key, + bearer_token, + }), + None => Ok(Self::ApiKey(api_key)), + }; + } + if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + return Ok(Self::BearerToken(bearer_token)); + } + match load_saved_oauth_token() { + Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => { + if token_set.refresh_token.is_some() { + Err(ApiError::Auth( + "saved OAuth token is expired; load runtime OAuth config to refresh it" + .to_string(), + )) + } else { + Err(ApiError::ExpiredOAuthToken) + } + } + Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)), + Ok(None) => Err(ApiError::MissingApiKey), + Err(error) => Err(error), + } + } +} + +#[must_use] +pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool { + token_set + .expires_at + .is_some_and(|expires_at| expires_at <= now_unix_timestamp()) +} + +pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result, ApiError> { + let Some(token_set) = load_saved_oauth_token()? else { + return Ok(None); + }; + if !oauth_token_is_expired(&token_set) { + return Ok(Some(token_set)); + } + let Some(refresh_token) = token_set.refresh_token.clone() else { + return Err(ApiError::ExpiredOAuthToken); + }; + let client = AnthropicClient::from_auth(AuthSource::None).with_base_url(read_base_url()); + let refreshed = client_runtime_block_on(async { + client + .refresh_oauth_token( + config, + &OAuthRefreshRequest::from_config(config, refresh_token, Some(token_set.scopes)), + ) + .await + })?; + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: refreshed.access_token.clone(), + refresh_token: refreshed.refresh_token.clone(), + expires_at: refreshed.expires_at, + scopes: refreshed.scopes.clone(), + }) + .map_err(ApiError::from)?; + Ok(Some(refreshed)) +} + +fn client_runtime_block_on(future: F) -> Result +where + F: std::future::Future>, +{ + tokio::runtime::Runtime::new() + .map_err(ApiError::from)? + .block_on(future) +} + +fn load_saved_oauth_token() -> Result, ApiError> { + let token_set = load_oauth_credentials().map_err(ApiError::from)?; + Ok(token_set.map(|token_set| OAuthTokenSet { + access_token: token_set.access_token, + refresh_token: token_set.refresh_token, + expires_at: token_set.expires_at, + scopes: token_set.scopes, + })) +} + +fn now_unix_timestamp() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_or(0, |duration| duration.as_secs()) +} + fn read_env_non_empty(key: &str) -> Result, ApiError> { match std::env::var(key) { Ok(value) if !value.is_empty() => Ok(Some(value)), @@ -314,7 +452,7 @@ fn read_env_non_empty(key: &str) -> Result, ApiError> { #[cfg(test)] fn read_api_key() -> Result { - let auth = AuthSource::from_env()?; + let auth = AuthSource::from_env_or_saved()?; auth.api_key() .or_else(|| auth.bearer_token()) .map(ToOwned::to_owned) @@ -424,10 +562,18 @@ struct AnthropicErrorBody { #[cfg(test)] mod tests { use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER}; + use std::io::{Read, Write}; + use std::net::TcpListener; use std::sync::{Mutex, OnceLock}; - use std::time::Duration; + use std::thread; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; - use crate::client::{AuthSource, OAuthTokenSet}; + use runtime::{clear_oauth_credentials, save_oauth_credentials, OAuthConfig}; + + use crate::client::{ + now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token, AnthropicClient, + AuthSource, OAuthTokenSet, + }; use crate::types::{ContentBlockDelta, MessageRequest}; fn env_lock() -> std::sync::MutexGuard<'static, ()> { @@ -437,11 +583,53 @@ mod tests { .expect("env lock") } + fn temp_config_home() -> std::path::PathBuf { + std::env::temp_dir().join(format!( + "api-oauth-test-{}-{}", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time") + .as_nanos() + )) + } + + fn sample_oauth_config(token_url: String) -> OAuthConfig { + OAuthConfig { + client_id: "runtime-client".to_string(), + authorize_url: "https://console.test/oauth/authorize".to_string(), + token_url, + callback_port: Some(4545), + manual_redirect_url: Some("https://console.test/oauth/callback".to_string()), + scopes: vec!["org:read".to_string(), "user:write".to_string()], + } + } + + fn spawn_token_server(response_body: &'static str) -> String { + let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener"); + let address = listener.local_addr().expect("local addr"); + thread::spawn(move || { + let (mut stream, _) = listener.accept().expect("accept connection"); + let mut buffer = [0_u8; 4096]; + let _ = stream.read(&mut buffer).expect("read request"); + let response = format!( + "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}", + response_body.len(), + response_body + ); + stream + .write_all(response.as_bytes()) + .expect("write response"); + }); + format!("http://{address}/oauth/token") + } + #[test] fn read_api_key_requires_presence() { let _guard = env_lock(); std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_API_KEY"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); let error = super::read_api_key().expect_err("missing key should error"); assert!(matches!(error, crate::error::ApiError::MissingApiKey)); } @@ -453,6 +641,7 @@ mod tests { std::env::remove_var("ANTHROPIC_API_KEY"); let error = super::read_api_key().expect_err("empty key should error"); assert!(matches!(error, crate::error::ApiError::MissingApiKey)); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); } #[test] @@ -500,6 +689,77 @@ mod tests { std::env::remove_var("ANTHROPIC_API_KEY"); } + #[test] + fn auth_source_from_saved_oauth_when_env_absent() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "saved-access-token".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(now_unix_timestamp() + 300), + scopes: vec!["scope:a".to_string()], + }) + .expect("save oauth credentials"); + + let auth = AuthSource::from_env_or_saved().expect("saved auth"); + assert_eq!(auth.bearer_token(), Some("saved-access-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); + } + + #[test] + fn oauth_token_expiry_uses_expires_at_timestamp() { + assert!(oauth_token_is_expired(&OAuthTokenSet { + access_token: "access-token".to_string(), + refresh_token: None, + expires_at: Some(1), + scopes: Vec::new(), + })); + assert!(!oauth_token_is_expired(&OAuthTokenSet { + access_token: "access-token".to_string(), + refresh_token: None, + expires_at: Some(now_unix_timestamp() + 60), + scopes: Vec::new(), + })); + } + + #[test] + fn resolve_saved_oauth_token_refreshes_expired_credentials() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "expired-access-token".to_string(), + refresh_token: Some("refresh-token".to_string()), + expires_at: Some(1), + scopes: vec!["scope:a".to_string()], + }) + .expect("save expired oauth credentials"); + + let token_url = spawn_token_server( + "{\"access_token\":\"refreshed-token\",\"refresh_token\":\"fresh-refresh\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}", + ); + let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url)) + .expect("resolve refreshed token") + .expect("token set present"); + assert_eq!(resolved.access_token, "refreshed-token"); + let stored = runtime::load_oauth_credentials() + .expect("load stored credentials") + .expect("stored token set"); + assert_eq!(stored.access_token, "refreshed-token"); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); + } + #[test] fn message_request_stream_helper_sets_stream_true() { let request = MessageRequest { @@ -517,7 +777,7 @@ mod tests { #[test] fn backoff_doubles_until_maximum() { - let client = super::AnthropicClient::new("test-key").with_retry_policy( + let client = AnthropicClient::new("test-key").with_retry_policy( 3, Duration::from_millis(10), Duration::from_millis(25), diff --git a/rust/crates/api/src/error.rs b/rust/crates/api/src/error.rs index 02ec584..2c31691 100644 --- a/rust/crates/api/src/error.rs +++ b/rust/crates/api/src/error.rs @@ -5,6 +5,8 @@ use std::time::Duration; #[derive(Debug)] pub enum ApiError { MissingApiKey, + ExpiredOAuthToken, + Auth(String), InvalidApiKeyEnv(VarError), Http(reqwest::Error), Io(std::io::Error), @@ -35,6 +37,8 @@ impl ApiError { Self::Api { retryable, .. } => *retryable, Self::RetriesExhausted { last_error, .. } => last_error.is_retryable(), Self::MissingApiKey + | Self::ExpiredOAuthToken + | Self::Auth(_) | Self::InvalidApiKeyEnv(_) | Self::Io(_) | Self::Json(_) @@ -53,6 +57,13 @@ impl Display for ApiError { "ANTHROPIC_AUTH_TOKEN or ANTHROPIC_API_KEY is not set; export one before calling the Anthropic API" ) } + Self::ExpiredOAuthToken => { + write!( + f, + "saved OAuth token is expired and no refresh token is available" + ) + } + Self::Auth(message) => write!(f, "auth error: {message}"), Self::InvalidApiKeyEnv(error) => { write!( f, diff --git a/rust/crates/api/src/lib.rs b/rust/crates/api/src/lib.rs index 9d587ee..048cd58 100644 --- a/rust/crates/api/src/lib.rs +++ b/rust/crates/api/src/lib.rs @@ -3,7 +3,10 @@ mod error; mod sse; mod types; -pub use client::{AnthropicClient, AuthSource, MessageStream, OAuthTokenSet}; +pub use client::{ + oauth_token_is_expired, resolve_saved_oauth_token, AnthropicClient, AuthSource, MessageStream, + OAuthTokenSet, +}; pub use error::ApiError; pub use sse::{parse_frame, SseParser}; pub use types::{ diff --git a/rust/crates/runtime/src/lib.rs b/rust/crates/runtime/src/lib.rs index 1d7af28..1f22571 100644 --- a/rust/crates/runtime/src/lib.rs +++ b/rust/crates/runtime/src/lib.rs @@ -52,8 +52,10 @@ pub use mcp_stdio::{ McpStdioProcess, McpTool, McpToolCallContent, McpToolCallParams, McpToolCallResult, }; pub use oauth::{ - code_challenge_s256, generate_pkce_pair, generate_state, loopback_redirect_uri, - OAuthAuthorizationRequest, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet, + clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair, + generate_state, load_oauth_credentials, loopback_redirect_uri, parse_oauth_callback_query, + parse_oauth_callback_request_target, save_oauth_credentials, OAuthAuthorizationRequest, + OAuthCallbackParams, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet, PkceChallengeMethod, PkceCodePair, }; pub use permissions::{ diff --git a/rust/crates/runtime/src/oauth.rs b/rust/crates/runtime/src/oauth.rs index 320a8ee..db68bf9 100644 --- a/rust/crates/runtime/src/oauth.rs +++ b/rust/crates/runtime/src/oauth.rs @@ -1,12 +1,15 @@ use std::collections::BTreeMap; -use std::fs::File; +use std::fs::{self, File}; use std::io::{self, Read}; +use std::path::PathBuf; +use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value}; use sha2::{Digest, Sha256}; use crate::config::OAuthConfig; -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct OAuthTokenSet { pub access_token: String, pub refresh_token: Option, @@ -65,6 +68,48 @@ pub struct OAuthRefreshRequest { pub scopes: Vec, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct OAuthCallbackParams { + pub code: Option, + pub state: Option, + pub error: Option, + pub error_description: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct StoredOAuthCredentials { + access_token: String, + #[serde(default)] + refresh_token: Option, + #[serde(default)] + expires_at: Option, + #[serde(default)] + scopes: Vec, +} + +impl From for StoredOAuthCredentials { + fn from(value: OAuthTokenSet) -> Self { + Self { + access_token: value.access_token, + refresh_token: value.refresh_token, + expires_at: value.expires_at, + scopes: value.scopes, + } + } +} + +impl From for OAuthTokenSet { + fn from(value: StoredOAuthCredentials) -> Self { + Self { + access_token: value.access_token, + refresh_token: value.refresh_token, + expires_at: value.expires_at, + scopes: value.scopes, + } + } +} + impl OAuthAuthorizationRequest { #[must_use] pub fn from_config( @@ -137,7 +182,6 @@ impl OAuthTokenExchangeRequest { verifier: impl Into, redirect_uri: impl Into, ) -> Self { - let _ = config; Self { grant_type: "authorization_code", code: code.into(), @@ -211,12 +255,116 @@ pub fn loopback_redirect_uri(port: u16) -> String { format!("http://localhost:{port}/callback") } +pub fn credentials_path() -> io::Result { + Ok(credentials_home_dir()?.join("credentials.json")) +} + +pub fn load_oauth_credentials() -> io::Result> { + let path = credentials_path()?; + let root = read_credentials_root(&path)?; + let Some(oauth) = root.get("oauth") else { + return Ok(None); + }; + if oauth.is_null() { + return Ok(None); + } + let stored = serde_json::from_value::(oauth.clone()) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; + Ok(Some(stored.into())) +} + +pub fn save_oauth_credentials(token_set: &OAuthTokenSet) -> io::Result<()> { + let path = credentials_path()?; + let mut root = read_credentials_root(&path)?; + root.insert( + "oauth".to_string(), + serde_json::to_value(StoredOAuthCredentials::from(token_set.clone())) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?, + ); + write_credentials_root(&path, &root) +} + +pub fn clear_oauth_credentials() -> io::Result<()> { + let path = credentials_path()?; + let mut root = read_credentials_root(&path)?; + root.remove("oauth"); + write_credentials_root(&path, &root) +} + +pub fn parse_oauth_callback_request_target(target: &str) -> Result { + let (path, query) = target + .split_once('?') + .map_or((target, ""), |(path, query)| (path, query)); + if path != "/callback" { + return Err(format!("unexpected callback path: {path}")); + } + parse_oauth_callback_query(query) +} + +pub fn parse_oauth_callback_query(query: &str) -> Result { + let mut params = BTreeMap::new(); + for pair in query.split('&').filter(|pair| !pair.is_empty()) { + let (key, value) = pair + .split_once('=') + .map_or((pair, ""), |(key, value)| (key, value)); + params.insert(percent_decode(key)?, percent_decode(value)?); + } + Ok(OAuthCallbackParams { + code: params.get("code").cloned(), + state: params.get("state").cloned(), + error: params.get("error").cloned(), + error_description: params.get("error_description").cloned(), + }) +} + fn generate_random_token(bytes: usize) -> io::Result { let mut buffer = vec![0_u8; bytes]; File::open("/dev/urandom")?.read_exact(&mut buffer)?; Ok(base64url_encode(&buffer)) } +fn credentials_home_dir() -> io::Result { + if let Some(path) = std::env::var_os("CLAUDE_CONFIG_HOME") { + return Ok(PathBuf::from(path)); + } + let home = std::env::var_os("HOME") + .ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "HOME is not set"))?; + Ok(PathBuf::from(home).join(".claude")) +} + +fn read_credentials_root(path: &PathBuf) -> io::Result> { + match fs::read_to_string(path) { + Ok(contents) => { + if contents.trim().is_empty() { + return Ok(Map::new()); + } + serde_json::from_str::(&contents) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))? + .as_object() + .cloned() + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "credentials file must contain a JSON object", + ) + }) + } + Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(Map::new()), + Err(error) => Err(error), + } +} + +fn write_credentials_root(path: &PathBuf, root: &Map) -> io::Result<()> { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent)?; + } + let rendered = serde_json::to_string_pretty(&Value::Object(root.clone())) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; + let temp_path = path.with_extension("json.tmp"); + fs::write(&temp_path, format!("{rendered}\n"))?; + fs::rename(temp_path, path) +} + fn base64url_encode(bytes: &[u8]) -> String { const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; let mut output = String::new(); @@ -264,11 +412,50 @@ fn percent_encode(value: &str) -> String { encoded } +fn percent_decode(value: &str) -> Result { + let mut decoded = Vec::with_capacity(value.len()); + let bytes = value.as_bytes(); + let mut index = 0; + while index < bytes.len() { + match bytes[index] { + b'%' if index + 2 < bytes.len() => { + let hi = decode_hex(bytes[index + 1])?; + let lo = decode_hex(bytes[index + 2])?; + decoded.push((hi << 4) | lo); + index += 3; + } + b'+' => { + decoded.push(b' '); + index += 1; + } + byte => { + decoded.push(byte); + index += 1; + } + } + } + String::from_utf8(decoded).map_err(|error| error.to_string()) +} + +fn decode_hex(byte: u8) -> Result { + match byte { + b'0'..=b'9' => Ok(byte - b'0'), + b'a'..=b'f' => Ok(byte - b'a' + 10), + b'A'..=b'F' => Ok(byte - b'A' + 10), + _ => Err(format!("invalid percent-encoding byte: {byte}")), + } +} + #[cfg(test)] mod tests { + use std::sync::{Mutex, OnceLock}; + use std::time::{SystemTime, UNIX_EPOCH}; + use super::{ - code_challenge_s256, generate_pkce_pair, generate_state, loopback_redirect_uri, - OAuthAuthorizationRequest, OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest, + clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair, + generate_state, load_oauth_credentials, loopback_redirect_uri, parse_oauth_callback_query, + parse_oauth_callback_request_target, save_oauth_credentials, OAuthAuthorizationRequest, + OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet, }; fn sample_config() -> OAuthConfig { @@ -282,6 +469,24 @@ mod tests { } } + fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .expect("env lock") + } + + fn temp_config_home() -> std::path::PathBuf { + std::env::temp_dir().join(format!( + "runtime-oauth-test-{}-{}", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time") + .as_nanos() + )) + } + #[test] fn s256_challenge_matches_expected_vector() { assert_eq!( @@ -335,4 +540,54 @@ mod tests { Some("org:read user:write") ); } + + #[test] + fn oauth_credentials_round_trip_and_clear_preserves_other_fields() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); + let path = credentials_path().expect("credentials path"); + std::fs::create_dir_all(path.parent().expect("parent")).expect("create parent"); + std::fs::write(&path, "{\"other\":\"value\"}\n").expect("seed credentials"); + + let token_set = OAuthTokenSet { + access_token: "access-token".to_string(), + refresh_token: Some("refresh-token".to_string()), + expires_at: Some(123), + scopes: vec!["scope:a".to_string()], + }; + save_oauth_credentials(&token_set).expect("save credentials"); + assert_eq!( + load_oauth_credentials().expect("load credentials"), + Some(token_set) + ); + let saved = std::fs::read_to_string(&path).expect("read saved file"); + assert!(saved.contains("\"other\": \"value\"")); + assert!(saved.contains("\"oauth\"")); + + clear_oauth_credentials().expect("clear credentials"); + assert_eq!(load_oauth_credentials().expect("load cleared"), None); + let cleared = std::fs::read_to_string(&path).expect("read cleared file"); + assert!(cleared.contains("\"other\": \"value\"")); + assert!(!cleared.contains("\"oauth\"")); + + std::env::remove_var("CLAUDE_CONFIG_HOME"); + std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); + } + + #[test] + fn parses_callback_query_and_target() { + let params = + parse_oauth_callback_query("code=abc123&state=state-1&error_description=needs%20login") + .expect("parse query"); + assert_eq!(params.code.as_deref(), Some("abc123")); + assert_eq!(params.state.as_deref(), Some("state-1")); + assert_eq!(params.error_description.as_deref(), Some("needs login")); + + let params = parse_oauth_callback_request_target("/callback?code=abc&state=xyz") + .expect("parse callback target"); + assert_eq!(params.code.as_deref(), Some("abc")); + assert_eq!(params.state.as_deref(), Some("xyz")); + assert!(parse_oauth_callback_request_target("/wrong?code=abc").is_err()); + } } diff --git a/rust/crates/rusty-claude-cli/src/args.rs b/rust/crates/rusty-claude-cli/src/args.rs index d2e0851..6c98269 100644 --- a/rust/crates/rusty-claude-cli/src/args.rs +++ b/rust/crates/rusty-claude-cli/src/args.rs @@ -31,6 +31,10 @@ pub enum Command { DumpManifests, /// Print the current bootstrap phase skeleton BootstrapPlan, + /// Start the OAuth login flow + Login, + /// Clear saved OAuth credentials + Logout, /// Run a non-interactive prompt and exit Prompt { prompt: Vec }, } @@ -86,4 +90,13 @@ mod tests { }) ); } + + #[test] + fn parses_login_and_logout_commands() { + let login = Cli::parse_from(["rusty-claude-cli", "login"]); + assert_eq!(login.command, Some(Command::Login)); + + let logout = Cli::parse_from(["rusty-claude-cli", "logout"]); + assert_eq!(logout.command, Some(Command::Logout)); + } } diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index afbd550..e9a68e2 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -3,24 +3,28 @@ mod render; use std::env; use std::fs; -use std::io::{self, Write}; +use std::io::{self, Read, Write}; +use std::net::TcpListener; use std::path::{Path, PathBuf}; +use std::process::Command; use std::time::{SystemTime, UNIX_EPOCH}; use api::{ - AnthropicClient, ContentBlockDelta, InputContentBlock, InputMessage, MessageRequest, - MessageResponse, OutputContentBlock, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, - ToolResultContentBlock, + resolve_saved_oauth_token, AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock, + InputMessage, MessageRequest, MessageResponse, OutputContentBlock, + StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, }; use commands::{render_slash_command_help, resume_supported_slash_commands, SlashCommand}; use compat_harness::{extract_manifest, UpstreamPaths}; use render::{Spinner, TerminalRenderer}; use runtime::{ - load_system_prompt, ApiClient, ApiRequest, AssistantEvent, CompactionConfig, ConfigLoader, - ConfigSource, ContentBlock, ConversationMessage, ConversationRuntime, MessageRole, - PermissionMode, PermissionPolicy, ProjectContext, RuntimeError, Session, TokenUsage, ToolError, - ToolExecutor, UsageTracker, + clear_oauth_credentials, generate_pkce_pair, generate_state, load_system_prompt, + parse_oauth_callback_request_target, save_oauth_credentials, ApiClient, ApiRequest, + AssistantEvent, CompactionConfig, ConfigLoader, ConfigSource, ContentBlock, + ConversationMessage, ConversationRuntime, MessageRole, OAuthAuthorizationRequest, + OAuthTokenExchangeRequest, PermissionMode, PermissionPolicy, ProjectContext, RuntimeError, + Session, TokenUsage, ToolError, ToolExecutor, UsageTracker, }; use serde_json::json; use tools::{execute_tool, mvp_tool_specs}; @@ -28,6 +32,7 @@ use tools::{execute_tool, mvp_tool_specs}; const DEFAULT_MODEL: &str = "claude-sonnet-4-20250514"; const DEFAULT_MAX_TOKENS: u32 = 32; const DEFAULT_DATE: &str = "2026-03-31"; +const DEFAULT_OAUTH_CALLBACK_PORT: u16 = 4545; const VERSION: &str = env!("CARGO_PKG_VERSION"); const BUILD_TARGET: Option<&str> = option_env!("TARGET"); const GIT_SHA: Option<&str> = option_env!("GIT_SHA"); @@ -58,6 +63,8 @@ fn run() -> Result<(), Box> { model, output_format, } => LiveCli::new(model, false)?.run_turn_with_output(&prompt, output_format)?, + CliAction::Login => run_login()?, + CliAction::Logout => run_logout()?, CliAction::Repl { model } => run_repl(model)?, CliAction::Help => print_help(), } @@ -81,6 +88,8 @@ enum CliAction { model: String, output_format: CliOutputFormat, }, + Login, + Logout, Repl { model: String, }, @@ -157,6 +166,8 @@ fn parse_args(args: &[String]) -> Result { "dump-manifests" => Ok(CliAction::DumpManifests), "bootstrap-plan" => Ok(CliAction::BootstrapPlan), "system-prompt" => parse_system_prompt_args(&rest[1..]), + "login" => Ok(CliAction::Login), + "logout" => Ok(CliAction::Logout), "prompt" => { let prompt = rest[1..].join(" "); if prompt.trim().is_empty() { @@ -245,6 +256,122 @@ fn print_bootstrap_plan() { } } +fn run_login() -> Result<(), Box> { + let cwd = env::current_dir()?; + let config = ConfigLoader::default_for(&cwd).load()?; + let oauth = config.oauth().ok_or_else(|| { + io::Error::new( + io::ErrorKind::NotFound, + "OAuth config is missing. Add settings.oauth.clientId/authorizeUrl/tokenUrl first.", + ) + })?; + let callback_port = oauth.callback_port.unwrap_or(DEFAULT_OAUTH_CALLBACK_PORT); + let redirect_uri = runtime::loopback_redirect_uri(callback_port); + let pkce = generate_pkce_pair()?; + let state = generate_state()?; + let authorize_url = + OAuthAuthorizationRequest::from_config(oauth, redirect_uri.clone(), state.clone(), &pkce) + .build_url(); + + println!("Starting Claude OAuth login..."); + println!("Listening for callback on {redirect_uri}"); + if let Err(error) = open_browser(&authorize_url) { + eprintln!("warning: failed to open browser automatically: {error}"); + println!("Open this URL manually:\n{authorize_url}"); + } + + let callback = wait_for_oauth_callback(callback_port)?; + if let Some(error) = callback.error { + let description = callback + .error_description + .unwrap_or_else(|| "authorization failed".to_string()); + return Err(io::Error::other(format!("{error}: {description}")).into()); + } + let code = callback.code.ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "callback did not include code") + })?; + let returned_state = callback.state.ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "callback did not include state") + })?; + if returned_state != state { + return Err(io::Error::new(io::ErrorKind::InvalidData, "oauth state mismatch").into()); + } + + let client = AnthropicClient::from_auth(AuthSource::None); + let exchange_request = + OAuthTokenExchangeRequest::from_config(oauth, code, state, pkce.verifier, redirect_uri); + let runtime = tokio::runtime::Runtime::new()?; + let token_set = runtime.block_on(client.exchange_oauth_code(oauth, &exchange_request))?; + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: token_set.access_token, + refresh_token: token_set.refresh_token, + expires_at: token_set.expires_at, + scopes: token_set.scopes, + })?; + println!("Claude OAuth login complete."); + Ok(()) +} + +fn run_logout() -> Result<(), Box> { + clear_oauth_credentials()?; + println!("Claude OAuth credentials cleared."); + Ok(()) +} + +fn open_browser(url: &str) -> io::Result<()> { + let commands = if cfg!(target_os = "macos") { + vec![("open", vec![url])] + } else if cfg!(target_os = "windows") { + vec![("cmd", vec!["/C", "start", "", url])] + } else { + vec![("xdg-open", vec![url])] + }; + for (program, args) in commands { + match Command::new(program).args(args).spawn() { + Ok(_) => return Ok(()), + Err(error) if error.kind() == io::ErrorKind::NotFound => {} + Err(error) => return Err(error), + } + } + Err(io::Error::new( + io::ErrorKind::NotFound, + "no supported browser opener command found", + )) +} + +fn wait_for_oauth_callback( + port: u16, +) -> Result> { + let listener = TcpListener::bind(("127.0.0.1", port))?; + let (mut stream, _) = listener.accept()?; + let mut buffer = [0_u8; 4096]; + let bytes_read = stream.read(&mut buffer)?; + let request = String::from_utf8_lossy(&buffer[..bytes_read]); + let request_line = request.lines().next().ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "missing callback request line") + })?; + let target = request_line.split_whitespace().nth(1).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "missing callback request target", + ) + })?; + let callback = parse_oauth_callback_request_target(target) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; + let body = if callback.error.is_some() { + "Claude OAuth login failed. You can close this window." + } else { + "Claude OAuth login succeeded. You can close this window." + }; + let response = format!( + "HTTP/1.1 200 OK\r\ncontent-type: text/plain; charset=utf-8\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes())?; + Ok(callback) +} + fn print_system_prompt(cwd: PathBuf, date: String) { match load_system_prompt(cwd, date, env::consts::OS, "unknown") { Ok(sections) => println!("{}", sections.join("\n\n")), @@ -727,7 +854,7 @@ impl LiveCli { } fn run_prompt_json(&mut self, input: &str) -> Result<(), Box> { - let client = AnthropicClient::from_env()?; + let client = AnthropicClient::from_auth(resolve_cli_auth_source()?); let request = MessageRequest { model: self.model.clone(), max_tokens: DEFAULT_MAX_TOKENS, @@ -1610,13 +1737,30 @@ impl AnthropicRuntimeClient { fn new(model: String, enable_tools: bool) -> Result> { Ok(Self { runtime: tokio::runtime::Runtime::new()?, - client: AnthropicClient::from_env()?, + client: AnthropicClient::from_auth(resolve_cli_auth_source()?), model, enable_tools, }) } } +fn resolve_cli_auth_source() -> Result> { + match AuthSource::from_env() { + Ok(auth) => Ok(auth), + Err(api::ApiError::MissingApiKey) => { + let cwd = env::current_dir()?; + let config = ConfigLoader::default_for(&cwd).load()?; + if let Some(oauth) = config.oauth() { + if let Some(token_set) = resolve_saved_oauth_token(oauth)? { + return Ok(AuthSource::from(token_set)); + } + } + Ok(AuthSource::from_env_or_saved()?) + } + Err(error) => Err(Box::new(error)), + } +} + impl ApiClient for AnthropicRuntimeClient { #[allow(clippy::too_many_lines)] fn stream(&mut self, request: ApiRequest) -> Result, RuntimeError> { @@ -1875,6 +2019,8 @@ fn print_help() { println!(" rusty-claude-cli dump-manifests"); println!(" rusty-claude-cli bootstrap-plan"); println!(" rusty-claude-cli system-prompt [--cwd PATH] [--date YYYY-MM-DD]"); + println!(" rusty-claude-cli login"); + println!(" rusty-claude-cli logout"); println!(); println!("Flags:"); println!(" --model MODEL Override the active model"); @@ -1896,6 +2042,7 @@ fn print_help() { println!(" rusty-claude-cli --model claude-opus \"summarize this repo\""); println!(" rusty-claude-cli --output-format json prompt \"explain src/main.rs\""); println!(" rusty-claude-cli --resume session.json /status /diff /export notes.txt"); + println!(" rusty-claude-cli login"); } #[cfg(test)] @@ -1975,6 +2122,18 @@ mod tests { ); } + #[test] + fn parses_login_and_logout_subcommands() { + assert_eq!( + parse_args(&["login".to_string()]).expect("login should parse"), + CliAction::Login + ); + assert_eq!( + parse_args(&["logout".to_string()]).expect("logout should parse"), + CliAction::Logout + ); + } + #[test] fn parses_resume_flag_with_slash_command() { let args = vec![