wip: grok provider abstraction
This commit is contained in:
@@ -185,7 +185,7 @@ impl OpenAiCompatClient {
|
|||||||
&self,
|
&self,
|
||||||
request: &MessageRequest,
|
request: &MessageRequest,
|
||||||
) -> Result<reqwest::Response, ApiError> {
|
) -> Result<reqwest::Response, ApiError> {
|
||||||
let request_url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
|
let request_url = chat_completions_endpoint(&self.base_url);
|
||||||
self.http
|
self.http
|
||||||
.post(&request_url)
|
.post(&request_url)
|
||||||
.header("content-type", "application/json")
|
.header("content-type", "application/json")
|
||||||
@@ -866,6 +866,15 @@ pub fn read_base_url(config: OpenAiCompatConfig) -> String {
|
|||||||
std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string())
|
std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn chat_completions_endpoint(base_url: &str) -> String {
|
||||||
|
let trimmed = base_url.trim_end_matches('/');
|
||||||
|
if trimmed.ends_with("/chat/completions") {
|
||||||
|
trimmed.to_string()
|
||||||
|
} else {
|
||||||
|
format!("{trimmed}/chat/completions")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
|
fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
|
||||||
headers
|
headers
|
||||||
.get(REQUEST_ID_HEADER)
|
.get(REQUEST_ID_HEADER)
|
||||||
@@ -927,8 +936,8 @@ impl StringExt for String {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::{
|
use super::{
|
||||||
build_chat_completion_request, normalize_finish_reason, openai_tool_choice,
|
build_chat_completion_request, chat_completions_endpoint, normalize_finish_reason,
|
||||||
parse_tool_arguments, OpenAiCompatClient, OpenAiCompatConfig,
|
openai_tool_choice, parse_tool_arguments, OpenAiCompatClient, OpenAiCompatConfig,
|
||||||
};
|
};
|
||||||
use crate::error::ApiError;
|
use crate::error::ApiError;
|
||||||
use crate::types::{
|
use crate::types::{
|
||||||
@@ -1010,6 +1019,22 @@ mod tests {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn endpoint_builder_accepts_base_urls_and_full_endpoints() {
|
||||||
|
assert_eq!(
|
||||||
|
chat_completions_endpoint("https://api.x.ai/v1"),
|
||||||
|
"https://api.x.ai/v1/chat/completions"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
chat_completions_endpoint("https://api.x.ai/v1/"),
|
||||||
|
"https://api.x.ai/v1/chat/completions"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
chat_completions_endpoint("https://api.x.ai/v1/chat/completions"),
|
||||||
|
"https://api.x.ai/v1/chat/completions"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||||
LOCK.get_or_init(|| Mutex::new(()))
|
LOCK.get_or_init(|| Mutex::new(()))
|
||||||
|
|||||||
@@ -62,6 +62,41 @@ async fn send_message_uses_openai_compatible_endpoint_and_auth() {
|
|||||||
assert_eq!(body["tools"][0]["type"], json!("function"));
|
assert_eq!(body["tools"][0]["type"], json!("function"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn send_message_accepts_full_chat_completions_endpoint_override() {
|
||||||
|
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||||
|
let body = concat!(
|
||||||
|
"{",
|
||||||
|
"\"id\":\"chatcmpl_full_endpoint\",",
|
||||||
|
"\"model\":\"grok-3\",",
|
||||||
|
"\"choices\":[{",
|
||||||
|
"\"message\":{\"role\":\"assistant\",\"content\":\"Endpoint override works\",\"tool_calls\":[]},",
|
||||||
|
"\"finish_reason\":\"stop\"",
|
||||||
|
"}],",
|
||||||
|
"\"usage\":{\"prompt_tokens\":7,\"completion_tokens\":3}",
|
||||||
|
"}"
|
||||||
|
);
|
||||||
|
let server = spawn_server(
|
||||||
|
state.clone(),
|
||||||
|
vec![http_response("200 OK", "application/json", body)],
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let endpoint_url = format!("{}/chat/completions", server.base_url());
|
||||||
|
let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai())
|
||||||
|
.with_base_url(endpoint_url);
|
||||||
|
let response = client
|
||||||
|
.send_message(&sample_request(false))
|
||||||
|
.await
|
||||||
|
.expect("request should succeed");
|
||||||
|
|
||||||
|
assert_eq!(response.total_tokens(), 10);
|
||||||
|
|
||||||
|
let captured = state.lock().await;
|
||||||
|
let request = captured.first().expect("server should capture request");
|
||||||
|
assert_eq!(request.path, "/chat/completions");
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn stream_message_normalizes_text_and_multiple_tool_calls() {
|
async fn stream_message_normalizes_text_and_multiple_tool_calls() {
|
||||||
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||||
|
|||||||
@@ -1907,13 +1907,14 @@ fn build_runtime(
|
|||||||
permission_mode: PermissionMode,
|
permission_mode: PermissionMode,
|
||||||
) -> Result<ConversationRuntime<ProviderRuntimeClient, CliToolExecutor>, Box<dyn std::error::Error>>
|
) -> Result<ConversationRuntime<ProviderRuntimeClient, CliToolExecutor>, Box<dyn std::error::Error>>
|
||||||
{
|
{
|
||||||
|
let feature_config = build_runtime_feature_config()?;
|
||||||
Ok(ConversationRuntime::new_with_features(
|
Ok(ConversationRuntime::new_with_features(
|
||||||
session,
|
session,
|
||||||
ProviderRuntimeClient::new(model, enable_tools, emit_output, allowed_tools.clone())?,
|
ProviderRuntimeClient::new(model, enable_tools, emit_output, allowed_tools.clone())?,
|
||||||
CliToolExecutor::new(allowed_tools, emit_output),
|
CliToolExecutor::new(allowed_tools, emit_output),
|
||||||
permission_policy(permission_mode),
|
permission_policy(permission_mode),
|
||||||
system_prompt,
|
system_prompt,
|
||||||
build_runtime_feature_config()?,
|
&feature_config,
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user