From a66c301fa3586bd82fab1601836cbcc971ae2885 Mon Sep 17 00:00:00 2001 From: Yeachan-Heo Date: Tue, 31 Mar 2026 19:47:02 +0000 Subject: [PATCH] Add reusable OAuth and auth-source foundations Add runtime OAuth primitives for PKCE generation, authorization URL building, token exchange request shaping, and refresh request shaping. Wire the API client to a real auth-source abstraction so future OAuth tokens can flow into Anthropic requests without bespoke header code. This keeps the slice bounded to foundations: no browser flow, callback listener, or token persistence. The API client still behaves compatibly for current API-key users while gaining explicit bearer-token and combined auth modeling. Constraint: Must keep the slice minimal and real while preserving current API client behavior Constraint: Repo verification requires fmt, tests, and clippy to pass cleanly Rejected: Implement full OAuth browser/listener flow now | too broad for the current parity-unblocking slice Rejected: Keep auth handling as ad hoc env reads only | blocks reuse by future OAuth integration paths Confidence: high Scope-risk: moderate Reversibility: clean Directive: Extend OAuth behavior by composing these request/auth primitives before adding session or storage orchestration Tested: cargo fmt --all; cargo clippy -p runtime -p api --all-targets -- -D warnings; cargo test -p runtime; cargo test -p api --tests Not-tested: live OAuth token exchange; callback listener flow; workspace-wide tests outside runtime/api --- rust/Cargo.lock | 72 +++++++ rust/crates/api/src/client.rs | 249 +++++++++++++++++++---- rust/crates/api/src/lib.rs | 2 +- rust/crates/runtime/Cargo.toml | 1 + rust/crates/runtime/src/lib.rs | 6 + rust/crates/runtime/src/oauth.rs | 338 +++++++++++++++++++++++++++++++ 6 files changed, 632 insertions(+), 36 deletions(-) create mode 100644 rust/crates/runtime/src/oauth.rs diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 308a108..806c309 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -54,6 +54,15 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.20.2" @@ -104,6 +113,15 @@ dependencies = [ "tools", ] +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.5.0" @@ -138,6 +156,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "deranged" version = "0.5.8" @@ -147,6 +175,16 @@ dependencies = [ "powerfmt", ] +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -238,6 +276,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getopts" version = "0.2.24" @@ -950,6 +998,7 @@ dependencies = [ "regex", "serde", "serde_json", + "sha2", "tokio", "walkdir", ] @@ -1106,6 +1155,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "shlex" version = "1.3.0" @@ -1427,6 +1487,12 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + [[package]] name = "unicase" version = "2.9.0" @@ -1469,6 +1535,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + [[package]] name = "walkdir" version = "2.5.0" diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs index d77cf9c..5e7d319 100644 --- a/rust/crates/api/src/client.rs +++ b/rust/crates/api/src/client.rs @@ -15,11 +15,90 @@ const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200); const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2); const DEFAULT_MAX_RETRIES: u32 = 2; +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AuthSource { + None, + ApiKey(String), + BearerToken(String), + ApiKeyAndBearer { + api_key: String, + bearer_token: String, + }, +} + +impl AuthSource { + pub fn from_env() -> Result { + let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?; + let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?; + match (api_key, auth_token) { + (Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer { + api_key, + bearer_token, + }), + (Some(api_key), None) => Ok(Self::ApiKey(api_key)), + (None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)), + (None, None) => Err(ApiError::MissingApiKey), + } + } + + #[must_use] + pub fn api_key(&self) -> Option<&str> { + match self { + Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key), + Self::None | Self::BearerToken(_) => None, + } + } + + #[must_use] + pub fn bearer_token(&self) -> Option<&str> { + match self { + Self::BearerToken(token) + | Self::ApiKeyAndBearer { + bearer_token: token, + .. + } => Some(token), + Self::None | Self::ApiKey(_) => None, + } + } + + #[must_use] + pub fn masked_authorization_header(&self) -> &'static str { + if self.bearer_token().is_some() { + "Bearer [REDACTED]" + } else { + "" + } + } + + pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder { + if let Some(api_key) = self.api_key() { + request_builder = request_builder.header("x-api-key", api_key); + } + if let Some(token) = self.bearer_token() { + request_builder = request_builder.bearer_auth(token); + } + request_builder + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct OAuthTokenSet { + pub access_token: String, + pub refresh_token: Option, + pub expires_at: Option, + pub scopes: Vec, +} + +impl From for AuthSource { + fn from(value: OAuthTokenSet) -> Self { + Self::BearerToken(value.access_token) + } +} + #[derive(Debug, Clone)] pub struct AnthropicClient { http: reqwest::Client, - api_key: String, - auth_token: Option, + auth: AuthSource, base_url: String, max_retries: u32, initial_backoff: Duration, @@ -31,8 +110,19 @@ impl AnthropicClient { pub fn new(api_key: impl Into) -> Self { Self { http: reqwest::Client::new(), - api_key: api_key.into(), - auth_token: None, + auth: AuthSource::ApiKey(api_key.into()), + base_url: DEFAULT_BASE_URL.to_string(), + max_retries: DEFAULT_MAX_RETRIES, + initial_backoff: DEFAULT_INITIAL_BACKOFF, + max_backoff: DEFAULT_MAX_BACKOFF, + } + } + + #[must_use] + pub fn from_auth(auth: AuthSource) -> Self { + Self { + http: reqwest::Client::new(), + auth, base_url: DEFAULT_BASE_URL.to_string(), max_retries: DEFAULT_MAX_RETRIES, initial_backoff: DEFAULT_INITIAL_BACKOFF, @@ -41,14 +131,37 @@ impl AnthropicClient { } pub fn from_env() -> Result { - Ok(Self::new(read_api_key()?) - .with_auth_token(read_auth_token()) - .with_base_url(read_base_url())) + Ok(Self::from_auth(AuthSource::from_env()?).with_base_url(read_base_url())) + } + + #[must_use] + pub fn with_auth_source(mut self, auth: AuthSource) -> Self { + self.auth = auth; + self } #[must_use] pub fn with_auth_token(mut self, auth_token: Option) -> Self { - self.auth_token = auth_token.filter(|token| !token.is_empty()); + match ( + self.auth.api_key().map(ToOwned::to_owned), + auth_token.filter(|token| !token.is_empty()), + ) { + (Some(api_key), Some(bearer_token)) => { + self.auth = AuthSource::ApiKeyAndBearer { + api_key, + bearer_token, + }; + } + (Some(api_key), None) => { + self.auth = AuthSource::ApiKey(api_key); + } + (None, Some(bearer_token)) => { + self.auth = AuthSource::BearerToken(bearer_token); + } + (None, None) => { + self.auth = AuthSource::None; + } + } self } @@ -71,6 +184,11 @@ impl AnthropicClient { self } + #[must_use] + pub fn auth_source(&self) -> &AuthSource { + &self.auth + } + pub async fn send_message( &self, request: &MessageRequest, @@ -151,25 +269,25 @@ impl AnthropicClient { let resolved_base_url = self.base_url.trim_end_matches('/'); eprintln!("[anthropic-client] resolved_base_url={resolved_base_url}"); eprintln!("[anthropic-client] request_url={request_url}"); - let mut request_builder = self + let request_builder = self .http .post(&request_url) - .header("x-api-key", &self.api_key) .header("anthropic-version", ANTHROPIC_VERSION) .header("content-type", "application/json"); + let mut request_builder = self.auth.apply(request_builder); - let auth_header = self.auth_token.as_ref().map(|_| "Bearer [REDACTED]").unwrap_or(""); - eprintln!("[anthropic-client] headers x-api-key=[REDACTED] authorization={auth_header} anthropic-version={ANTHROPIC_VERSION} content-type=application/json"); + eprintln!( + "[anthropic-client] headers x-api-key={} authorization={} anthropic-version={ANTHROPIC_VERSION} content-type=application/json", + if self.auth.api_key().is_some() { + "[REDACTED]" + } else { + "" + }, + self.auth.masked_authorization_header() + ); - if let Some(auth_token) = &self.auth_token { - request_builder = request_builder.bearer_auth(auth_token); - } - - request_builder - .json(request) - .send() - .await - .map_err(ApiError::from) + request_builder = request_builder.json(request); + request_builder.send().await.map_err(ApiError::from) } fn backoff_for_attempt(&self, attempt: u32) -> Result { @@ -186,25 +304,28 @@ impl AnthropicClient { } } -fn read_api_key() -> Result { - match std::env::var("ANTHROPIC_API_KEY") { - Ok(api_key) if !api_key.is_empty() => Ok(api_key), - Ok(_) => Err(ApiError::MissingApiKey), - Err(std::env::VarError::NotPresent) => match std::env::var("ANTHROPIC_AUTH_TOKEN") { - Ok(api_key) if !api_key.is_empty() => Ok(api_key), - Ok(_) => Err(ApiError::MissingApiKey), - Err(std::env::VarError::NotPresent) => Err(ApiError::MissingApiKey), - Err(error) => Err(ApiError::from(error)), - }, +fn read_env_non_empty(key: &str) -> Result, ApiError> { + match std::env::var(key) { + Ok(value) if !value.is_empty() => Ok(Some(value)), + Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None), Err(error) => Err(ApiError::from(error)), } } +#[cfg(test)] +fn read_api_key() -> Result { + let auth = AuthSource::from_env()?; + auth.api_key() + .or_else(|| auth.bearer_token()) + .map(ToOwned::to_owned) + .ok_or(ApiError::MissingApiKey) +} + +#[cfg(test)] fn read_auth_token() -> Option { - match std::env::var("ANTHROPIC_AUTH_TOKEN") { - Ok(token) if !token.is_empty() => Some(token), - _ => None, - } + read_env_non_empty("ANTHROPIC_AUTH_TOKEN") + .ok() + .and_then(std::convert::identity) } fn read_base_url() -> String { @@ -303,12 +424,22 @@ struct AnthropicErrorBody { #[cfg(test)] mod tests { use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER}; + use std::sync::{Mutex, OnceLock}; use std::time::Duration; + use crate::client::{AuthSource, OAuthTokenSet}; use crate::types::{ContentBlockDelta, MessageRequest}; + fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .expect("env lock") + } + #[test] fn read_api_key_requires_presence() { + let _guard = env_lock(); std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_API_KEY"); let error = super::read_api_key().expect_err("missing key should error"); @@ -317,6 +448,7 @@ mod tests { #[test] fn read_api_key_requires_non_empty_value() { + let _guard = env_lock(); std::env::set_var("ANTHROPIC_AUTH_TOKEN", ""); std::env::remove_var("ANTHROPIC_API_KEY"); let error = super::read_api_key().expect_err("empty key should error"); @@ -325,6 +457,7 @@ mod tests { #[test] fn read_api_key_prefers_api_key_env() { + let _guard = env_lock(); std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); assert_eq!( @@ -337,11 +470,36 @@ mod tests { #[test] fn read_auth_token_reads_auth_token_env() { + let _guard = env_lock(); std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); assert_eq!(super::read_auth_token().as_deref(), Some("auth-token")); std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); } + #[test] + fn oauth_token_maps_to_bearer_auth_source() { + let auth = AuthSource::from(OAuthTokenSet { + access_token: "access-token".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(123), + scopes: vec!["scope:a".to_string()], + }); + assert_eq!(auth.bearer_token(), Some("access-token")); + assert_eq!(auth.api_key(), None); + } + + #[test] + fn auth_source_from_env_combines_api_key_and_bearer_token() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); + std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); + let auth = AuthSource::from_env().expect("env auth"); + assert_eq!(auth.api_key(), Some("legacy-key")); + assert_eq!(auth.bearer_token(), Some("auth-token")); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + } + #[test] fn message_request_stream_helper_sets_stream_true() { let request = MessageRequest { @@ -421,4 +579,25 @@ mod tests { Some("req_fallback") ); } + + #[test] + fn auth_source_applies_headers() { + let auth = AuthSource::ApiKeyAndBearer { + api_key: "test-key".to_string(), + bearer_token: "proxy-token".to_string(), + }; + let request = auth + .apply(reqwest::Client::new().post("https://example.test")) + .build() + .expect("request build"); + let headers = request.headers(); + assert_eq!( + headers.get("x-api-key").and_then(|v| v.to_str().ok()), + Some("test-key") + ); + assert_eq!( + headers.get("authorization").and_then(|v| v.to_str().ok()), + Some("Bearer proxy-token") + ); + } } diff --git a/rust/crates/api/src/lib.rs b/rust/crates/api/src/lib.rs index e08e3d7..9d587ee 100644 --- a/rust/crates/api/src/lib.rs +++ b/rust/crates/api/src/lib.rs @@ -3,7 +3,7 @@ mod error; mod sse; mod types; -pub use client::{AnthropicClient, MessageStream}; +pub use client::{AnthropicClient, AuthSource, MessageStream, OAuthTokenSet}; pub use error::ApiError; pub use sse::{parse_frame, SseParser}; pub use types::{ diff --git a/rust/crates/runtime/Cargo.toml b/rust/crates/runtime/Cargo.toml index 8bd9a42..3803c10 100644 --- a/rust/crates/runtime/Cargo.toml +++ b/rust/crates/runtime/Cargo.toml @@ -6,6 +6,7 @@ license.workspace = true publish.workspace = true [dependencies] +sha2 = "0.10" glob = "0.3" regex = "1" serde = { version = "1", features = ["derive"] } diff --git a/rust/crates/runtime/src/lib.rs b/rust/crates/runtime/src/lib.rs index 358d367..4381166 100644 --- a/rust/crates/runtime/src/lib.rs +++ b/rust/crates/runtime/src/lib.rs @@ -5,6 +5,7 @@ mod config; mod conversation; mod file_ops; mod json; +mod oauth; mod permissions; mod prompt; mod session; @@ -31,6 +32,11 @@ pub use file_ops::{ GrepSearchInput, GrepSearchOutput, ReadFileOutput, StructuredPatchHunk, TextFilePayload, WriteFileOutput, }; +pub use oauth::{ + code_challenge_s256, generate_pkce_pair, generate_state, loopback_redirect_uri, + OAuthAuthorizationRequest, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet, + PkceChallengeMethod, PkceCodePair, +}; pub use permissions::{ PermissionMode, PermissionOutcome, PermissionPolicy, PermissionPromptDecision, PermissionPrompter, PermissionRequest, diff --git a/rust/crates/runtime/src/oauth.rs b/rust/crates/runtime/src/oauth.rs new file mode 100644 index 0000000..320a8ee --- /dev/null +++ b/rust/crates/runtime/src/oauth.rs @@ -0,0 +1,338 @@ +use std::collections::BTreeMap; +use std::fs::File; +use std::io::{self, Read}; + +use sha2::{Digest, Sha256}; + +use crate::config::OAuthConfig; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct OAuthTokenSet { + pub access_token: String, + pub refresh_token: Option, + pub expires_at: Option, + pub scopes: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PkceCodePair { + pub verifier: String, + pub challenge: String, + pub challenge_method: PkceChallengeMethod, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PkceChallengeMethod { + S256, +} + +impl PkceChallengeMethod { + #[must_use] + pub const fn as_str(self) -> &'static str { + match self { + Self::S256 => "S256", + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct OAuthAuthorizationRequest { + pub authorize_url: String, + pub client_id: String, + pub redirect_uri: String, + pub scopes: Vec, + pub state: String, + pub code_challenge: String, + pub code_challenge_method: PkceChallengeMethod, + pub extra_params: BTreeMap, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct OAuthTokenExchangeRequest { + pub grant_type: &'static str, + pub code: String, + pub redirect_uri: String, + pub client_id: String, + pub code_verifier: String, + pub state: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct OAuthRefreshRequest { + pub grant_type: &'static str, + pub refresh_token: String, + pub client_id: String, + pub scopes: Vec, +} + +impl OAuthAuthorizationRequest { + #[must_use] + pub fn from_config( + config: &OAuthConfig, + redirect_uri: impl Into, + state: impl Into, + pkce: &PkceCodePair, + ) -> Self { + Self { + authorize_url: config.authorize_url.clone(), + client_id: config.client_id.clone(), + redirect_uri: redirect_uri.into(), + scopes: config.scopes.clone(), + state: state.into(), + code_challenge: pkce.challenge.clone(), + code_challenge_method: pkce.challenge_method, + extra_params: BTreeMap::new(), + } + } + + #[must_use] + pub fn with_extra_param(mut self, key: impl Into, value: impl Into) -> Self { + self.extra_params.insert(key.into(), value.into()); + self + } + + #[must_use] + pub fn build_url(&self) -> String { + let mut params = vec![ + ("response_type", "code".to_string()), + ("client_id", self.client_id.clone()), + ("redirect_uri", self.redirect_uri.clone()), + ("scope", self.scopes.join(" ")), + ("state", self.state.clone()), + ("code_challenge", self.code_challenge.clone()), + ( + "code_challenge_method", + self.code_challenge_method.as_str().to_string(), + ), + ]; + params.extend( + self.extra_params + .iter() + .map(|(key, value)| (key.as_str(), value.clone())), + ); + let query = params + .into_iter() + .map(|(key, value)| format!("{}={}", percent_encode(key), percent_encode(&value))) + .collect::>() + .join("&"); + format!( + "{}{}{}", + self.authorize_url, + if self.authorize_url.contains('?') { + '&' + } else { + '?' + }, + query + ) + } +} + +impl OAuthTokenExchangeRequest { + #[must_use] + pub fn from_config( + config: &OAuthConfig, + code: impl Into, + state: impl Into, + verifier: impl Into, + redirect_uri: impl Into, + ) -> Self { + let _ = config; + Self { + grant_type: "authorization_code", + code: code.into(), + redirect_uri: redirect_uri.into(), + client_id: config.client_id.clone(), + code_verifier: verifier.into(), + state: state.into(), + } + } + + #[must_use] + pub fn form_params(&self) -> BTreeMap<&str, String> { + BTreeMap::from([ + ("grant_type", self.grant_type.to_string()), + ("code", self.code.clone()), + ("redirect_uri", self.redirect_uri.clone()), + ("client_id", self.client_id.clone()), + ("code_verifier", self.code_verifier.clone()), + ("state", self.state.clone()), + ]) + } +} + +impl OAuthRefreshRequest { + #[must_use] + pub fn from_config( + config: &OAuthConfig, + refresh_token: impl Into, + scopes: Option>, + ) -> Self { + Self { + grant_type: "refresh_token", + refresh_token: refresh_token.into(), + client_id: config.client_id.clone(), + scopes: scopes.unwrap_or_else(|| config.scopes.clone()), + } + } + + #[must_use] + pub fn form_params(&self) -> BTreeMap<&str, String> { + BTreeMap::from([ + ("grant_type", self.grant_type.to_string()), + ("refresh_token", self.refresh_token.clone()), + ("client_id", self.client_id.clone()), + ("scope", self.scopes.join(" ")), + ]) + } +} + +pub fn generate_pkce_pair() -> io::Result { + let verifier = generate_random_token(32)?; + Ok(PkceCodePair { + challenge: code_challenge_s256(&verifier), + verifier, + challenge_method: PkceChallengeMethod::S256, + }) +} + +pub fn generate_state() -> io::Result { + generate_random_token(32) +} + +#[must_use] +pub fn code_challenge_s256(verifier: &str) -> String { + let digest = Sha256::digest(verifier.as_bytes()); + base64url_encode(&digest) +} + +#[must_use] +pub fn loopback_redirect_uri(port: u16) -> String { + format!("http://localhost:{port}/callback") +} + +fn generate_random_token(bytes: usize) -> io::Result { + let mut buffer = vec![0_u8; bytes]; + File::open("/dev/urandom")?.read_exact(&mut buffer)?; + Ok(base64url_encode(&buffer)) +} + +fn base64url_encode(bytes: &[u8]) -> String { + const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; + let mut output = String::new(); + let mut index = 0; + while index + 3 <= bytes.len() { + let block = (u32::from(bytes[index]) << 16) + | (u32::from(bytes[index + 1]) << 8) + | u32::from(bytes[index + 2]); + output.push(TABLE[((block >> 18) & 0x3F) as usize] as char); + output.push(TABLE[((block >> 12) & 0x3F) as usize] as char); + output.push(TABLE[((block >> 6) & 0x3F) as usize] as char); + output.push(TABLE[(block & 0x3F) as usize] as char); + index += 3; + } + match bytes.len().saturating_sub(index) { + 1 => { + let block = u32::from(bytes[index]) << 16; + output.push(TABLE[((block >> 18) & 0x3F) as usize] as char); + output.push(TABLE[((block >> 12) & 0x3F) as usize] as char); + } + 2 => { + let block = (u32::from(bytes[index]) << 16) | (u32::from(bytes[index + 1]) << 8); + output.push(TABLE[((block >> 18) & 0x3F) as usize] as char); + output.push(TABLE[((block >> 12) & 0x3F) as usize] as char); + output.push(TABLE[((block >> 6) & 0x3F) as usize] as char); + } + _ => {} + } + output +} + +fn percent_encode(value: &str) -> String { + let mut encoded = String::new(); + for byte in value.bytes() { + match byte { + b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => { + encoded.push(char::from(byte)); + } + _ => { + use std::fmt::Write as _; + let _ = write!(&mut encoded, "%{byte:02X}"); + } + } + } + encoded +} + +#[cfg(test)] +mod tests { + use super::{ + code_challenge_s256, generate_pkce_pair, generate_state, loopback_redirect_uri, + OAuthAuthorizationRequest, OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest, + }; + + fn sample_config() -> OAuthConfig { + OAuthConfig { + client_id: "runtime-client".to_string(), + authorize_url: "https://console.test/oauth/authorize".to_string(), + token_url: "https://console.test/oauth/token".to_string(), + callback_port: Some(4545), + manual_redirect_url: Some("https://console.test/oauth/callback".to_string()), + scopes: vec!["org:read".to_string(), "user:write".to_string()], + } + } + + #[test] + fn s256_challenge_matches_expected_vector() { + assert_eq!( + code_challenge_s256("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"), + "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" + ); + } + + #[test] + fn generates_pkce_pair_and_state() { + let pair = generate_pkce_pair().expect("pkce pair"); + let state = generate_state().expect("state"); + assert!(!pair.verifier.is_empty()); + assert!(!pair.challenge.is_empty()); + assert!(!state.is_empty()); + } + + #[test] + fn builds_authorize_url_and_form_requests() { + let config = sample_config(); + let pair = generate_pkce_pair().expect("pkce"); + let url = OAuthAuthorizationRequest::from_config( + &config, + loopback_redirect_uri(4545), + "state-123", + &pair, + ) + .with_extra_param("login_hint", "user@example.com") + .build_url(); + assert!(url.starts_with("https://console.test/oauth/authorize?")); + assert!(url.contains("response_type=code")); + assert!(url.contains("client_id=runtime-client")); + assert!(url.contains("scope=org%3Aread%20user%3Awrite")); + assert!(url.contains("login_hint=user%40example.com")); + + let exchange = OAuthTokenExchangeRequest::from_config( + &config, + "auth-code", + "state-123", + pair.verifier, + loopback_redirect_uri(4545), + ); + assert_eq!( + exchange.form_params().get("grant_type").map(String::as_str), + Some("authorization_code") + ); + + let refresh = OAuthRefreshRequest::from_config(&config, "refresh-token", None); + assert_eq!( + refresh.form_params().get("scope").map(String::as_str), + Some("org:read user:write") + ); + } +}