Add reusable OAuth and auth-source foundations
Add runtime OAuth primitives for PKCE generation, authorization URL building, token exchange request shaping, and refresh request shaping. Wire the API client to a real auth-source abstraction so future OAuth tokens can flow into Anthropic requests without bespoke header code. This keeps the slice bounded to foundations: no browser flow, callback listener, or token persistence. The API client still behaves compatibly for current API-key users while gaining explicit bearer-token and combined auth modeling. Constraint: Must keep the slice minimal and real while preserving current API client behavior Constraint: Repo verification requires fmt, tests, and clippy to pass cleanly Rejected: Implement full OAuth browser/listener flow now | too broad for the current parity-unblocking slice Rejected: Keep auth handling as ad hoc env reads only | blocks reuse by future OAuth integration paths Confidence: high Scope-risk: moderate Reversibility: clean Directive: Extend OAuth behavior by composing these request/auth primitives before adding session or storage orchestration Tested: cargo fmt --all; cargo clippy -p runtime -p api --all-targets -- -D warnings; cargo test -p runtime; cargo test -p api --tests Not-tested: live OAuth token exchange; callback listener flow; workspace-wide tests outside runtime/api
This commit is contained in:
72
rust/Cargo.lock
generated
72
rust/Cargo.lock
generated
@@ -54,6 +54,15 @@ version = "2.11.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af"
|
checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "block-buffer"
|
||||||
|
version = "0.10.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
|
||||||
|
dependencies = [
|
||||||
|
"generic-array",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bumpalo"
|
name = "bumpalo"
|
||||||
version = "3.20.2"
|
version = "3.20.2"
|
||||||
@@ -104,6 +113,15 @@ dependencies = [
|
|||||||
"tools",
|
"tools",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cpufeatures"
|
||||||
|
version = "0.2.17"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crc32fast"
|
name = "crc32fast"
|
||||||
version = "1.5.0"
|
version = "1.5.0"
|
||||||
@@ -138,6 +156,16 @@ dependencies = [
|
|||||||
"winapi",
|
"winapi",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crypto-common"
|
||||||
|
version = "0.1.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a"
|
||||||
|
dependencies = [
|
||||||
|
"generic-array",
|
||||||
|
"typenum",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "deranged"
|
name = "deranged"
|
||||||
version = "0.5.8"
|
version = "0.5.8"
|
||||||
@@ -147,6 +175,16 @@ dependencies = [
|
|||||||
"powerfmt",
|
"powerfmt",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "digest"
|
||||||
|
version = "0.10.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
|
||||||
|
dependencies = [
|
||||||
|
"block-buffer",
|
||||||
|
"crypto-common",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "displaydoc"
|
name = "displaydoc"
|
||||||
version = "0.2.5"
|
version = "0.2.5"
|
||||||
@@ -238,6 +276,16 @@ dependencies = [
|
|||||||
"slab",
|
"slab",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "generic-array"
|
||||||
|
version = "0.14.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
|
||||||
|
dependencies = [
|
||||||
|
"typenum",
|
||||||
|
"version_check",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "getopts"
|
name = "getopts"
|
||||||
version = "0.2.24"
|
version = "0.2.24"
|
||||||
@@ -950,6 +998,7 @@ dependencies = [
|
|||||||
"regex",
|
"regex",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"sha2",
|
||||||
"tokio",
|
"tokio",
|
||||||
"walkdir",
|
"walkdir",
|
||||||
]
|
]
|
||||||
@@ -1106,6 +1155,17 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sha2"
|
||||||
|
version = "0.10.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"cpufeatures",
|
||||||
|
"digest",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "shlex"
|
name = "shlex"
|
||||||
version = "1.3.0"
|
version = "1.3.0"
|
||||||
@@ -1427,6 +1487,12 @@ version = "0.2.5"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "typenum"
|
||||||
|
version = "1.19.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicase"
|
name = "unicase"
|
||||||
version = "2.9.0"
|
version = "2.9.0"
|
||||||
@@ -1469,6 +1535,12 @@ version = "1.0.4"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
|
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "version_check"
|
||||||
|
version = "0.9.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "walkdir"
|
name = "walkdir"
|
||||||
version = "2.5.0"
|
version = "2.5.0"
|
||||||
|
|||||||
@@ -15,11 +15,90 @@ const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200);
|
|||||||
const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2);
|
const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2);
|
||||||
const DEFAULT_MAX_RETRIES: u32 = 2;
|
const DEFAULT_MAX_RETRIES: u32 = 2;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum AuthSource {
|
||||||
|
None,
|
||||||
|
ApiKey(String),
|
||||||
|
BearerToken(String),
|
||||||
|
ApiKeyAndBearer {
|
||||||
|
api_key: String,
|
||||||
|
bearer_token: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AuthSource {
|
||||||
|
pub fn from_env() -> Result<Self, ApiError> {
|
||||||
|
let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?;
|
||||||
|
let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?;
|
||||||
|
match (api_key, auth_token) {
|
||||||
|
(Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer {
|
||||||
|
api_key,
|
||||||
|
bearer_token,
|
||||||
|
}),
|
||||||
|
(Some(api_key), None) => Ok(Self::ApiKey(api_key)),
|
||||||
|
(None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)),
|
||||||
|
(None, None) => Err(ApiError::MissingApiKey),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn api_key(&self) -> Option<&str> {
|
||||||
|
match self {
|
||||||
|
Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key),
|
||||||
|
Self::None | Self::BearerToken(_) => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn bearer_token(&self) -> Option<&str> {
|
||||||
|
match self {
|
||||||
|
Self::BearerToken(token)
|
||||||
|
| Self::ApiKeyAndBearer {
|
||||||
|
bearer_token: token,
|
||||||
|
..
|
||||||
|
} => Some(token),
|
||||||
|
Self::None | Self::ApiKey(_) => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn masked_authorization_header(&self) -> &'static str {
|
||||||
|
if self.bearer_token().is_some() {
|
||||||
|
"Bearer [REDACTED]"
|
||||||
|
} else {
|
||||||
|
"<absent>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
|
||||||
|
if let Some(api_key) = self.api_key() {
|
||||||
|
request_builder = request_builder.header("x-api-key", api_key);
|
||||||
|
}
|
||||||
|
if let Some(token) = self.bearer_token() {
|
||||||
|
request_builder = request_builder.bearer_auth(token);
|
||||||
|
}
|
||||||
|
request_builder
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct OAuthTokenSet {
|
||||||
|
pub access_token: String,
|
||||||
|
pub refresh_token: Option<String>,
|
||||||
|
pub expires_at: Option<u64>,
|
||||||
|
pub scopes: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<OAuthTokenSet> for AuthSource {
|
||||||
|
fn from(value: OAuthTokenSet) -> Self {
|
||||||
|
Self::BearerToken(value.access_token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct AnthropicClient {
|
pub struct AnthropicClient {
|
||||||
http: reqwest::Client,
|
http: reqwest::Client,
|
||||||
api_key: String,
|
auth: AuthSource,
|
||||||
auth_token: Option<String>,
|
|
||||||
base_url: String,
|
base_url: String,
|
||||||
max_retries: u32,
|
max_retries: u32,
|
||||||
initial_backoff: Duration,
|
initial_backoff: Duration,
|
||||||
@@ -31,8 +110,19 @@ impl AnthropicClient {
|
|||||||
pub fn new(api_key: impl Into<String>) -> Self {
|
pub fn new(api_key: impl Into<String>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
http: reqwest::Client::new(),
|
http: reqwest::Client::new(),
|
||||||
api_key: api_key.into(),
|
auth: AuthSource::ApiKey(api_key.into()),
|
||||||
auth_token: None,
|
base_url: DEFAULT_BASE_URL.to_string(),
|
||||||
|
max_retries: DEFAULT_MAX_RETRIES,
|
||||||
|
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
||||||
|
max_backoff: DEFAULT_MAX_BACKOFF,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_auth(auth: AuthSource) -> Self {
|
||||||
|
Self {
|
||||||
|
http: reqwest::Client::new(),
|
||||||
|
auth,
|
||||||
base_url: DEFAULT_BASE_URL.to_string(),
|
base_url: DEFAULT_BASE_URL.to_string(),
|
||||||
max_retries: DEFAULT_MAX_RETRIES,
|
max_retries: DEFAULT_MAX_RETRIES,
|
||||||
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
||||||
@@ -41,14 +131,37 @@ impl AnthropicClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_env() -> Result<Self, ApiError> {
|
pub fn from_env() -> Result<Self, ApiError> {
|
||||||
Ok(Self::new(read_api_key()?)
|
Ok(Self::from_auth(AuthSource::from_env()?).with_base_url(read_base_url()))
|
||||||
.with_auth_token(read_auth_token())
|
}
|
||||||
.with_base_url(read_base_url()))
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn with_auth_source(mut self, auth: AuthSource) -> Self {
|
||||||
|
self.auth = auth;
|
||||||
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn with_auth_token(mut self, auth_token: Option<String>) -> Self {
|
pub fn with_auth_token(mut self, auth_token: Option<String>) -> Self {
|
||||||
self.auth_token = auth_token.filter(|token| !token.is_empty());
|
match (
|
||||||
|
self.auth.api_key().map(ToOwned::to_owned),
|
||||||
|
auth_token.filter(|token| !token.is_empty()),
|
||||||
|
) {
|
||||||
|
(Some(api_key), Some(bearer_token)) => {
|
||||||
|
self.auth = AuthSource::ApiKeyAndBearer {
|
||||||
|
api_key,
|
||||||
|
bearer_token,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
(Some(api_key), None) => {
|
||||||
|
self.auth = AuthSource::ApiKey(api_key);
|
||||||
|
}
|
||||||
|
(None, Some(bearer_token)) => {
|
||||||
|
self.auth = AuthSource::BearerToken(bearer_token);
|
||||||
|
}
|
||||||
|
(None, None) => {
|
||||||
|
self.auth = AuthSource::None;
|
||||||
|
}
|
||||||
|
}
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -71,6 +184,11 @@ impl AnthropicClient {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn auth_source(&self) -> &AuthSource {
|
||||||
|
&self.auth
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn send_message(
|
pub async fn send_message(
|
||||||
&self,
|
&self,
|
||||||
request: &MessageRequest,
|
request: &MessageRequest,
|
||||||
@@ -151,25 +269,25 @@ impl AnthropicClient {
|
|||||||
let resolved_base_url = self.base_url.trim_end_matches('/');
|
let resolved_base_url = self.base_url.trim_end_matches('/');
|
||||||
eprintln!("[anthropic-client] resolved_base_url={resolved_base_url}");
|
eprintln!("[anthropic-client] resolved_base_url={resolved_base_url}");
|
||||||
eprintln!("[anthropic-client] request_url={request_url}");
|
eprintln!("[anthropic-client] request_url={request_url}");
|
||||||
let mut request_builder = self
|
let request_builder = self
|
||||||
.http
|
.http
|
||||||
.post(&request_url)
|
.post(&request_url)
|
||||||
.header("x-api-key", &self.api_key)
|
|
||||||
.header("anthropic-version", ANTHROPIC_VERSION)
|
.header("anthropic-version", ANTHROPIC_VERSION)
|
||||||
.header("content-type", "application/json");
|
.header("content-type", "application/json");
|
||||||
|
let mut request_builder = self.auth.apply(request_builder);
|
||||||
|
|
||||||
let auth_header = self.auth_token.as_ref().map(|_| "Bearer [REDACTED]").unwrap_or("<absent>");
|
eprintln!(
|
||||||
eprintln!("[anthropic-client] headers x-api-key=[REDACTED] authorization={auth_header} anthropic-version={ANTHROPIC_VERSION} content-type=application/json");
|
"[anthropic-client] headers x-api-key={} authorization={} anthropic-version={ANTHROPIC_VERSION} content-type=application/json",
|
||||||
|
if self.auth.api_key().is_some() {
|
||||||
|
"[REDACTED]"
|
||||||
|
} else {
|
||||||
|
"<absent>"
|
||||||
|
},
|
||||||
|
self.auth.masked_authorization_header()
|
||||||
|
);
|
||||||
|
|
||||||
if let Some(auth_token) = &self.auth_token {
|
request_builder = request_builder.json(request);
|
||||||
request_builder = request_builder.bearer_auth(auth_token);
|
request_builder.send().await.map_err(ApiError::from)
|
||||||
}
|
|
||||||
|
|
||||||
request_builder
|
|
||||||
.json(request)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(ApiError::from)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
|
fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
|
||||||
@@ -186,25 +304,28 @@ impl AnthropicClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_api_key() -> Result<String, ApiError> {
|
fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
|
||||||
match std::env::var("ANTHROPIC_API_KEY") {
|
match std::env::var(key) {
|
||||||
Ok(api_key) if !api_key.is_empty() => Ok(api_key),
|
Ok(value) if !value.is_empty() => Ok(Some(value)),
|
||||||
Ok(_) => Err(ApiError::MissingApiKey),
|
Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
|
||||||
Err(std::env::VarError::NotPresent) => match std::env::var("ANTHROPIC_AUTH_TOKEN") {
|
|
||||||
Ok(api_key) if !api_key.is_empty() => Ok(api_key),
|
|
||||||
Ok(_) => Err(ApiError::MissingApiKey),
|
|
||||||
Err(std::env::VarError::NotPresent) => Err(ApiError::MissingApiKey),
|
|
||||||
Err(error) => Err(ApiError::from(error)),
|
|
||||||
},
|
|
||||||
Err(error) => Err(ApiError::from(error)),
|
Err(error) => Err(ApiError::from(error)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn read_api_key() -> Result<String, ApiError> {
|
||||||
|
let auth = AuthSource::from_env()?;
|
||||||
|
auth.api_key()
|
||||||
|
.or_else(|| auth.bearer_token())
|
||||||
|
.map(ToOwned::to_owned)
|
||||||
|
.ok_or(ApiError::MissingApiKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
fn read_auth_token() -> Option<String> {
|
fn read_auth_token() -> Option<String> {
|
||||||
match std::env::var("ANTHROPIC_AUTH_TOKEN") {
|
read_env_non_empty("ANTHROPIC_AUTH_TOKEN")
|
||||||
Ok(token) if !token.is_empty() => Some(token),
|
.ok()
|
||||||
_ => None,
|
.and_then(std::convert::identity)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_base_url() -> String {
|
fn read_base_url() -> String {
|
||||||
@@ -303,12 +424,22 @@ struct AnthropicErrorBody {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER};
|
use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER};
|
||||||
|
use std::sync::{Mutex, OnceLock};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use crate::client::{AuthSource, OAuthTokenSet};
|
||||||
use crate::types::{ContentBlockDelta, MessageRequest};
|
use crate::types::{ContentBlockDelta, MessageRequest};
|
||||||
|
|
||||||
|
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||||
|
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||||
|
LOCK.get_or_init(|| Mutex::new(()))
|
||||||
|
.lock()
|
||||||
|
.expect("env lock")
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn read_api_key_requires_presence() {
|
fn read_api_key_requires_presence() {
|
||||||
|
let _guard = env_lock();
|
||||||
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||||
std::env::remove_var("ANTHROPIC_API_KEY");
|
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||||
let error = super::read_api_key().expect_err("missing key should error");
|
let error = super::read_api_key().expect_err("missing key should error");
|
||||||
@@ -317,6 +448,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn read_api_key_requires_non_empty_value() {
|
fn read_api_key_requires_non_empty_value() {
|
||||||
|
let _guard = env_lock();
|
||||||
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "");
|
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "");
|
||||||
std::env::remove_var("ANTHROPIC_API_KEY");
|
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||||
let error = super::read_api_key().expect_err("empty key should error");
|
let error = super::read_api_key().expect_err("empty key should error");
|
||||||
@@ -325,6 +457,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn read_api_key_prefers_api_key_env() {
|
fn read_api_key_prefers_api_key_env() {
|
||||||
|
let _guard = env_lock();
|
||||||
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
|
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
|
||||||
std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
|
std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@@ -337,11 +470,36 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn read_auth_token_reads_auth_token_env() {
|
fn read_auth_token_reads_auth_token_env() {
|
||||||
|
let _guard = env_lock();
|
||||||
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
|
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
|
||||||
assert_eq!(super::read_auth_token().as_deref(), Some("auth-token"));
|
assert_eq!(super::read_auth_token().as_deref(), Some("auth-token"));
|
||||||
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn oauth_token_maps_to_bearer_auth_source() {
|
||||||
|
let auth = AuthSource::from(OAuthTokenSet {
|
||||||
|
access_token: "access-token".to_string(),
|
||||||
|
refresh_token: Some("refresh".to_string()),
|
||||||
|
expires_at: Some(123),
|
||||||
|
scopes: vec!["scope:a".to_string()],
|
||||||
|
});
|
||||||
|
assert_eq!(auth.bearer_token(), Some("access-token"));
|
||||||
|
assert_eq!(auth.api_key(), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_source_from_env_combines_api_key_and_bearer_token() {
|
||||||
|
let _guard = env_lock();
|
||||||
|
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
|
||||||
|
std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
|
||||||
|
let auth = AuthSource::from_env().expect("env auth");
|
||||||
|
assert_eq!(auth.api_key(), Some("legacy-key"));
|
||||||
|
assert_eq!(auth.bearer_token(), Some("auth-token"));
|
||||||
|
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||||
|
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn message_request_stream_helper_sets_stream_true() {
|
fn message_request_stream_helper_sets_stream_true() {
|
||||||
let request = MessageRequest {
|
let request = MessageRequest {
|
||||||
@@ -421,4 +579,25 @@ mod tests {
|
|||||||
Some("req_fallback")
|
Some("req_fallback")
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_source_applies_headers() {
|
||||||
|
let auth = AuthSource::ApiKeyAndBearer {
|
||||||
|
api_key: "test-key".to_string(),
|
||||||
|
bearer_token: "proxy-token".to_string(),
|
||||||
|
};
|
||||||
|
let request = auth
|
||||||
|
.apply(reqwest::Client::new().post("https://example.test"))
|
||||||
|
.build()
|
||||||
|
.expect("request build");
|
||||||
|
let headers = request.headers();
|
||||||
|
assert_eq!(
|
||||||
|
headers.get("x-api-key").and_then(|v| v.to_str().ok()),
|
||||||
|
Some("test-key")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
headers.get("authorization").and_then(|v| v.to_str().ok()),
|
||||||
|
Some("Bearer proxy-token")
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ mod error;
|
|||||||
mod sse;
|
mod sse;
|
||||||
mod types;
|
mod types;
|
||||||
|
|
||||||
pub use client::{AnthropicClient, MessageStream};
|
pub use client::{AnthropicClient, AuthSource, MessageStream, OAuthTokenSet};
|
||||||
pub use error::ApiError;
|
pub use error::ApiError;
|
||||||
pub use sse::{parse_frame, SseParser};
|
pub use sse::{parse_frame, SseParser};
|
||||||
pub use types::{
|
pub use types::{
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ license.workspace = true
|
|||||||
publish.workspace = true
|
publish.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
sha2 = "0.10"
|
||||||
glob = "0.3"
|
glob = "0.3"
|
||||||
regex = "1"
|
regex = "1"
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ mod config;
|
|||||||
mod conversation;
|
mod conversation;
|
||||||
mod file_ops;
|
mod file_ops;
|
||||||
mod json;
|
mod json;
|
||||||
|
mod oauth;
|
||||||
mod permissions;
|
mod permissions;
|
||||||
mod prompt;
|
mod prompt;
|
||||||
mod session;
|
mod session;
|
||||||
@@ -31,6 +32,11 @@ pub use file_ops::{
|
|||||||
GrepSearchInput, GrepSearchOutput, ReadFileOutput, StructuredPatchHunk, TextFilePayload,
|
GrepSearchInput, GrepSearchOutput, ReadFileOutput, StructuredPatchHunk, TextFilePayload,
|
||||||
WriteFileOutput,
|
WriteFileOutput,
|
||||||
};
|
};
|
||||||
|
pub use oauth::{
|
||||||
|
code_challenge_s256, generate_pkce_pair, generate_state, loopback_redirect_uri,
|
||||||
|
OAuthAuthorizationRequest, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet,
|
||||||
|
PkceChallengeMethod, PkceCodePair,
|
||||||
|
};
|
||||||
pub use permissions::{
|
pub use permissions::{
|
||||||
PermissionMode, PermissionOutcome, PermissionPolicy, PermissionPromptDecision,
|
PermissionMode, PermissionOutcome, PermissionPolicy, PermissionPromptDecision,
|
||||||
PermissionPrompter, PermissionRequest,
|
PermissionPrompter, PermissionRequest,
|
||||||
|
|||||||
338
rust/crates/runtime/src/oauth.rs
Normal file
338
rust/crates/runtime/src/oauth.rs
Normal file
@@ -0,0 +1,338 @@
|
|||||||
|
use std::collections::BTreeMap;
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io::{self, Read};
|
||||||
|
|
||||||
|
use sha2::{Digest, Sha256};
|
||||||
|
|
||||||
|
use crate::config::OAuthConfig;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct OAuthTokenSet {
|
||||||
|
pub access_token: String,
|
||||||
|
pub refresh_token: Option<String>,
|
||||||
|
pub expires_at: Option<u64>,
|
||||||
|
pub scopes: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct PkceCodePair {
|
||||||
|
pub verifier: String,
|
||||||
|
pub challenge: String,
|
||||||
|
pub challenge_method: PkceChallengeMethod,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum PkceChallengeMethod {
|
||||||
|
S256,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PkceChallengeMethod {
|
||||||
|
#[must_use]
|
||||||
|
pub const fn as_str(self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::S256 => "S256",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct OAuthAuthorizationRequest {
|
||||||
|
pub authorize_url: String,
|
||||||
|
pub client_id: String,
|
||||||
|
pub redirect_uri: String,
|
||||||
|
pub scopes: Vec<String>,
|
||||||
|
pub state: String,
|
||||||
|
pub code_challenge: String,
|
||||||
|
pub code_challenge_method: PkceChallengeMethod,
|
||||||
|
pub extra_params: BTreeMap<String, String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct OAuthTokenExchangeRequest {
|
||||||
|
pub grant_type: &'static str,
|
||||||
|
pub code: String,
|
||||||
|
pub redirect_uri: String,
|
||||||
|
pub client_id: String,
|
||||||
|
pub code_verifier: String,
|
||||||
|
pub state: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct OAuthRefreshRequest {
|
||||||
|
pub grant_type: &'static str,
|
||||||
|
pub refresh_token: String,
|
||||||
|
pub client_id: String,
|
||||||
|
pub scopes: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OAuthAuthorizationRequest {
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_config(
|
||||||
|
config: &OAuthConfig,
|
||||||
|
redirect_uri: impl Into<String>,
|
||||||
|
state: impl Into<String>,
|
||||||
|
pkce: &PkceCodePair,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
authorize_url: config.authorize_url.clone(),
|
||||||
|
client_id: config.client_id.clone(),
|
||||||
|
redirect_uri: redirect_uri.into(),
|
||||||
|
scopes: config.scopes.clone(),
|
||||||
|
state: state.into(),
|
||||||
|
code_challenge: pkce.challenge.clone(),
|
||||||
|
code_challenge_method: pkce.challenge_method,
|
||||||
|
extra_params: BTreeMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn with_extra_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
|
||||||
|
self.extra_params.insert(key.into(), value.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn build_url(&self) -> String {
|
||||||
|
let mut params = vec![
|
||||||
|
("response_type", "code".to_string()),
|
||||||
|
("client_id", self.client_id.clone()),
|
||||||
|
("redirect_uri", self.redirect_uri.clone()),
|
||||||
|
("scope", self.scopes.join(" ")),
|
||||||
|
("state", self.state.clone()),
|
||||||
|
("code_challenge", self.code_challenge.clone()),
|
||||||
|
(
|
||||||
|
"code_challenge_method",
|
||||||
|
self.code_challenge_method.as_str().to_string(),
|
||||||
|
),
|
||||||
|
];
|
||||||
|
params.extend(
|
||||||
|
self.extra_params
|
||||||
|
.iter()
|
||||||
|
.map(|(key, value)| (key.as_str(), value.clone())),
|
||||||
|
);
|
||||||
|
let query = params
|
||||||
|
.into_iter()
|
||||||
|
.map(|(key, value)| format!("{}={}", percent_encode(key), percent_encode(&value)))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("&");
|
||||||
|
format!(
|
||||||
|
"{}{}{}",
|
||||||
|
self.authorize_url,
|
||||||
|
if self.authorize_url.contains('?') {
|
||||||
|
'&'
|
||||||
|
} else {
|
||||||
|
'?'
|
||||||
|
},
|
||||||
|
query
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OAuthTokenExchangeRequest {
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_config(
|
||||||
|
config: &OAuthConfig,
|
||||||
|
code: impl Into<String>,
|
||||||
|
state: impl Into<String>,
|
||||||
|
verifier: impl Into<String>,
|
||||||
|
redirect_uri: impl Into<String>,
|
||||||
|
) -> Self {
|
||||||
|
let _ = config;
|
||||||
|
Self {
|
||||||
|
grant_type: "authorization_code",
|
||||||
|
code: code.into(),
|
||||||
|
redirect_uri: redirect_uri.into(),
|
||||||
|
client_id: config.client_id.clone(),
|
||||||
|
code_verifier: verifier.into(),
|
||||||
|
state: state.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn form_params(&self) -> BTreeMap<&str, String> {
|
||||||
|
BTreeMap::from([
|
||||||
|
("grant_type", self.grant_type.to_string()),
|
||||||
|
("code", self.code.clone()),
|
||||||
|
("redirect_uri", self.redirect_uri.clone()),
|
||||||
|
("client_id", self.client_id.clone()),
|
||||||
|
("code_verifier", self.code_verifier.clone()),
|
||||||
|
("state", self.state.clone()),
|
||||||
|
])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OAuthRefreshRequest {
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_config(
|
||||||
|
config: &OAuthConfig,
|
||||||
|
refresh_token: impl Into<String>,
|
||||||
|
scopes: Option<Vec<String>>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
grant_type: "refresh_token",
|
||||||
|
refresh_token: refresh_token.into(),
|
||||||
|
client_id: config.client_id.clone(),
|
||||||
|
scopes: scopes.unwrap_or_else(|| config.scopes.clone()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn form_params(&self) -> BTreeMap<&str, String> {
|
||||||
|
BTreeMap::from([
|
||||||
|
("grant_type", self.grant_type.to_string()),
|
||||||
|
("refresh_token", self.refresh_token.clone()),
|
||||||
|
("client_id", self.client_id.clone()),
|
||||||
|
("scope", self.scopes.join(" ")),
|
||||||
|
])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn generate_pkce_pair() -> io::Result<PkceCodePair> {
|
||||||
|
let verifier = generate_random_token(32)?;
|
||||||
|
Ok(PkceCodePair {
|
||||||
|
challenge: code_challenge_s256(&verifier),
|
||||||
|
verifier,
|
||||||
|
challenge_method: PkceChallengeMethod::S256,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn generate_state() -> io::Result<String> {
|
||||||
|
generate_random_token(32)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn code_challenge_s256(verifier: &str) -> String {
|
||||||
|
let digest = Sha256::digest(verifier.as_bytes());
|
||||||
|
base64url_encode(&digest)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn loopback_redirect_uri(port: u16) -> String {
|
||||||
|
format!("http://localhost:{port}/callback")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn generate_random_token(bytes: usize) -> io::Result<String> {
|
||||||
|
let mut buffer = vec![0_u8; bytes];
|
||||||
|
File::open("/dev/urandom")?.read_exact(&mut buffer)?;
|
||||||
|
Ok(base64url_encode(&buffer))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn base64url_encode(bytes: &[u8]) -> String {
|
||||||
|
const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
|
||||||
|
let mut output = String::new();
|
||||||
|
let mut index = 0;
|
||||||
|
while index + 3 <= bytes.len() {
|
||||||
|
let block = (u32::from(bytes[index]) << 16)
|
||||||
|
| (u32::from(bytes[index + 1]) << 8)
|
||||||
|
| u32::from(bytes[index + 2]);
|
||||||
|
output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
|
||||||
|
output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
|
||||||
|
output.push(TABLE[((block >> 6) & 0x3F) as usize] as char);
|
||||||
|
output.push(TABLE[(block & 0x3F) as usize] as char);
|
||||||
|
index += 3;
|
||||||
|
}
|
||||||
|
match bytes.len().saturating_sub(index) {
|
||||||
|
1 => {
|
||||||
|
let block = u32::from(bytes[index]) << 16;
|
||||||
|
output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
|
||||||
|
output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
|
||||||
|
}
|
||||||
|
2 => {
|
||||||
|
let block = (u32::from(bytes[index]) << 16) | (u32::from(bytes[index + 1]) << 8);
|
||||||
|
output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
|
||||||
|
output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
|
||||||
|
output.push(TABLE[((block >> 6) & 0x3F) as usize] as char);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
fn percent_encode(value: &str) -> String {
|
||||||
|
let mut encoded = String::new();
|
||||||
|
for byte in value.bytes() {
|
||||||
|
match byte {
|
||||||
|
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
|
||||||
|
encoded.push(char::from(byte));
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
use std::fmt::Write as _;
|
||||||
|
let _ = write!(&mut encoded, "%{byte:02X}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
encoded
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::{
|
||||||
|
code_challenge_s256, generate_pkce_pair, generate_state, loopback_redirect_uri,
|
||||||
|
OAuthAuthorizationRequest, OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest,
|
||||||
|
};
|
||||||
|
|
||||||
|
fn sample_config() -> OAuthConfig {
|
||||||
|
OAuthConfig {
|
||||||
|
client_id: "runtime-client".to_string(),
|
||||||
|
authorize_url: "https://console.test/oauth/authorize".to_string(),
|
||||||
|
token_url: "https://console.test/oauth/token".to_string(),
|
||||||
|
callback_port: Some(4545),
|
||||||
|
manual_redirect_url: Some("https://console.test/oauth/callback".to_string()),
|
||||||
|
scopes: vec!["org:read".to_string(), "user:write".to_string()],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn s256_challenge_matches_expected_vector() {
|
||||||
|
assert_eq!(
|
||||||
|
code_challenge_s256("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"),
|
||||||
|
"E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn generates_pkce_pair_and_state() {
|
||||||
|
let pair = generate_pkce_pair().expect("pkce pair");
|
||||||
|
let state = generate_state().expect("state");
|
||||||
|
assert!(!pair.verifier.is_empty());
|
||||||
|
assert!(!pair.challenge.is_empty());
|
||||||
|
assert!(!state.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn builds_authorize_url_and_form_requests() {
|
||||||
|
let config = sample_config();
|
||||||
|
let pair = generate_pkce_pair().expect("pkce");
|
||||||
|
let url = OAuthAuthorizationRequest::from_config(
|
||||||
|
&config,
|
||||||
|
loopback_redirect_uri(4545),
|
||||||
|
"state-123",
|
||||||
|
&pair,
|
||||||
|
)
|
||||||
|
.with_extra_param("login_hint", "user@example.com")
|
||||||
|
.build_url();
|
||||||
|
assert!(url.starts_with("https://console.test/oauth/authorize?"));
|
||||||
|
assert!(url.contains("response_type=code"));
|
||||||
|
assert!(url.contains("client_id=runtime-client"));
|
||||||
|
assert!(url.contains("scope=org%3Aread%20user%3Awrite"));
|
||||||
|
assert!(url.contains("login_hint=user%40example.com"));
|
||||||
|
|
||||||
|
let exchange = OAuthTokenExchangeRequest::from_config(
|
||||||
|
&config,
|
||||||
|
"auth-code",
|
||||||
|
"state-123",
|
||||||
|
pair.verifier,
|
||||||
|
loopback_redirect_uri(4545),
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
exchange.form_params().get("grant_type").map(String::as_str),
|
||||||
|
Some("authorization_code")
|
||||||
|
);
|
||||||
|
|
||||||
|
let refresh = OAuthRefreshRequest::from_config(&config, "refresh-token", None);
|
||||||
|
assert_eq!(
|
||||||
|
refresh.form_params().get("scope").map(String::as_str),
|
||||||
|
Some("org:read user:write")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user