Enable Claude OAuth login without requiring API keys
This adds an end-to-end OAuth PKCE login/logout path to the Rust CLI, persists OAuth credentials under the Claude config home, and teaches the API client to use persisted bearer credentials with refresh support when env-based API credentials are absent. Constraint: Reuse existing runtime OAuth primitives and keep browser/callback orchestration in the CLI Constraint: Preserve auth precedence as API key, then auth-token env, then persisted OAuth credentials Rejected: Put browser launch and token exchange entirely in runtime | caused boundary creep across shared crates Rejected: Duplicate credential parsing in CLI and api | increased drift and refresh inconsistency Confidence: medium Scope-risk: moderate Reversibility: clean Directive: Keep logout non-destructive to unrelated credentials.json fields and do not silently fall back to stale expired tokens Tested: cargo fmt; cargo clippy --workspace --all-targets -- -D warnings; cargo test Not-tested: Manual live Anthropic OAuth browser flow against real authorize/token endpoints
This commit is contained in:
@@ -52,8 +52,10 @@ pub use mcp_stdio::{
|
||||
McpStdioProcess, McpTool, McpToolCallContent, McpToolCallParams, McpToolCallResult,
|
||||
};
|
||||
pub use oauth::{
|
||||
code_challenge_s256, generate_pkce_pair, generate_state, loopback_redirect_uri,
|
||||
OAuthAuthorizationRequest, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet,
|
||||
clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair,
|
||||
generate_state, load_oauth_credentials, loopback_redirect_uri, parse_oauth_callback_query,
|
||||
parse_oauth_callback_request_target, save_oauth_credentials, OAuthAuthorizationRequest,
|
||||
OAuthCallbackParams, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet,
|
||||
PkceChallengeMethod, PkceCodePair,
|
||||
};
|
||||
pub use permissions::{
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::fs::File;
|
||||
use std::fs::{self, File};
|
||||
use std::io::{self, Read};
|
||||
use std::path::PathBuf;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Map, Value};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
use crate::config::OAuthConfig;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct OAuthTokenSet {
|
||||
pub access_token: String,
|
||||
pub refresh_token: Option<String>,
|
||||
@@ -65,6 +68,48 @@ pub struct OAuthRefreshRequest {
|
||||
pub scopes: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct OAuthCallbackParams {
|
||||
pub code: Option<String>,
|
||||
pub state: Option<String>,
|
||||
pub error: Option<String>,
|
||||
pub error_description: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct StoredOAuthCredentials {
|
||||
access_token: String,
|
||||
#[serde(default)]
|
||||
refresh_token: Option<String>,
|
||||
#[serde(default)]
|
||||
expires_at: Option<u64>,
|
||||
#[serde(default)]
|
||||
scopes: Vec<String>,
|
||||
}
|
||||
|
||||
impl From<OAuthTokenSet> for StoredOAuthCredentials {
|
||||
fn from(value: OAuthTokenSet) -> Self {
|
||||
Self {
|
||||
access_token: value.access_token,
|
||||
refresh_token: value.refresh_token,
|
||||
expires_at: value.expires_at,
|
||||
scopes: value.scopes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StoredOAuthCredentials> for OAuthTokenSet {
|
||||
fn from(value: StoredOAuthCredentials) -> Self {
|
||||
Self {
|
||||
access_token: value.access_token,
|
||||
refresh_token: value.refresh_token,
|
||||
expires_at: value.expires_at,
|
||||
scopes: value.scopes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl OAuthAuthorizationRequest {
|
||||
#[must_use]
|
||||
pub fn from_config(
|
||||
@@ -137,7 +182,6 @@ impl OAuthTokenExchangeRequest {
|
||||
verifier: impl Into<String>,
|
||||
redirect_uri: impl Into<String>,
|
||||
) -> Self {
|
||||
let _ = config;
|
||||
Self {
|
||||
grant_type: "authorization_code",
|
||||
code: code.into(),
|
||||
@@ -211,12 +255,116 @@ pub fn loopback_redirect_uri(port: u16) -> String {
|
||||
format!("http://localhost:{port}/callback")
|
||||
}
|
||||
|
||||
pub fn credentials_path() -> io::Result<PathBuf> {
|
||||
Ok(credentials_home_dir()?.join("credentials.json"))
|
||||
}
|
||||
|
||||
pub fn load_oauth_credentials() -> io::Result<Option<OAuthTokenSet>> {
|
||||
let path = credentials_path()?;
|
||||
let root = read_credentials_root(&path)?;
|
||||
let Some(oauth) = root.get("oauth") else {
|
||||
return Ok(None);
|
||||
};
|
||||
if oauth.is_null() {
|
||||
return Ok(None);
|
||||
}
|
||||
let stored = serde_json::from_value::<StoredOAuthCredentials>(oauth.clone())
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
||||
Ok(Some(stored.into()))
|
||||
}
|
||||
|
||||
pub fn save_oauth_credentials(token_set: &OAuthTokenSet) -> io::Result<()> {
|
||||
let path = credentials_path()?;
|
||||
let mut root = read_credentials_root(&path)?;
|
||||
root.insert(
|
||||
"oauth".to_string(),
|
||||
serde_json::to_value(StoredOAuthCredentials::from(token_set.clone()))
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?,
|
||||
);
|
||||
write_credentials_root(&path, &root)
|
||||
}
|
||||
|
||||
pub fn clear_oauth_credentials() -> io::Result<()> {
|
||||
let path = credentials_path()?;
|
||||
let mut root = read_credentials_root(&path)?;
|
||||
root.remove("oauth");
|
||||
write_credentials_root(&path, &root)
|
||||
}
|
||||
|
||||
pub fn parse_oauth_callback_request_target(target: &str) -> Result<OAuthCallbackParams, String> {
|
||||
let (path, query) = target
|
||||
.split_once('?')
|
||||
.map_or((target, ""), |(path, query)| (path, query));
|
||||
if path != "/callback" {
|
||||
return Err(format!("unexpected callback path: {path}"));
|
||||
}
|
||||
parse_oauth_callback_query(query)
|
||||
}
|
||||
|
||||
pub fn parse_oauth_callback_query(query: &str) -> Result<OAuthCallbackParams, String> {
|
||||
let mut params = BTreeMap::new();
|
||||
for pair in query.split('&').filter(|pair| !pair.is_empty()) {
|
||||
let (key, value) = pair
|
||||
.split_once('=')
|
||||
.map_or((pair, ""), |(key, value)| (key, value));
|
||||
params.insert(percent_decode(key)?, percent_decode(value)?);
|
||||
}
|
||||
Ok(OAuthCallbackParams {
|
||||
code: params.get("code").cloned(),
|
||||
state: params.get("state").cloned(),
|
||||
error: params.get("error").cloned(),
|
||||
error_description: params.get("error_description").cloned(),
|
||||
})
|
||||
}
|
||||
|
||||
fn generate_random_token(bytes: usize) -> io::Result<String> {
|
||||
let mut buffer = vec![0_u8; bytes];
|
||||
File::open("/dev/urandom")?.read_exact(&mut buffer)?;
|
||||
Ok(base64url_encode(&buffer))
|
||||
}
|
||||
|
||||
fn credentials_home_dir() -> io::Result<PathBuf> {
|
||||
if let Some(path) = std::env::var_os("CLAUDE_CONFIG_HOME") {
|
||||
return Ok(PathBuf::from(path));
|
||||
}
|
||||
let home = std::env::var_os("HOME")
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "HOME is not set"))?;
|
||||
Ok(PathBuf::from(home).join(".claude"))
|
||||
}
|
||||
|
||||
fn read_credentials_root(path: &PathBuf) -> io::Result<Map<String, Value>> {
|
||||
match fs::read_to_string(path) {
|
||||
Ok(contents) => {
|
||||
if contents.trim().is_empty() {
|
||||
return Ok(Map::new());
|
||||
}
|
||||
serde_json::from_str::<Value>(&contents)
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?
|
||||
.as_object()
|
||||
.cloned()
|
||||
.ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"credentials file must contain a JSON object",
|
||||
)
|
||||
})
|
||||
}
|
||||
Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(Map::new()),
|
||||
Err(error) => Err(error),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_credentials_root(path: &PathBuf, root: &Map<String, Value>) -> io::Result<()> {
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
let rendered = serde_json::to_string_pretty(&Value::Object(root.clone()))
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
||||
let temp_path = path.with_extension("json.tmp");
|
||||
fs::write(&temp_path, format!("{rendered}\n"))?;
|
||||
fs::rename(temp_path, path)
|
||||
}
|
||||
|
||||
fn base64url_encode(bytes: &[u8]) -> String {
|
||||
const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
|
||||
let mut output = String::new();
|
||||
@@ -264,11 +412,50 @@ fn percent_encode(value: &str) -> String {
|
||||
encoded
|
||||
}
|
||||
|
||||
fn percent_decode(value: &str) -> Result<String, String> {
|
||||
let mut decoded = Vec::with_capacity(value.len());
|
||||
let bytes = value.as_bytes();
|
||||
let mut index = 0;
|
||||
while index < bytes.len() {
|
||||
match bytes[index] {
|
||||
b'%' if index + 2 < bytes.len() => {
|
||||
let hi = decode_hex(bytes[index + 1])?;
|
||||
let lo = decode_hex(bytes[index + 2])?;
|
||||
decoded.push((hi << 4) | lo);
|
||||
index += 3;
|
||||
}
|
||||
b'+' => {
|
||||
decoded.push(b' ');
|
||||
index += 1;
|
||||
}
|
||||
byte => {
|
||||
decoded.push(byte);
|
||||
index += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
String::from_utf8(decoded).map_err(|error| error.to_string())
|
||||
}
|
||||
|
||||
fn decode_hex(byte: u8) -> Result<u8, String> {
|
||||
match byte {
|
||||
b'0'..=b'9' => Ok(byte - b'0'),
|
||||
b'a'..=b'f' => Ok(byte - b'a' + 10),
|
||||
b'A'..=b'F' => Ok(byte - b'A' + 10),
|
||||
_ => Err(format!("invalid percent-encoding byte: {byte}")),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use super::{
|
||||
code_challenge_s256, generate_pkce_pair, generate_state, loopback_redirect_uri,
|
||||
OAuthAuthorizationRequest, OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest,
|
||||
clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair,
|
||||
generate_state, load_oauth_credentials, loopback_redirect_uri, parse_oauth_callback_query,
|
||||
parse_oauth_callback_request_target, save_oauth_credentials, OAuthAuthorizationRequest,
|
||||
OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet,
|
||||
};
|
||||
|
||||
fn sample_config() -> OAuthConfig {
|
||||
@@ -282,6 +469,24 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
.lock()
|
||||
.expect("env lock")
|
||||
}
|
||||
|
||||
fn temp_config_home() -> std::path::PathBuf {
|
||||
std::env::temp_dir().join(format!(
|
||||
"runtime-oauth-test-{}-{}",
|
||||
std::process::id(),
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time")
|
||||
.as_nanos()
|
||||
))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn s256_challenge_matches_expected_vector() {
|
||||
assert_eq!(
|
||||
@@ -335,4 +540,54 @@ mod tests {
|
||||
Some("org:read user:write")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn oauth_credentials_round_trip_and_clear_preserves_other_fields() {
|
||||
let _guard = env_lock();
|
||||
let config_home = temp_config_home();
|
||||
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
||||
let path = credentials_path().expect("credentials path");
|
||||
std::fs::create_dir_all(path.parent().expect("parent")).expect("create parent");
|
||||
std::fs::write(&path, "{\"other\":\"value\"}\n").expect("seed credentials");
|
||||
|
||||
let token_set = OAuthTokenSet {
|
||||
access_token: "access-token".to_string(),
|
||||
refresh_token: Some("refresh-token".to_string()),
|
||||
expires_at: Some(123),
|
||||
scopes: vec!["scope:a".to_string()],
|
||||
};
|
||||
save_oauth_credentials(&token_set).expect("save credentials");
|
||||
assert_eq!(
|
||||
load_oauth_credentials().expect("load credentials"),
|
||||
Some(token_set)
|
||||
);
|
||||
let saved = std::fs::read_to_string(&path).expect("read saved file");
|
||||
assert!(saved.contains("\"other\": \"value\""));
|
||||
assert!(saved.contains("\"oauth\""));
|
||||
|
||||
clear_oauth_credentials().expect("clear credentials");
|
||||
assert_eq!(load_oauth_credentials().expect("load cleared"), None);
|
||||
let cleared = std::fs::read_to_string(&path).expect("read cleared file");
|
||||
assert!(cleared.contains("\"other\": \"value\""));
|
||||
assert!(!cleared.contains("\"oauth\""));
|
||||
|
||||
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_callback_query_and_target() {
|
||||
let params =
|
||||
parse_oauth_callback_query("code=abc123&state=state-1&error_description=needs%20login")
|
||||
.expect("parse query");
|
||||
assert_eq!(params.code.as_deref(), Some("abc123"));
|
||||
assert_eq!(params.state.as_deref(), Some("state-1"));
|
||||
assert_eq!(params.error_description.as_deref(), Some("needs login"));
|
||||
|
||||
let params = parse_oauth_callback_request_target("/callback?code=abc&state=xyz")
|
||||
.expect("parse callback target");
|
||||
assert_eq!(params.code.as_deref(), Some("abc"));
|
||||
assert_eq!(params.state.as_deref(), Some("xyz"));
|
||||
assert!(parse_oauth_callback_request_target("/wrong?code=abc").is_err());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user