From 0a547e0c4421f1a7dbbf246b0c2c787947c48c53 Mon Sep 17 00:00:00 2001 From: Drew Galbraith Date: Mon, 23 Feb 2026 22:26:18 -0800 Subject: [PATCH] Claude Wrapper. --- Cargo.lock | 2 + Cargo.toml | 2 +- src/provider/claude.rs | 440 +++++++++++++++++++++++++++++++++++++++++ src/provider/mod.rs | 20 ++ 4 files changed, 463 insertions(+), 1 deletion(-) create mode 100644 src/provider/claude.rs diff --git a/Cargo.lock b/Cargo.lock index 4c51a7e..2bea462 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1714,6 +1714,8 @@ dependencies = [ "rustls", "rustls-pki-types", "rustls-platform-verifier", + "serde", + "serde_json", "sync_wrapper", "tokio", "tokio-rustls", diff --git a/Cargo.toml b/Cargo.toml index 11e6c0c..528b3b5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,5 +12,5 @@ serde_json = "1" thiserror = "2" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } -reqwest = { version = "0.13", features = ["stream"] } +reqwest = { version = "0.13", features = ["stream", "json"] } futures = "0.3" diff --git a/src/provider/claude.rs b/src/provider/claude.rs new file mode 100644 index 0000000..0857625 --- /dev/null +++ b/src/provider/claude.rs @@ -0,0 +1,440 @@ +// Items used only in later phases or tests. +#![allow(dead_code)] + +use futures::{SinkExt, Stream, StreamExt}; +use reqwest::Client; +use serde::Deserialize; + +use crate::core::types::{ConversationMessage, StreamEvent}; + +use super::ModelProvider; + +/// Errors that can occur when constructing or using a [`ClaudeProvider`]. +#[derive(Debug, thiserror::Error)] +pub enum ClaudeProviderError { + /// The `ANTHROPIC_API_KEY` environment variable is not set. + #[error("ANTHROPIC_API_KEY environment variable not set")] + MissingApiKey, + /// An HTTP-level error from the reqwest client. + #[error("HTTP request failed: {0}")] + Http(#[from] reqwest::Error), +} + +/// [`ModelProvider`] implementation that streams responses from the Anthropic Messages API. +/// +/// Calls `POST /v1/messages` with `"stream": true` and parses the resulting +/// [Server-Sent Events][sse] stream into [`StreamEvent`]s. +/// +/// # Authentication +/// +/// Reads the API key from the `ANTHROPIC_API_KEY` environment variable. +/// See the [Anthropic authentication docs][auth] for how to obtain a key. +/// +/// # API version +/// +/// Sends the `anthropic-version: 2023-06-01` header on every request, which is +/// the stable baseline version required by the API. See the +/// [versioning docs][versioning] for details on how Anthropic handles API versions. +/// +/// [sse]: https://html.spec.whatwg.org/multipage/server-sent-events.html +/// [auth]: https://docs.anthropic.com/en/api/getting-started#authentication +/// [versioning]: https://docs.anthropic.com/en/api/versioning +pub struct ClaudeProvider { + api_key: String, + client: Client, + model: String, +} + +impl ClaudeProvider { + /// Create a `ClaudeProvider` reading `ANTHROPIC_API_KEY` from the environment. + /// The caller must supply the model ID (e.g. `"claude-opus-4-6"`). + /// + /// See the [models overview][models] for available model IDs. + /// + /// [models]: https://docs.anthropic.com/en/docs/about-claude/models/overview + pub fn from_env(model: impl Into) -> Result { + let api_key = + std::env::var("ANTHROPIC_API_KEY").map_err(|_| ClaudeProviderError::MissingApiKey)?; + Ok(Self { + api_key, + client: Client::new(), + model: model.into(), + }) + } +} + +impl ModelProvider for ClaudeProvider { + fn stream<'a>( + &'a self, + messages: &'a [ConversationMessage], + ) -> impl Stream + Send + 'a { + let (mut tx, rx) = futures::channel::mpsc::channel(32); + let client = self.client.clone(); + let api_key = self.api_key.clone(); + let model = self.model.clone(); + let messages = messages.to_vec(); + + tokio::spawn(async move { + run_stream(client, api_key, model, messages, &mut tx).await; + }); + + rx + } +} + +/// POST to `/v1/messages` with `stream: true`, then parse the SSE response into +/// [`StreamEvent`]s and forward them to `tx`. +/// +/// # Request shape +/// +/// ```json +/// { +/// "model": "", +/// "max_tokens": 8192, +/// "stream": true, +/// "messages": [{ "role": "user"|"assistant", "content": "" }, …] +/// } +/// ``` +/// +/// See the [Messages API reference][messages-api] for the full schema. +/// +/// # SSE stream lifecycle +/// +/// With streaming enabled the API sends a sequence of +/// [Server-Sent Events][sse] separated by blank lines (`\n\n`). Each event +/// has an `event:` line naming its type and a `data:` line containing a JSON +/// object. The full event sequence for a successful turn is: +/// +/// ```text +/// event: message_start → InputTokens(n) +/// event: content_block_start → (ignored — signals a new content block) +/// event: ping → (ignored — keepalive) +/// event: content_block_delta → TextDelta(chunk) (repeated) +/// event: content_block_stop → (ignored — signals end of content block) +/// event: message_delta → OutputTokens(n) +/// event: message_stop → Done +/// ``` +/// +/// We stop reading as soon as `Done` is emitted; any bytes arriving after +/// `message_stop` are discarded. +/// +/// See the [streaming reference][streaming] for the authoritative description +/// of each event type and its JSON payload. +/// +/// [messages-api]: https://docs.anthropic.com/en/api/messages +/// [streaming]: https://docs.anthropic.com/en/api/messages-streaming +/// [sse]: https://html.spec.whatwg.org/multipage/server-sent-events.html +async fn run_stream( + client: Client, + api_key: String, + model: String, + messages: Vec, + tx: &mut futures::channel::mpsc::Sender, +) { + let body = serde_json::json!({ + "model": model, + "max_tokens": 8192, + "stream": true, + "messages": messages, + }); + + let response = match client + .post("https://api.anthropic.com/v1/messages") + .header("x-api-key", &api_key) + .header("anthropic-version", "2023-06-01") + .header("content-type", "application/json") + .json(&body) + .send() + .await + { + Ok(r) => r, + Err(e) => { + let _ = tx.send(StreamEvent::Error(e.to_string())).await; + return; + } + }; + + if !response.status().is_success() { + let status = response.status(); + let body_text = response.text().await.unwrap_or_default(); + let _ = tx + .send(StreamEvent::Error(format!("HTTP {status}: {body_text}"))) + .await; + return; + } + + let mut stream = response.bytes_stream(); + let mut buffer: Vec = Vec::new(); + + while let Some(chunk) = stream.next().await { + match chunk { + Err(e) => { + let _ = tx.send(StreamEvent::Error(e.to_string())).await; + return; + } + Ok(bytes) => { + buffer.extend_from_slice(&bytes); + // Drain complete SSE events (delimited by blank lines). + loop { + match find_double_newline(&buffer) { + None => break, + Some(pos) => { + let event_bytes: Vec = buffer.drain(..pos + 2).collect(); + let event_str = String::from_utf8_lossy(&event_bytes); + if let Some(event) = parse_sse_event(&event_str) { + let is_done = matches!(event, StreamEvent::Done); + let _ = tx.send(event).await; + if is_done { + return; + } + } + } + } + } + } + } + } + + let _ = tx.send(StreamEvent::Done).await; +} + +/// Return the byte offset of the first `\n\n` in `buf`, or `None`. +/// +/// SSE uses a blank line (two consecutive newlines) as the event boundary. +/// See [§9.2.6 of the SSE spec][sse-dispatch]. +/// +/// [sse-dispatch]: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation +fn find_double_newline(buf: &[u8]) -> Option { + buf.windows(2).position(|w| w == b"\n\n") +} + +// ── SSE JSON types ──────────────────────────────────────────────────────────── +// +// These structs mirror the subset of the Anthropic SSE payload we actually +// consume. Unknown fields are silently ignored by serde. Full schemas are +// documented in the [streaming reference][streaming]. +// +// [streaming]: https://docs.anthropic.com/en/api/messages-streaming + +/// Top-level SSE data object. The `type` field selects which other fields +/// are present; we use `Option` for all of them so a single struct covers +/// every event type without needing an enum. +#[derive(Deserialize, Debug)] +struct SseEvent { + #[serde(rename = "type")] + event_type: String, + /// Present on `content_block_delta` events. + delta: Option, + /// Present on `message_start` events; carries initial token usage. + message: Option, + /// Present on `message_delta` events; carries final output token count. + usage: Option, +} + +/// The `delta` object inside a `content_block_delta` event. +/// +/// `type` is `"text_delta"` for plain text chunks; other delta types +/// (e.g. `"input_json_delta"` for tool-use blocks) are not yet handled. +#[derive(Deserialize, Debug)] +struct SseDelta { + #[serde(rename = "type")] + delta_type: Option, + /// The text chunk; present when `delta_type == "text_delta"`. + text: Option, +} + +/// The `message` object inside a `message_start` event. +#[derive(Deserialize, Debug)] +struct SseMessageStart { + usage: Option, +} + +/// Token counts reported at the start and end of a turn. +/// +/// `input_tokens` is set in the `message_start` event; +/// `output_tokens` is set in the `message_delta` event. +/// Both fields are `Option` so the same struct works for both events. +#[derive(Deserialize, Debug)] +struct SseUsage { + input_tokens: Option, + output_tokens: Option, +} + +/// Parse a single SSE event string into a [`StreamEvent`], returning `None` for +/// event types we don't care about (`ping`, `content_block_start`, +/// `content_block_stop`). +/// +/// # SSE format +/// +/// Each event is a block of `field: value` lines. We only read the `data:` +/// field; the `event:` line is redundant with the `type` key inside the JSON +/// payload so we ignore it. See the [SSE spec][sse-fields] for the full field +/// grammar. +/// +/// # Mapping to [`StreamEvent`] +/// +/// | API event type | JSON path | Emits | +/// |----------------------|------------------------------------|------------------------------| +/// | `message_start` | `.message.usage.input_tokens` | `InputTokens(n)` | +/// | `content_block_delta`| `.delta.type == "text_delta"` | `TextDelta(chunk)` | +/// | `message_delta` | `.usage.output_tokens` | `OutputTokens(n)` | +/// | `message_stop` | — | `Done` | +/// | everything else | — | `None` (caller skips) | +/// +/// [sse-fields]: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream +fn parse_sse_event(event_str: &str) -> Option { + // SSE events may have multiple fields; we only need `data:`. + let data = event_str + .lines() + .find_map(|line| line.strip_prefix("data: "))?; + + let event: SseEvent = serde_json::from_str(data).ok()?; + + match event.event_type.as_str() { + "message_start" => event + .message + .and_then(|m| m.usage) + .and_then(|u| u.input_tokens) + .map(StreamEvent::InputTokens), + + "content_block_delta" => { + let delta = event.delta?; + if delta.delta_type.as_deref() == Some("text_delta") { + delta.text.map(StreamEvent::TextDelta) + } else { + None + } + } + + // usage lives at the top level of message_delta, not inside delta. + "message_delta" => event + .usage + .and_then(|u| u.output_tokens) + .map(StreamEvent::OutputTokens), + + "message_stop" => Some(StreamEvent::Done), + + // error, ping, content_block_start, content_block_stop — ignored or + // handled by the caller. + _ => None, + } +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::types::Role; + + /// A minimal but complete Anthropic SSE fixture. + const SSE_FIXTURE: &str = concat!( + "event: message_start\n", + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"type\":\"message\",", + "\"role\":\"assistant\",\"content\":[],\"model\":\"claude-opus-4-6\",", + "\"stop_reason\":null,\"stop_sequence\":null,", + "\"usage\":{\"input_tokens\":10,\"cache_creation_input_tokens\":0,", + "\"cache_read_input_tokens\":0,\"output_tokens\":1}}}\n", + "\n", + "event: content_block_start\n", + "data: {\"type\":\"content_block_start\",\"index\":0,", + "\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n", + "\n", + "event: ping\n", + "data: {\"type\":\"ping\"}\n", + "\n", + "event: content_block_delta\n", + "data: {\"type\":\"content_block_delta\",\"index\":0,", + "\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n", + "\n", + "event: content_block_delta\n", + "data: {\"type\":\"content_block_delta\",\"index\":0,", + "\"delta\":{\"type\":\"text_delta\",\"text\":\", world!\"}}\n", + "\n", + "event: content_block_stop\n", + "data: {\"type\":\"content_block_stop\",\"index\":0}\n", + "\n", + "event: message_delta\n", + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",", + "\"stop_sequence\":null},\"usage\":{\"output_tokens\":5}}\n", + "\n", + "event: message_stop\n", + "data: {\"type\":\"message_stop\"}\n", + "\n", + ); + + #[test] + fn test_parse_sse_events_from_fixture() { + let events: Vec = SSE_FIXTURE + .split("\n\n") + .filter(|s| !s.trim().is_empty()) + .filter_map(parse_sse_event) + .collect(); + + // content_block_start, ping, content_block_stop → None (filtered out) + assert_eq!(events.len(), 5); + assert!(matches!(events[0], StreamEvent::InputTokens(10))); + assert!(matches!(&events[1], StreamEvent::TextDelta(s) if s == "Hello")); + assert!(matches!(&events[2], StreamEvent::TextDelta(s) if s == ", world!")); + assert!(matches!(events[3], StreamEvent::OutputTokens(5))); + assert!(matches!(events[4], StreamEvent::Done)); + } + + #[test] + fn test_parse_message_stop_yields_done() { + let event_str = "event: message_stop\ndata: {\"type\":\"message_stop\"}\n"; + assert!(matches!(parse_sse_event(event_str), Some(StreamEvent::Done))); + } + + #[test] + fn test_parse_ping_yields_none() { + let event_str = "event: ping\ndata: {\"type\":\"ping\"}\n"; + assert!(parse_sse_event(event_str).is_none()); + } + + #[test] + fn test_parse_content_block_start_yields_none() { + let event_str = concat!( + "event: content_block_start\n", + "data: {\"type\":\"content_block_start\",\"index\":0,", + "\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n", + ); + assert!(parse_sse_event(event_str).is_none()); + } + + #[test] + fn test_messages_serialize_to_anthropic_format() { + let messages = vec![ + ConversationMessage { + role: Role::User, + content: "Hello".to_string(), + }, + ConversationMessage { + role: Role::Assistant, + content: "Hi there!".to_string(), + }, + ]; + + let json = serde_json::json!({ + "model": "claude-opus-4-6", + "max_tokens": 8192, + "stream": true, + "messages": messages, + }); + + assert_eq!(json["messages"][0]["role"], "user"); + assert_eq!(json["messages"][0]["content"], "Hello"); + assert_eq!(json["messages"][1]["role"], "assistant"); + assert_eq!(json["messages"][1]["content"], "Hi there!"); + assert_eq!(json["stream"], true); + assert!(json["max_tokens"].as_u64().unwrap() > 0); + } + + #[test] + fn test_find_double_newline() { + assert_eq!(find_double_newline(b"abc\n\ndef"), Some(3)); + assert_eq!(find_double_newline(b"abc\ndef"), None); + assert_eq!(find_double_newline(b"\n\n"), Some(0)); + assert_eq!(find_double_newline(b""), None); + } +} diff --git a/src/provider/mod.rs b/src/provider/mod.rs index 8b13789..cf37d66 100644 --- a/src/provider/mod.rs +++ b/src/provider/mod.rs @@ -1 +1,21 @@ +#![allow(dead_code, unused_imports)] +mod claude; + +pub use claude::{ClaudeProvider, ClaudeProviderError}; + +use futures::Stream; + +use crate::core::types::{ConversationMessage, StreamEvent}; + +/// Trait for model providers that can stream conversation responses. +/// +/// Implementors take a conversation history and return a stream of [`StreamEvent`]s. +/// The trait is provider-agnostic — no Claude-specific types appear here. +pub trait ModelProvider: Send + Sync { + /// Stream a response from the model given the conversation history. + fn stream<'a>( + &'a self, + messages: &'a [ConversationMessage], + ) -> impl Stream + Send + 'a; +}