433 lines
13 KiB
Rust
433 lines
13 KiB
Rust
use std::collections::BTreeMap;
|
|
use std::fmt::{Display, Formatter};
|
|
use std::fs;
|
|
use std::path::Path;
|
|
|
|
use crate::json::{JsonError, JsonValue};
|
|
use crate::usage::TokenUsage;
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub enum MessageRole {
|
|
System,
|
|
User,
|
|
Assistant,
|
|
Tool,
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub enum ContentBlock {
|
|
Text {
|
|
text: String,
|
|
},
|
|
ToolUse {
|
|
id: String,
|
|
name: String,
|
|
input: String,
|
|
},
|
|
ToolResult {
|
|
tool_use_id: String,
|
|
tool_name: String,
|
|
output: String,
|
|
is_error: bool,
|
|
},
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct ConversationMessage {
|
|
pub role: MessageRole,
|
|
pub blocks: Vec<ContentBlock>,
|
|
pub usage: Option<TokenUsage>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct Session {
|
|
pub version: u32,
|
|
pub messages: Vec<ConversationMessage>,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum SessionError {
|
|
Io(std::io::Error),
|
|
Json(JsonError),
|
|
Format(String),
|
|
}
|
|
|
|
impl Display for SessionError {
|
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
match self {
|
|
Self::Io(error) => write!(f, "{error}"),
|
|
Self::Json(error) => write!(f, "{error}"),
|
|
Self::Format(error) => write!(f, "{error}"),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for SessionError {}
|
|
|
|
impl From<std::io::Error> for SessionError {
|
|
fn from(value: std::io::Error) -> Self {
|
|
Self::Io(value)
|
|
}
|
|
}
|
|
|
|
impl From<JsonError> for SessionError {
|
|
fn from(value: JsonError) -> Self {
|
|
Self::Json(value)
|
|
}
|
|
}
|
|
|
|
impl Session {
|
|
#[must_use]
|
|
pub fn new() -> Self {
|
|
Self {
|
|
version: 1,
|
|
messages: Vec::new(),
|
|
}
|
|
}
|
|
|
|
pub fn save_to_path(&self, path: impl AsRef<Path>) -> Result<(), SessionError> {
|
|
fs::write(path, self.to_json().render())?;
|
|
Ok(())
|
|
}
|
|
|
|
pub fn load_from_path(path: impl AsRef<Path>) -> Result<Self, SessionError> {
|
|
let contents = fs::read_to_string(path)?;
|
|
Self::from_json(&JsonValue::parse(&contents)?)
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn to_json(&self) -> JsonValue {
|
|
let mut object = BTreeMap::new();
|
|
object.insert(
|
|
"version".to_string(),
|
|
JsonValue::Number(i64::from(self.version)),
|
|
);
|
|
object.insert(
|
|
"messages".to_string(),
|
|
JsonValue::Array(
|
|
self.messages
|
|
.iter()
|
|
.map(ConversationMessage::to_json)
|
|
.collect(),
|
|
),
|
|
);
|
|
JsonValue::Object(object)
|
|
}
|
|
|
|
pub fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
|
|
let object = value
|
|
.as_object()
|
|
.ok_or_else(|| SessionError::Format("session must be an object".to_string()))?;
|
|
let version = object
|
|
.get("version")
|
|
.and_then(JsonValue::as_i64)
|
|
.ok_or_else(|| SessionError::Format("missing version".to_string()))?;
|
|
let version = u32::try_from(version)
|
|
.map_err(|_| SessionError::Format("version out of range".to_string()))?;
|
|
let messages = object
|
|
.get("messages")
|
|
.and_then(JsonValue::as_array)
|
|
.ok_or_else(|| SessionError::Format("missing messages".to_string()))?
|
|
.iter()
|
|
.map(ConversationMessage::from_json)
|
|
.collect::<Result<Vec<_>, _>>()?;
|
|
Ok(Self { version, messages })
|
|
}
|
|
}
|
|
|
|
impl Default for Session {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
impl ConversationMessage {
|
|
#[must_use]
|
|
pub fn user_text(text: impl Into<String>) -> Self {
|
|
Self {
|
|
role: MessageRole::User,
|
|
blocks: vec![ContentBlock::Text { text: text.into() }],
|
|
usage: None,
|
|
}
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn assistant(blocks: Vec<ContentBlock>) -> Self {
|
|
Self {
|
|
role: MessageRole::Assistant,
|
|
blocks,
|
|
usage: None,
|
|
}
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn assistant_with_usage(blocks: Vec<ContentBlock>, usage: Option<TokenUsage>) -> Self {
|
|
Self {
|
|
role: MessageRole::Assistant,
|
|
blocks,
|
|
usage,
|
|
}
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn tool_result(
|
|
tool_use_id: impl Into<String>,
|
|
tool_name: impl Into<String>,
|
|
output: impl Into<String>,
|
|
is_error: bool,
|
|
) -> Self {
|
|
Self {
|
|
role: MessageRole::Tool,
|
|
blocks: vec![ContentBlock::ToolResult {
|
|
tool_use_id: tool_use_id.into(),
|
|
tool_name: tool_name.into(),
|
|
output: output.into(),
|
|
is_error,
|
|
}],
|
|
usage: None,
|
|
}
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn to_json(&self) -> JsonValue {
|
|
let mut object = BTreeMap::new();
|
|
object.insert(
|
|
"role".to_string(),
|
|
JsonValue::String(
|
|
match self.role {
|
|
MessageRole::System => "system",
|
|
MessageRole::User => "user",
|
|
MessageRole::Assistant => "assistant",
|
|
MessageRole::Tool => "tool",
|
|
}
|
|
.to_string(),
|
|
),
|
|
);
|
|
object.insert(
|
|
"blocks".to_string(),
|
|
JsonValue::Array(self.blocks.iter().map(ContentBlock::to_json).collect()),
|
|
);
|
|
if let Some(usage) = self.usage {
|
|
object.insert("usage".to_string(), usage_to_json(usage));
|
|
}
|
|
JsonValue::Object(object)
|
|
}
|
|
|
|
fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
|
|
let object = value
|
|
.as_object()
|
|
.ok_or_else(|| SessionError::Format("message must be an object".to_string()))?;
|
|
let role = match object
|
|
.get("role")
|
|
.and_then(JsonValue::as_str)
|
|
.ok_or_else(|| SessionError::Format("missing role".to_string()))?
|
|
{
|
|
"system" => MessageRole::System,
|
|
"user" => MessageRole::User,
|
|
"assistant" => MessageRole::Assistant,
|
|
"tool" => MessageRole::Tool,
|
|
other => {
|
|
return Err(SessionError::Format(format!(
|
|
"unsupported message role: {other}"
|
|
)))
|
|
}
|
|
};
|
|
let blocks = object
|
|
.get("blocks")
|
|
.and_then(JsonValue::as_array)
|
|
.ok_or_else(|| SessionError::Format("missing blocks".to_string()))?
|
|
.iter()
|
|
.map(ContentBlock::from_json)
|
|
.collect::<Result<Vec<_>, _>>()?;
|
|
let usage = object.get("usage").map(usage_from_json).transpose()?;
|
|
Ok(Self {
|
|
role,
|
|
blocks,
|
|
usage,
|
|
})
|
|
}
|
|
}
|
|
|
|
impl ContentBlock {
|
|
#[must_use]
|
|
pub fn to_json(&self) -> JsonValue {
|
|
let mut object = BTreeMap::new();
|
|
match self {
|
|
Self::Text { text } => {
|
|
object.insert("type".to_string(), JsonValue::String("text".to_string()));
|
|
object.insert("text".to_string(), JsonValue::String(text.clone()));
|
|
}
|
|
Self::ToolUse { id, name, input } => {
|
|
object.insert(
|
|
"type".to_string(),
|
|
JsonValue::String("tool_use".to_string()),
|
|
);
|
|
object.insert("id".to_string(), JsonValue::String(id.clone()));
|
|
object.insert("name".to_string(), JsonValue::String(name.clone()));
|
|
object.insert("input".to_string(), JsonValue::String(input.clone()));
|
|
}
|
|
Self::ToolResult {
|
|
tool_use_id,
|
|
tool_name,
|
|
output,
|
|
is_error,
|
|
} => {
|
|
object.insert(
|
|
"type".to_string(),
|
|
JsonValue::String("tool_result".to_string()),
|
|
);
|
|
object.insert(
|
|
"tool_use_id".to_string(),
|
|
JsonValue::String(tool_use_id.clone()),
|
|
);
|
|
object.insert(
|
|
"tool_name".to_string(),
|
|
JsonValue::String(tool_name.clone()),
|
|
);
|
|
object.insert("output".to_string(), JsonValue::String(output.clone()));
|
|
object.insert("is_error".to_string(), JsonValue::Bool(*is_error));
|
|
}
|
|
}
|
|
JsonValue::Object(object)
|
|
}
|
|
|
|
fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
|
|
let object = value
|
|
.as_object()
|
|
.ok_or_else(|| SessionError::Format("block must be an object".to_string()))?;
|
|
match object
|
|
.get("type")
|
|
.and_then(JsonValue::as_str)
|
|
.ok_or_else(|| SessionError::Format("missing block type".to_string()))?
|
|
{
|
|
"text" => Ok(Self::Text {
|
|
text: required_string(object, "text")?,
|
|
}),
|
|
"tool_use" => Ok(Self::ToolUse {
|
|
id: required_string(object, "id")?,
|
|
name: required_string(object, "name")?,
|
|
input: required_string(object, "input")?,
|
|
}),
|
|
"tool_result" => Ok(Self::ToolResult {
|
|
tool_use_id: required_string(object, "tool_use_id")?,
|
|
tool_name: required_string(object, "tool_name")?,
|
|
output: required_string(object, "output")?,
|
|
is_error: object
|
|
.get("is_error")
|
|
.and_then(JsonValue::as_bool)
|
|
.ok_or_else(|| SessionError::Format("missing is_error".to_string()))?,
|
|
}),
|
|
other => Err(SessionError::Format(format!(
|
|
"unsupported block type: {other}"
|
|
))),
|
|
}
|
|
}
|
|
}
|
|
|
|
fn usage_to_json(usage: TokenUsage) -> JsonValue {
|
|
let mut object = BTreeMap::new();
|
|
object.insert(
|
|
"input_tokens".to_string(),
|
|
JsonValue::Number(i64::from(usage.input_tokens)),
|
|
);
|
|
object.insert(
|
|
"output_tokens".to_string(),
|
|
JsonValue::Number(i64::from(usage.output_tokens)),
|
|
);
|
|
object.insert(
|
|
"cache_creation_input_tokens".to_string(),
|
|
JsonValue::Number(i64::from(usage.cache_creation_input_tokens)),
|
|
);
|
|
object.insert(
|
|
"cache_read_input_tokens".to_string(),
|
|
JsonValue::Number(i64::from(usage.cache_read_input_tokens)),
|
|
);
|
|
JsonValue::Object(object)
|
|
}
|
|
|
|
fn usage_from_json(value: &JsonValue) -> Result<TokenUsage, SessionError> {
|
|
let object = value
|
|
.as_object()
|
|
.ok_or_else(|| SessionError::Format("usage must be an object".to_string()))?;
|
|
Ok(TokenUsage {
|
|
input_tokens: required_u32(object, "input_tokens")?,
|
|
output_tokens: required_u32(object, "output_tokens")?,
|
|
cache_creation_input_tokens: required_u32(object, "cache_creation_input_tokens")?,
|
|
cache_read_input_tokens: required_u32(object, "cache_read_input_tokens")?,
|
|
})
|
|
}
|
|
|
|
fn required_string(
|
|
object: &BTreeMap<String, JsonValue>,
|
|
key: &str,
|
|
) -> Result<String, SessionError> {
|
|
object
|
|
.get(key)
|
|
.and_then(JsonValue::as_str)
|
|
.map(ToOwned::to_owned)
|
|
.ok_or_else(|| SessionError::Format(format!("missing {key}")))
|
|
}
|
|
|
|
fn required_u32(object: &BTreeMap<String, JsonValue>, key: &str) -> Result<u32, SessionError> {
|
|
let value = object
|
|
.get(key)
|
|
.and_then(JsonValue::as_i64)
|
|
.ok_or_else(|| SessionError::Format(format!("missing {key}")))?;
|
|
u32::try_from(value).map_err(|_| SessionError::Format(format!("{key} out of range")))
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::{ContentBlock, ConversationMessage, MessageRole, Session};
|
|
use crate::usage::TokenUsage;
|
|
use std::fs;
|
|
use std::time::{SystemTime, UNIX_EPOCH};
|
|
|
|
#[test]
|
|
fn persists_and_restores_session_json() {
|
|
let mut session = Session::new();
|
|
session
|
|
.messages
|
|
.push(ConversationMessage::user_text("hello"));
|
|
session
|
|
.messages
|
|
.push(ConversationMessage::assistant_with_usage(
|
|
vec![
|
|
ContentBlock::Text {
|
|
text: "thinking".to_string(),
|
|
},
|
|
ContentBlock::ToolUse {
|
|
id: "tool-1".to_string(),
|
|
name: "bash".to_string(),
|
|
input: "echo hi".to_string(),
|
|
},
|
|
],
|
|
Some(TokenUsage {
|
|
input_tokens: 10,
|
|
output_tokens: 4,
|
|
cache_creation_input_tokens: 1,
|
|
cache_read_input_tokens: 2,
|
|
}),
|
|
));
|
|
session.messages.push(ConversationMessage::tool_result(
|
|
"tool-1", "bash", "hi", false,
|
|
));
|
|
|
|
let nanos = SystemTime::now()
|
|
.duration_since(UNIX_EPOCH)
|
|
.expect("system time should be after epoch")
|
|
.as_nanos();
|
|
let path = std::env::temp_dir().join(format!("runtime-session-{nanos}.json"));
|
|
session.save_to_path(&path).expect("session should save");
|
|
let restored = Session::load_from_path(&path).expect("session should load");
|
|
fs::remove_file(&path).expect("temp file should be removable");
|
|
|
|
assert_eq!(restored, session);
|
|
assert_eq!(restored.messages[2].role, MessageRole::Tool);
|
|
assert_eq!(
|
|
restored.messages[1].usage.expect("usage").total_tokens(),
|
|
17
|
|
);
|
|
}
|
|
}
|