use futures::{SinkExt, Stream, StreamExt}; use reqwest::Client; use serde::Deserialize; use crate::core::types::{ConversationMessage, StreamEvent}; use super::ModelProvider; /// Errors that can occur when constructing or using a [`ClaudeProvider`]. #[derive(Debug, thiserror::Error)] pub enum ClaudeProviderError { /// The `ANTHROPIC_API_KEY` environment variable is not set. #[error("ANTHROPIC_API_KEY environment variable not set")] MissingApiKey, /// An HTTP-level error from the reqwest client. #[error("HTTP request failed: {0}")] Http(#[from] reqwest::Error), } /// [`ModelProvider`] implementation that streams responses from the Anthropic Messages API. /// /// Calls `POST /v1/messages` with `"stream": true` and parses the resulting /// [Server-Sent Events][sse] stream into [`StreamEvent`]s. /// /// # Authentication /// /// Reads the API key from the `ANTHROPIC_API_KEY` environment variable. /// See the [Anthropic authentication docs][auth] for how to obtain a key. /// /// # API version /// /// Sends the `anthropic-version: 2023-06-01` header on every request, which is /// the stable baseline version required by the API. See the /// [versioning docs][versioning] for details on how Anthropic handles API versions. /// /// [sse]: https://html.spec.whatwg.org/multipage/server-sent-events.html /// [auth]: https://docs.anthropic.com/en/api/getting-started#authentication /// [versioning]: https://docs.anthropic.com/en/api/versioning pub struct ClaudeProvider { api_key: String, client: Client, model: String, } impl ClaudeProvider { /// Create a `ClaudeProvider` reading `ANTHROPIC_API_KEY` from the environment. /// The caller must supply the model ID (e.g. `"claude-opus-4-6"`). /// /// See the [models overview][models] for available model IDs. /// /// [models]: https://docs.anthropic.com/en/docs/about-claude/models/overview pub fn from_env(model: impl Into) -> Result { let api_key = std::env::var("ANTHROPIC_API_KEY").map_err(|_| ClaudeProviderError::MissingApiKey)?; Ok(Self { api_key, client: Client::new(), model: model.into(), }) } } impl ModelProvider for ClaudeProvider { fn stream<'a>( &'a self, messages: &'a [ConversationMessage], ) -> impl Stream + Send + 'a { let (mut tx, rx) = futures::channel::mpsc::channel(32); let client = self.client.clone(); let api_key = self.api_key.clone(); let model = self.model.clone(); let messages = messages.to_vec(); tokio::spawn(async move { run_stream(client, api_key, model, messages, &mut tx).await; }); rx } } /// POST to `/v1/messages` with `stream: true`, then parse the SSE response into /// [`StreamEvent`]s and forward them to `tx`. /// /// # Request shape /// /// ```json /// { /// "model": "", /// "max_tokens": 8192, /// "stream": true, /// "messages": [{ "role": "user"|"assistant", "content": "" }, ...] /// } /// ``` /// /// See the [Messages API reference][messages-api] for the full schema. /// /// # SSE stream lifecycle /// /// With streaming enabled the API sends a sequence of /// [Server-Sent Events][sse] separated by blank lines (`\n\n`). Each event /// has an `event:` line naming its type and a `data:` line containing a JSON /// object. The full event sequence for a successful turn is: /// /// ```text /// event: message_start -> InputTokens(n) /// event: content_block_start -> (ignored -- signals a new content block) /// event: ping -> (ignored -- keepalive) /// event: content_block_delta -> TextDelta(chunk) (repeated) /// event: content_block_stop -> (ignored -- signals end of content block) /// event: message_delta -> OutputTokens(n) /// event: message_stop -> Done /// ``` /// /// We stop reading as soon as `Done` is emitted; any bytes arriving after /// `message_stop` are discarded. /// /// See the [streaming reference][streaming] for the authoritative description /// of each event type and its JSON payload. /// /// [messages-api]: https://docs.anthropic.com/en/api/messages /// [streaming]: https://docs.anthropic.com/en/api/messages-streaming /// [sse]: https://html.spec.whatwg.org/multipage/server-sent-events.html async fn run_stream( client: Client, api_key: String, model: String, messages: Vec, tx: &mut futures::channel::mpsc::Sender, ) { let body = serde_json::json!({ "model": model, "max_tokens": 8192, "stream": true, "messages": messages, }); let response = match client .post("https://api.anthropic.com/v1/messages") .header("x-api-key", &api_key) .header("anthropic-version", "2023-06-01") .header("content-type", "application/json") .json(&body) .send() .await { Ok(r) => r, Err(e) => { let _ = tx.send(StreamEvent::Error(e.to_string())).await; return; } }; if !response.status().is_success() { let status = response.status(); let body_text = match response.text().await { Ok(t) => t, Err(e) => format!("(failed to read error body: {e})"), }; let _ = tx .send(StreamEvent::Error(format!("HTTP {status}: {body_text}"))) .await; return; } let mut stream = response.bytes_stream(); let mut buffer: Vec = Vec::new(); while let Some(chunk) = stream.next().await { match chunk { Err(e) => { let _ = tx.send(StreamEvent::Error(e.to_string())).await; return; } Ok(bytes) => { buffer.extend_from_slice(&bytes); // Drain complete SSE events (delimited by blank lines). loop { match find_double_newline(&buffer) { None => break, Some(pos) => { let event_bytes: Vec = buffer.drain(..pos + 2).collect(); let event_str = String::from_utf8_lossy(&event_bytes); if let Some(event) = parse_sse_event(&event_str) { let is_done = matches!(event, StreamEvent::Done); let _ = tx.send(event).await; if is_done { return; } } } } } } } } let _ = tx.send(StreamEvent::Done).await; } /// Return the byte offset of the first `\n\n` in `buf`, or `None`. /// /// SSE uses a blank line (two consecutive newlines) as the event boundary. /// See [Section 9.2.6 of the SSE spec][sse-dispatch]. /// /// [sse-dispatch]: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation fn find_double_newline(buf: &[u8]) -> Option { buf.windows(2).position(|w| w == b"\n\n") } // -- SSE JSON types ----------------------------------------------------------- // // These structs mirror the subset of the Anthropic SSE payload we actually // consume. Unknown fields are silently ignored by serde. Full schemas are // documented in the [streaming reference][streaming]. // // [streaming]: https://docs.anthropic.com/en/api/messages-streaming /// Top-level SSE data object. The `type` field selects which other fields /// are present; we use `Option` for all of them so a single struct covers /// every event type without needing an enum. #[derive(Deserialize, Debug)] struct SseEvent { #[serde(rename = "type")] event_type: String, /// Present on `content_block_delta` events. delta: Option, /// Present on `message_start` events; carries initial token usage. message: Option, /// Present on `message_delta` events; carries final output token count. usage: Option, } /// The `delta` object inside a `content_block_delta` event. /// /// `type` is `"text_delta"` for plain text chunks; other delta types /// (e.g. `"input_json_delta"` for tool-use blocks) are not yet handled. #[derive(Deserialize, Debug)] struct SseDelta { #[serde(rename = "type")] delta_type: Option, /// The text chunk; present when `delta_type == "text_delta"`. text: Option, } /// The `message` object inside a `message_start` event. #[derive(Deserialize, Debug)] struct SseMessageStart { usage: Option, } /// Token counts reported at the start and end of a turn. /// /// `input_tokens` is set in the `message_start` event; /// `output_tokens` is set in the `message_delta` event. /// Both fields are `Option` so the same struct works for both events. #[derive(Deserialize, Debug)] struct SseUsage { input_tokens: Option, output_tokens: Option, } /// Parse a single SSE event string into a [`StreamEvent`], returning `None` for /// event types we don't care about (`ping`, `content_block_start`, /// `content_block_stop`). /// /// # SSE format /// /// Each event is a block of `field: value` lines. We only read the `data:` /// field; the `event:` line is redundant with the `type` key inside the JSON /// payload so we ignore it. See the [SSE spec][sse-fields] for the full field /// grammar. /// /// # Mapping to [`StreamEvent`] /// /// | API event type | JSON path | Emits | /// |----------------------|------------------------------------|------------------------------| /// | `message_start` | `.message.usage.input_tokens` | `InputTokens(n)` | /// | `content_block_delta`| `.delta.type == "text_delta"` | `TextDelta(chunk)` | /// | `message_delta` | `.usage.output_tokens` | `OutputTokens(n)` | /// | `message_stop` | n/a | `Done` | /// | everything else | n/a | `None` (caller skips) | /// /// [sse-fields]: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream fn parse_sse_event(event_str: &str) -> Option { // SSE events may have multiple fields; we only need `data:`. let data = event_str .lines() .find_map(|line| line.strip_prefix("data: "))?; let event: SseEvent = serde_json::from_str(data).ok()?; match event.event_type.as_str() { "message_start" => event .message .and_then(|m| m.usage) .and_then(|u| u.input_tokens) .map(StreamEvent::InputTokens), "content_block_delta" => { let delta = event.delta?; if delta.delta_type.as_deref() == Some("text_delta") { delta.text.map(StreamEvent::TextDelta) } else { None } } // usage lives at the top level of message_delta, not inside delta. "message_delta" => event .usage .and_then(|u| u.output_tokens) .map(StreamEvent::OutputTokens), "message_stop" => Some(StreamEvent::Done), // error, ping, content_block_start, content_block_stop -- ignored or // handled by the caller. _ => None, } } // -- Tests -------------------------------------------------------------------- #[cfg(test)] mod tests { use super::*; use crate::core::types::Role; /// A minimal but complete Anthropic SSE fixture. const SSE_FIXTURE: &str = concat!( "event: message_start\n", "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"type\":\"message\",", "\"role\":\"assistant\",\"content\":[],\"model\":\"claude-opus-4-6\",", "\"stop_reason\":null,\"stop_sequence\":null,", "\"usage\":{\"input_tokens\":10,\"cache_creation_input_tokens\":0,", "\"cache_read_input_tokens\":0,\"output_tokens\":1}}}\n", "\n", "event: content_block_start\n", "data: {\"type\":\"content_block_start\",\"index\":0,", "\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n", "\n", "event: ping\n", "data: {\"type\":\"ping\"}\n", "\n", "event: content_block_delta\n", "data: {\"type\":\"content_block_delta\",\"index\":0,", "\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n", "\n", "event: content_block_delta\n", "data: {\"type\":\"content_block_delta\",\"index\":0,", "\"delta\":{\"type\":\"text_delta\",\"text\":\", world!\"}}\n", "\n", "event: content_block_stop\n", "data: {\"type\":\"content_block_stop\",\"index\":0}\n", "\n", "event: message_delta\n", "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",", "\"stop_sequence\":null},\"usage\":{\"output_tokens\":5}}\n", "\n", "event: message_stop\n", "data: {\"type\":\"message_stop\"}\n", "\n", ); #[test] fn test_parse_sse_events_from_fixture() { let events: Vec = SSE_FIXTURE .split("\n\n") .filter(|s| !s.trim().is_empty()) .filter_map(parse_sse_event) .collect(); // content_block_start, ping, content_block_stop -> None (filtered out) assert_eq!(events.len(), 5); assert!(matches!(events[0], StreamEvent::InputTokens(10))); assert!(matches!(&events[1], StreamEvent::TextDelta(s) if s == "Hello")); assert!(matches!(&events[2], StreamEvent::TextDelta(s) if s == ", world!")); assert!(matches!(events[3], StreamEvent::OutputTokens(5))); assert!(matches!(events[4], StreamEvent::Done)); } #[test] fn test_parse_message_stop_yields_done() { let event_str = "event: message_stop\ndata: {\"type\":\"message_stop\"}\n"; assert!(matches!( parse_sse_event(event_str), Some(StreamEvent::Done) )); } #[test] fn test_parse_ping_yields_none() { let event_str = "event: ping\ndata: {\"type\":\"ping\"}\n"; assert!(parse_sse_event(event_str).is_none()); } #[test] fn test_parse_content_block_start_yields_none() { let event_str = concat!( "event: content_block_start\n", "data: {\"type\":\"content_block_start\",\"index\":0,", "\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n", ); assert!(parse_sse_event(event_str).is_none()); } #[test] fn test_messages_serialize_to_anthropic_format() { let messages = vec![ ConversationMessage { role: Role::User, content: "Hello".to_string(), }, ConversationMessage { role: Role::Assistant, content: "Hi there!".to_string(), }, ]; let json = serde_json::json!({ "model": "claude-opus-4-6", "max_tokens": 8192, "stream": true, "messages": messages, }); assert_eq!(json["messages"][0]["role"], "user"); assert_eq!(json["messages"][0]["content"], "Hello"); assert_eq!(json["messages"][1]["role"], "assistant"); assert_eq!(json["messages"][1]["content"], "Hi there!"); assert_eq!(json["stream"], true); assert!(json["max_tokens"].as_u64().unwrap() > 0); } #[test] fn test_find_double_newline() { assert_eq!(find_double_newline(b"abc\n\ndef"), Some(3)); assert_eq!(find_double_newline(b"abc\ndef"), None); assert_eq!(find_double_newline(b"\n\n"), Some(0)); assert_eq!(find_double_newline(b""), None); } }