539 lines
20 KiB
Rust
539 lines
20 KiB
Rust
use futures::{SinkExt, Stream, StreamExt};
|
|
use reqwest::Client;
|
|
use serde::Deserialize;
|
|
|
|
use crate::core::types::{ConversationMessage, StreamEvent, ToolDefinition};
|
|
|
|
use super::ModelProvider;
|
|
|
|
/// Errors that can occur when constructing or using a [`ClaudeProvider`].
|
|
#[derive(Debug, thiserror::Error)]
|
|
pub enum ClaudeProviderError {
|
|
/// The `ANTHROPIC_API_KEY` environment variable is not set.
|
|
#[error("ANTHROPIC_API_KEY environment variable not set")]
|
|
MissingApiKey,
|
|
/// An HTTP-level error from the reqwest client.
|
|
#[error("HTTP request failed: {0}")]
|
|
Http(#[from] reqwest::Error),
|
|
}
|
|
|
|
/// [`ModelProvider`] implementation that streams responses from the Anthropic Messages API.
|
|
///
|
|
/// Calls `POST /v1/messages` with `"stream": true` and parses the resulting
|
|
/// [Server-Sent Events][sse] stream into [`StreamEvent`]s.
|
|
///
|
|
/// # Authentication
|
|
///
|
|
/// Reads the API key from the `ANTHROPIC_API_KEY` environment variable.
|
|
/// See the [Anthropic authentication docs][auth] for how to obtain a key.
|
|
///
|
|
/// # API version
|
|
///
|
|
/// Sends the `anthropic-version: 2023-06-01` header on every request, which is
|
|
/// the stable baseline version required by the API. See the
|
|
/// [versioning docs][versioning] for details on how Anthropic handles API versions.
|
|
///
|
|
/// [sse]: https://html.spec.whatwg.org/multipage/server-sent-events.html
|
|
/// [auth]: https://docs.anthropic.com/en/api/getting-started#authentication
|
|
/// [versioning]: https://docs.anthropic.com/en/api/versioning
|
|
pub struct ClaudeProvider {
|
|
api_key: String,
|
|
client: Client,
|
|
model: String,
|
|
}
|
|
|
|
impl ClaudeProvider {
|
|
/// Create a `ClaudeProvider` reading `ANTHROPIC_API_KEY` from the environment.
|
|
/// The caller must supply the model ID (e.g. `"claude-opus-4-6"`).
|
|
///
|
|
/// See the [models overview][models] for available model IDs.
|
|
///
|
|
/// [models]: https://docs.anthropic.com/en/docs/about-claude/models/overview
|
|
pub fn from_env(model: impl Into<String>) -> Result<Self, ClaudeProviderError> {
|
|
let api_key =
|
|
std::env::var("ANTHROPIC_API_KEY").map_err(|_| ClaudeProviderError::MissingApiKey)?;
|
|
Ok(Self {
|
|
api_key,
|
|
client: Client::new(),
|
|
model: model.into(),
|
|
})
|
|
}
|
|
}
|
|
|
|
impl ModelProvider for ClaudeProvider {
|
|
fn stream<'a>(
|
|
&'a self,
|
|
messages: &'a [ConversationMessage],
|
|
tools: &'a [ToolDefinition],
|
|
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
|
let (mut tx, rx) = futures::channel::mpsc::channel(32);
|
|
let client = self.client.clone();
|
|
let api_key = self.api_key.clone();
|
|
let model = self.model.clone();
|
|
let messages = messages.to_vec();
|
|
let tools = tools.to_vec();
|
|
|
|
tokio::spawn(async move {
|
|
run_stream(client, api_key, model, messages, tools, &mut tx).await;
|
|
});
|
|
|
|
rx
|
|
}
|
|
}
|
|
|
|
/// POST to `/v1/messages` with `stream: true`, then parse the SSE response into
|
|
/// [`StreamEvent`]s and forward them to `tx`.
|
|
///
|
|
/// # Request shape
|
|
///
|
|
/// ```json
|
|
/// {
|
|
/// "model": "<model-id>",
|
|
/// "max_tokens": 8192,
|
|
/// "stream": true,
|
|
/// "messages": [{ "role": "user"|"assistant", "content": "<text>" }, ...]
|
|
/// }
|
|
/// ```
|
|
///
|
|
/// See the [Messages API reference][messages-api] for the full schema.
|
|
///
|
|
/// # SSE stream lifecycle
|
|
///
|
|
/// With streaming enabled the API sends a sequence of
|
|
/// [Server-Sent Events][sse] separated by blank lines (`\n\n`). Each event
|
|
/// has an `event:` line naming its type and a `data:` line containing a JSON
|
|
/// object. The full event sequence for a successful turn is:
|
|
///
|
|
/// ```text
|
|
/// event: message_start -> InputTokens(n)
|
|
/// event: content_block_start -> (ignored -- signals a new content block)
|
|
/// event: ping -> (ignored -- keepalive)
|
|
/// event: content_block_delta -> TextDelta(chunk) (repeated)
|
|
/// event: content_block_stop -> (ignored -- signals end of content block)
|
|
/// event: message_delta -> OutputTokens(n)
|
|
/// event: message_stop -> Done
|
|
/// ```
|
|
///
|
|
/// We stop reading as soon as `Done` is emitted; any bytes arriving after
|
|
/// `message_stop` are discarded.
|
|
///
|
|
/// See the [streaming reference][streaming] for the authoritative description
|
|
/// of each event type and its JSON payload.
|
|
///
|
|
/// [messages-api]: https://docs.anthropic.com/en/api/messages
|
|
/// [streaming]: https://docs.anthropic.com/en/api/messages-streaming
|
|
/// [sse]: https://html.spec.whatwg.org/multipage/server-sent-events.html
|
|
async fn run_stream(
|
|
client: Client,
|
|
api_key: String,
|
|
model: String,
|
|
messages: Vec<ConversationMessage>,
|
|
tools: Vec<ToolDefinition>,
|
|
tx: &mut futures::channel::mpsc::Sender<StreamEvent>,
|
|
) {
|
|
let mut body = serde_json::json!({
|
|
"model": model,
|
|
"max_tokens": 8192,
|
|
"stream": true,
|
|
"messages": messages,
|
|
});
|
|
if !tools.is_empty() {
|
|
body["tools"] = serde_json::to_value(&tools).unwrap_or_default();
|
|
}
|
|
|
|
let response = match client
|
|
.post("https://api.anthropic.com/v1/messages")
|
|
.header("x-api-key", &api_key)
|
|
.header("anthropic-version", "2023-06-01")
|
|
.header("content-type", "application/json")
|
|
.json(&body)
|
|
.send()
|
|
.await
|
|
{
|
|
Ok(r) => r,
|
|
Err(e) => {
|
|
let _ = tx.send(StreamEvent::Error(e.to_string())).await;
|
|
return;
|
|
}
|
|
};
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let body_text = match response.text().await {
|
|
Ok(t) => t,
|
|
Err(e) => format!("(failed to read error body: {e})"),
|
|
};
|
|
let _ = tx
|
|
.send(StreamEvent::Error(format!("HTTP {status}: {body_text}")))
|
|
.await;
|
|
return;
|
|
}
|
|
|
|
let mut stream = response.bytes_stream();
|
|
let mut buffer: Vec<u8> = Vec::new();
|
|
|
|
while let Some(chunk) = stream.next().await {
|
|
match chunk {
|
|
Err(e) => {
|
|
let _ = tx.send(StreamEvent::Error(e.to_string())).await;
|
|
return;
|
|
}
|
|
Ok(bytes) => {
|
|
buffer.extend_from_slice(&bytes);
|
|
// Drain complete SSE events (delimited by blank lines).
|
|
loop {
|
|
match find_double_newline(&buffer) {
|
|
None => break,
|
|
Some(pos) => {
|
|
let event_bytes: Vec<u8> = buffer.drain(..pos + 2).collect();
|
|
let event_str = String::from_utf8_lossy(&event_bytes);
|
|
if let Some(event) = parse_sse_event(&event_str) {
|
|
let is_done = matches!(event, StreamEvent::Done);
|
|
let _ = tx.send(event).await;
|
|
if is_done {
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let _ = tx.send(StreamEvent::Done).await;
|
|
}
|
|
|
|
/// Return the byte offset of the first `\n\n` in `buf`, or `None`.
|
|
///
|
|
/// SSE uses a blank line (two consecutive newlines) as the event boundary.
|
|
/// See [Section 9.2.6 of the SSE spec][sse-dispatch].
|
|
///
|
|
/// [sse-dispatch]: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation
|
|
fn find_double_newline(buf: &[u8]) -> Option<usize> {
|
|
buf.windows(2).position(|w| w == b"\n\n")
|
|
}
|
|
|
|
// -- SSE JSON types -----------------------------------------------------------
|
|
//
|
|
// These structs mirror the subset of the Anthropic SSE payload we actually
|
|
// consume. Unknown fields are silently ignored by serde. Full schemas are
|
|
// documented in the [streaming reference][streaming].
|
|
//
|
|
// [streaming]: https://docs.anthropic.com/en/api/messages-streaming
|
|
|
|
/// Top-level SSE data object. The `type` field selects which other fields
|
|
/// are present; we use `Option` for all of them so a single struct covers
|
|
/// every event type without needing an enum.
|
|
#[derive(Deserialize, Debug)]
|
|
struct SseEvent {
|
|
#[serde(rename = "type")]
|
|
event_type: String,
|
|
/// Present on `content_block_start` events; describes the new block.
|
|
content_block: Option<SseContentBlock>,
|
|
/// Present on `content_block_delta` events.
|
|
delta: Option<SseDelta>,
|
|
/// Present on `message_start` events; carries initial token usage.
|
|
message: Option<SseMessageStart>,
|
|
/// Present on `message_delta` events; carries final output token count.
|
|
usage: Option<SseUsage>,
|
|
}
|
|
|
|
/// The `content_block` object inside a `content_block_start` event.
|
|
#[derive(Deserialize, Debug)]
|
|
struct SseContentBlock {
|
|
#[serde(rename = "type")]
|
|
block_type: String,
|
|
/// Tool-use block ID; present when `block_type == "tool_use"`.
|
|
id: Option<String>,
|
|
/// Tool name; present when `block_type == "tool_use"`.
|
|
name: Option<String>,
|
|
}
|
|
|
|
/// The `delta` object inside a `content_block_delta` event.
|
|
///
|
|
/// `type` is `"text_delta"` for plain text chunks, or `"input_json_delta"`
|
|
/// for streaming tool-use input JSON.
|
|
#[derive(Deserialize, Debug)]
|
|
struct SseDelta {
|
|
#[serde(rename = "type")]
|
|
delta_type: Option<String>,
|
|
/// The text chunk; present when `delta_type == "text_delta"`.
|
|
text: Option<String>,
|
|
/// Partial JSON for tool input; present when `delta_type == "input_json_delta"`.
|
|
partial_json: Option<String>,
|
|
}
|
|
|
|
/// The `message` object inside a `message_start` event.
|
|
#[derive(Deserialize, Debug)]
|
|
struct SseMessageStart {
|
|
usage: Option<SseUsage>,
|
|
}
|
|
|
|
/// Token counts reported at the start and end of a turn.
|
|
///
|
|
/// `input_tokens` is set in the `message_start` event;
|
|
/// `output_tokens` is set in the `message_delta` event.
|
|
/// Both fields are `Option` so the same struct works for both events.
|
|
#[derive(Deserialize, Debug)]
|
|
struct SseUsage {
|
|
input_tokens: Option<u32>,
|
|
output_tokens: Option<u32>,
|
|
}
|
|
|
|
/// Parse a single SSE event string into a [`StreamEvent`], returning `None` for
|
|
/// event types we don't care about (`ping`, `content_block_start`,
|
|
/// `content_block_stop`).
|
|
///
|
|
/// # SSE format
|
|
///
|
|
/// Each event is a block of `field: value` lines. We only read the `data:`
|
|
/// field; the `event:` line is redundant with the `type` key inside the JSON
|
|
/// payload so we ignore it. See the [SSE spec][sse-fields] for the full field
|
|
/// grammar.
|
|
///
|
|
/// # Mapping to [`StreamEvent`]
|
|
///
|
|
/// | API event type | JSON path | Emits |
|
|
/// |------------------------|------------------------------------|------------------------------|
|
|
/// | `message_start` | `.message.usage.input_tokens` | `InputTokens(n)` |
|
|
/// | `content_block_start` | `.content_block.type == "tool_use"`| `ToolUseStart { id, name }` |
|
|
/// | `content_block_delta` | `.delta.type == "text_delta"` | `TextDelta(chunk)` |
|
|
/// | `content_block_delta` | `.delta.type == "input_json_delta"`| `ToolUseInputDelta(json)` |
|
|
/// | `content_block_stop` | n/a | `ToolUseDone` |
|
|
/// | `message_delta` | `.usage.output_tokens` | `OutputTokens(n)` |
|
|
/// | `message_stop` | n/a | `Done` |
|
|
/// | everything else | n/a | `None` (caller skips) |
|
|
///
|
|
/// Note: `content_block_stop` always emits `ToolUseDone`. The caller is
|
|
/// responsible for tracking whether a tool-use block is active and ignoring
|
|
/// `ToolUseDone` when it follows a text block.
|
|
///
|
|
/// [sse-fields]: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream
|
|
fn parse_sse_event(event_str: &str) -> Option<StreamEvent> {
|
|
// SSE events may have multiple fields; we only need `data:`.
|
|
let data = event_str
|
|
.lines()
|
|
.find_map(|line| line.strip_prefix("data: "))?;
|
|
|
|
let event: SseEvent = serde_json::from_str(data).ok()?;
|
|
|
|
match event.event_type.as_str() {
|
|
"message_start" => event
|
|
.message
|
|
.and_then(|m| m.usage)
|
|
.and_then(|u| u.input_tokens)
|
|
.map(StreamEvent::InputTokens),
|
|
|
|
"content_block_start" => {
|
|
let block = event.content_block?;
|
|
if block.block_type == "tool_use" {
|
|
Some(StreamEvent::ToolUseStart {
|
|
id: block.id.unwrap_or_default(),
|
|
name: block.name.unwrap_or_default(),
|
|
})
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
"content_block_delta" => {
|
|
let delta = event.delta?;
|
|
match delta.delta_type.as_deref() {
|
|
Some("text_delta") => delta.text.map(StreamEvent::TextDelta),
|
|
Some("input_json_delta") => delta.partial_json.map(StreamEvent::ToolUseInputDelta),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
"content_block_stop" => Some(StreamEvent::ToolUseDone),
|
|
|
|
// usage lives at the top level of message_delta, not inside delta.
|
|
"message_delta" => event
|
|
.usage
|
|
.and_then(|u| u.output_tokens)
|
|
.map(StreamEvent::OutputTokens),
|
|
|
|
"message_stop" => Some(StreamEvent::Done),
|
|
|
|
// error, ping -- ignored.
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
// -- Tests --------------------------------------------------------------------
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::core::types::Role;
|
|
|
|
/// A minimal but complete Anthropic SSE fixture.
|
|
const SSE_FIXTURE: &str = concat!(
|
|
"event: message_start\n",
|
|
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"type\":\"message\",",
|
|
"\"role\":\"assistant\",\"content\":[],\"model\":\"claude-opus-4-6\",",
|
|
"\"stop_reason\":null,\"stop_sequence\":null,",
|
|
"\"usage\":{\"input_tokens\":10,\"cache_creation_input_tokens\":0,",
|
|
"\"cache_read_input_tokens\":0,\"output_tokens\":1}}}\n",
|
|
"\n",
|
|
"event: content_block_start\n",
|
|
"data: {\"type\":\"content_block_start\",\"index\":0,",
|
|
"\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n",
|
|
"\n",
|
|
"event: ping\n",
|
|
"data: {\"type\":\"ping\"}\n",
|
|
"\n",
|
|
"event: content_block_delta\n",
|
|
"data: {\"type\":\"content_block_delta\",\"index\":0,",
|
|
"\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n",
|
|
"\n",
|
|
"event: content_block_delta\n",
|
|
"data: {\"type\":\"content_block_delta\",\"index\":0,",
|
|
"\"delta\":{\"type\":\"text_delta\",\"text\":\", world!\"}}\n",
|
|
"\n",
|
|
"event: content_block_stop\n",
|
|
"data: {\"type\":\"content_block_stop\",\"index\":0}\n",
|
|
"\n",
|
|
"event: message_delta\n",
|
|
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",",
|
|
"\"stop_sequence\":null},\"usage\":{\"output_tokens\":5}}\n",
|
|
"\n",
|
|
"event: message_stop\n",
|
|
"data: {\"type\":\"message_stop\"}\n",
|
|
"\n",
|
|
);
|
|
|
|
#[test]
|
|
fn test_parse_sse_events_from_fixture() {
|
|
let events: Vec<StreamEvent> = SSE_FIXTURE
|
|
.split("\n\n")
|
|
.filter(|s| !s.trim().is_empty())
|
|
.filter_map(parse_sse_event)
|
|
.collect();
|
|
|
|
// content_block_start (text) and ping -> None (filtered out)
|
|
// content_block_stop -> ToolUseDone (always emitted; caller filters)
|
|
assert_eq!(events.len(), 6);
|
|
assert!(matches!(events[0], StreamEvent::InputTokens(10)));
|
|
assert!(matches!(&events[1], StreamEvent::TextDelta(s) if s == "Hello"));
|
|
assert!(matches!(&events[2], StreamEvent::TextDelta(s) if s == ", world!"));
|
|
assert!(matches!(events[3], StreamEvent::ToolUseDone));
|
|
assert!(matches!(events[4], StreamEvent::OutputTokens(5)));
|
|
assert!(matches!(events[5], StreamEvent::Done));
|
|
}
|
|
|
|
#[test]
|
|
fn test_parse_message_stop_yields_done() {
|
|
let event_str = "event: message_stop\ndata: {\"type\":\"message_stop\"}\n";
|
|
assert!(matches!(
|
|
parse_sse_event(event_str),
|
|
Some(StreamEvent::Done)
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn test_parse_ping_yields_none() {
|
|
let event_str = "event: ping\ndata: {\"type\":\"ping\"}\n";
|
|
assert!(parse_sse_event(event_str).is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_parse_content_block_start_yields_none() {
|
|
let event_str = concat!(
|
|
"event: content_block_start\n",
|
|
"data: {\"type\":\"content_block_start\",\"index\":0,",
|
|
"\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n",
|
|
);
|
|
assert!(parse_sse_event(event_str).is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_messages_serialize_to_anthropic_format() {
|
|
let messages = vec![
|
|
ConversationMessage::text(Role::User, "Hello"),
|
|
ConversationMessage::text(Role::Assistant, "Hi there!"),
|
|
];
|
|
|
|
let json = serde_json::json!({
|
|
"model": "claude-opus-4-6",
|
|
"max_tokens": 8192,
|
|
"stream": true,
|
|
"messages": messages,
|
|
});
|
|
|
|
assert_eq!(json["messages"][0]["role"], "user");
|
|
assert_eq!(json["messages"][0]["content"], "Hello");
|
|
assert_eq!(json["messages"][1]["role"], "assistant");
|
|
assert_eq!(json["messages"][1]["content"], "Hi there!");
|
|
assert_eq!(json["stream"], true);
|
|
assert!(json["max_tokens"].as_u64().unwrap() > 0);
|
|
}
|
|
|
|
/// SSE fixture for a response that contains a tool-use block.
|
|
///
|
|
/// Sequence: message_start -> content_block_start (tool_use) ->
|
|
/// content_block_delta (input_json_delta) -> content_block_stop ->
|
|
/// message_delta -> message_stop
|
|
const TOOL_USE_SSE_FIXTURE: &str = concat!(
|
|
"event: message_start\n",
|
|
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_456\",\"type\":\"message\",",
|
|
"\"role\":\"assistant\",\"content\":[],\"model\":\"claude-opus-4-6\",",
|
|
"\"stop_reason\":null,\"stop_sequence\":null,",
|
|
"\"usage\":{\"input_tokens\":25,\"cache_creation_input_tokens\":0,",
|
|
"\"cache_read_input_tokens\":0,\"output_tokens\":1}}}\n",
|
|
"\n",
|
|
"event: content_block_start\n",
|
|
"data: {\"type\":\"content_block_start\",\"index\":0,",
|
|
"\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_01ABC\",\"name\":\"read_file\",\"input\":{}}}\n",
|
|
"\n",
|
|
"event: content_block_delta\n",
|
|
"data: {\"type\":\"content_block_delta\",\"index\":0,",
|
|
"\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"path\\\": \\\"src/\"}}\n",
|
|
"\n",
|
|
"event: content_block_delta\n",
|
|
"data: {\"type\":\"content_block_delta\",\"index\":0,",
|
|
"\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"main.rs\\\"}\"}}\n",
|
|
"\n",
|
|
"event: content_block_stop\n",
|
|
"data: {\"type\":\"content_block_stop\",\"index\":0}\n",
|
|
"\n",
|
|
"event: message_delta\n",
|
|
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",",
|
|
"\"stop_sequence\":null},\"usage\":{\"output_tokens\":15}}\n",
|
|
"\n",
|
|
"event: message_stop\n",
|
|
"data: {\"type\":\"message_stop\"}\n",
|
|
"\n",
|
|
);
|
|
|
|
#[test]
|
|
fn test_parse_tool_use_sse_fixture() {
|
|
let events: Vec<StreamEvent> = TOOL_USE_SSE_FIXTURE
|
|
.split("\n\n")
|
|
.filter(|s| !s.trim().is_empty())
|
|
.filter_map(parse_sse_event)
|
|
.collect();
|
|
|
|
assert_eq!(events.len(), 7);
|
|
assert!(matches!(events[0], StreamEvent::InputTokens(25)));
|
|
assert!(
|
|
matches!(&events[1], StreamEvent::ToolUseStart { id, name } if id == "toolu_01ABC" && name == "read_file")
|
|
);
|
|
assert!(
|
|
matches!(&events[2], StreamEvent::ToolUseInputDelta(s) if s == "{\"path\": \"src/")
|
|
);
|
|
assert!(matches!(&events[3], StreamEvent::ToolUseInputDelta(s) if s == "main.rs\"}"));
|
|
assert!(matches!(events[4], StreamEvent::ToolUseDone));
|
|
assert!(matches!(events[5], StreamEvent::OutputTokens(15)));
|
|
assert!(matches!(events[6], StreamEvent::Done));
|
|
}
|
|
|
|
#[test]
|
|
fn test_find_double_newline() {
|
|
assert_eq!(find_double_newline(b"abc\n\ndef"), Some(3));
|
|
assert_eq!(find_double_newline(b"abc\ndef"), None);
|
|
assert_eq!(find_double_newline(b"\n\n"), Some(0));
|
|
assert_eq!(find_double_newline(b""), None);
|
|
}
|
|
}
|