Tool Use.
This commit is contained in:
parent
6b85ff3cb8
commit
0c1c928498
20 changed files with 1822 additions and 129 deletions
|
|
@ -2,7 +2,7 @@ use futures::{SinkExt, Stream, StreamExt};
|
|||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::core::types::{ConversationMessage, StreamEvent};
|
||||
use crate::core::types::{ConversationMessage, StreamEvent, ToolDefinition};
|
||||
|
||||
use super::ModelProvider;
|
||||
|
||||
|
|
@ -64,15 +64,17 @@ impl ModelProvider for ClaudeProvider {
|
|||
fn stream<'a>(
|
||||
&'a self,
|
||||
messages: &'a [ConversationMessage],
|
||||
tools: &'a [ToolDefinition],
|
||||
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
||||
let (mut tx, rx) = futures::channel::mpsc::channel(32);
|
||||
let client = self.client.clone();
|
||||
let api_key = self.api_key.clone();
|
||||
let model = self.model.clone();
|
||||
let messages = messages.to_vec();
|
||||
let tools = tools.to_vec();
|
||||
|
||||
tokio::spawn(async move {
|
||||
run_stream(client, api_key, model, messages, &mut tx).await;
|
||||
run_stream(client, api_key, model, messages, tools, &mut tx).await;
|
||||
});
|
||||
|
||||
rx
|
||||
|
|
@ -126,14 +128,18 @@ async fn run_stream(
|
|||
api_key: String,
|
||||
model: String,
|
||||
messages: Vec<ConversationMessage>,
|
||||
tools: Vec<ToolDefinition>,
|
||||
tx: &mut futures::channel::mpsc::Sender<StreamEvent>,
|
||||
) {
|
||||
let body = serde_json::json!({
|
||||
let mut body = serde_json::json!({
|
||||
"model": model,
|
||||
"max_tokens": 8192,
|
||||
"stream": true,
|
||||
"messages": messages,
|
||||
});
|
||||
if !tools.is_empty() {
|
||||
body["tools"] = serde_json::to_value(&tools).unwrap_or_default();
|
||||
}
|
||||
|
||||
let response = match client
|
||||
.post("https://api.anthropic.com/v1/messages")
|
||||
|
|
@ -223,6 +229,8 @@ fn find_double_newline(buf: &[u8]) -> Option<usize> {
|
|||
struct SseEvent {
|
||||
#[serde(rename = "type")]
|
||||
event_type: String,
|
||||
/// Present on `content_block_start` events; describes the new block.
|
||||
content_block: Option<SseContentBlock>,
|
||||
/// Present on `content_block_delta` events.
|
||||
delta: Option<SseDelta>,
|
||||
/// Present on `message_start` events; carries initial token usage.
|
||||
|
|
@ -231,16 +239,29 @@ struct SseEvent {
|
|||
usage: Option<SseUsage>,
|
||||
}
|
||||
|
||||
/// The `content_block` object inside a `content_block_start` event.
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct SseContentBlock {
|
||||
#[serde(rename = "type")]
|
||||
block_type: String,
|
||||
/// Tool-use block ID; present when `block_type == "tool_use"`.
|
||||
id: Option<String>,
|
||||
/// Tool name; present when `block_type == "tool_use"`.
|
||||
name: Option<String>,
|
||||
}
|
||||
|
||||
/// The `delta` object inside a `content_block_delta` event.
|
||||
///
|
||||
/// `type` is `"text_delta"` for plain text chunks; other delta types
|
||||
/// (e.g. `"input_json_delta"` for tool-use blocks) are not yet handled.
|
||||
/// `type` is `"text_delta"` for plain text chunks, or `"input_json_delta"`
|
||||
/// for streaming tool-use input JSON.
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct SseDelta {
|
||||
#[serde(rename = "type")]
|
||||
delta_type: Option<String>,
|
||||
/// The text chunk; present when `delta_type == "text_delta"`.
|
||||
text: Option<String>,
|
||||
/// Partial JSON for tool input; present when `delta_type == "input_json_delta"`.
|
||||
partial_json: Option<String>,
|
||||
}
|
||||
|
||||
/// The `message` object inside a `message_start` event.
|
||||
|
|
@ -273,13 +294,20 @@ struct SseUsage {
|
|||
///
|
||||
/// # Mapping to [`StreamEvent`]
|
||||
///
|
||||
/// | API event type | JSON path | Emits |
|
||||
/// |----------------------|------------------------------------|------------------------------|
|
||||
/// | `message_start` | `.message.usage.input_tokens` | `InputTokens(n)` |
|
||||
/// | `content_block_delta`| `.delta.type == "text_delta"` | `TextDelta(chunk)` |
|
||||
/// | `message_delta` | `.usage.output_tokens` | `OutputTokens(n)` |
|
||||
/// | `message_stop` | n/a | `Done` |
|
||||
/// | everything else | n/a | `None` (caller skips) |
|
||||
/// | API event type | JSON path | Emits |
|
||||
/// |------------------------|------------------------------------|------------------------------|
|
||||
/// | `message_start` | `.message.usage.input_tokens` | `InputTokens(n)` |
|
||||
/// | `content_block_start` | `.content_block.type == "tool_use"`| `ToolUseStart { id, name }` |
|
||||
/// | `content_block_delta` | `.delta.type == "text_delta"` | `TextDelta(chunk)` |
|
||||
/// | `content_block_delta` | `.delta.type == "input_json_delta"`| `ToolUseInputDelta(json)` |
|
||||
/// | `content_block_stop` | n/a | `ToolUseDone` |
|
||||
/// | `message_delta` | `.usage.output_tokens` | `OutputTokens(n)` |
|
||||
/// | `message_stop` | n/a | `Done` |
|
||||
/// | everything else | n/a | `None` (caller skips) |
|
||||
///
|
||||
/// Note: `content_block_stop` always emits `ToolUseDone`. The caller is
|
||||
/// responsible for tracking whether a tool-use block is active and ignoring
|
||||
/// `ToolUseDone` when it follows a text block.
|
||||
///
|
||||
/// [sse-fields]: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream
|
||||
fn parse_sse_event(event_str: &str) -> Option<StreamEvent> {
|
||||
|
|
@ -297,15 +325,29 @@ fn parse_sse_event(event_str: &str) -> Option<StreamEvent> {
|
|||
.and_then(|u| u.input_tokens)
|
||||
.map(StreamEvent::InputTokens),
|
||||
|
||||
"content_block_delta" => {
|
||||
let delta = event.delta?;
|
||||
if delta.delta_type.as_deref() == Some("text_delta") {
|
||||
delta.text.map(StreamEvent::TextDelta)
|
||||
"content_block_start" => {
|
||||
let block = event.content_block?;
|
||||
if block.block_type == "tool_use" {
|
||||
Some(StreamEvent::ToolUseStart {
|
||||
id: block.id.unwrap_or_default(),
|
||||
name: block.name.unwrap_or_default(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
"content_block_delta" => {
|
||||
let delta = event.delta?;
|
||||
match delta.delta_type.as_deref() {
|
||||
Some("text_delta") => delta.text.map(StreamEvent::TextDelta),
|
||||
Some("input_json_delta") => delta.partial_json.map(StreamEvent::ToolUseInputDelta),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
"content_block_stop" => Some(StreamEvent::ToolUseDone),
|
||||
|
||||
// usage lives at the top level of message_delta, not inside delta.
|
||||
"message_delta" => event
|
||||
.usage
|
||||
|
|
@ -314,8 +356,7 @@ fn parse_sse_event(event_str: &str) -> Option<StreamEvent> {
|
|||
|
||||
"message_stop" => Some(StreamEvent::Done),
|
||||
|
||||
// error, ping, content_block_start, content_block_stop -- ignored or
|
||||
// handled by the caller.
|
||||
// error, ping -- ignored.
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
|
@ -371,13 +412,15 @@ mod tests {
|
|||
.filter_map(parse_sse_event)
|
||||
.collect();
|
||||
|
||||
// content_block_start, ping, content_block_stop -> None (filtered out)
|
||||
assert_eq!(events.len(), 5);
|
||||
// content_block_start (text) and ping -> None (filtered out)
|
||||
// content_block_stop -> ToolUseDone (always emitted; caller filters)
|
||||
assert_eq!(events.len(), 6);
|
||||
assert!(matches!(events[0], StreamEvent::InputTokens(10)));
|
||||
assert!(matches!(&events[1], StreamEvent::TextDelta(s) if s == "Hello"));
|
||||
assert!(matches!(&events[2], StreamEvent::TextDelta(s) if s == ", world!"));
|
||||
assert!(matches!(events[3], StreamEvent::OutputTokens(5)));
|
||||
assert!(matches!(events[4], StreamEvent::Done));
|
||||
assert!(matches!(events[3], StreamEvent::ToolUseDone));
|
||||
assert!(matches!(events[4], StreamEvent::OutputTokens(5)));
|
||||
assert!(matches!(events[5], StreamEvent::Done));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -408,14 +451,8 @@ mod tests {
|
|||
#[test]
|
||||
fn test_messages_serialize_to_anthropic_format() {
|
||||
let messages = vec![
|
||||
ConversationMessage {
|
||||
role: Role::User,
|
||||
content: "Hello".to_string(),
|
||||
},
|
||||
ConversationMessage {
|
||||
role: Role::Assistant,
|
||||
content: "Hi there!".to_string(),
|
||||
},
|
||||
ConversationMessage::text(Role::User, "Hello"),
|
||||
ConversationMessage::text(Role::Assistant, "Hi there!"),
|
||||
];
|
||||
|
||||
let json = serde_json::json!({
|
||||
|
|
@ -433,6 +470,65 @@ mod tests {
|
|||
assert!(json["max_tokens"].as_u64().unwrap() > 0);
|
||||
}
|
||||
|
||||
/// SSE fixture for a response that contains a tool-use block.
|
||||
///
|
||||
/// Sequence: message_start -> content_block_start (tool_use) ->
|
||||
/// content_block_delta (input_json_delta) -> content_block_stop ->
|
||||
/// message_delta -> message_stop
|
||||
const TOOL_USE_SSE_FIXTURE: &str = concat!(
|
||||
"event: message_start\n",
|
||||
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_456\",\"type\":\"message\",",
|
||||
"\"role\":\"assistant\",\"content\":[],\"model\":\"claude-opus-4-6\",",
|
||||
"\"stop_reason\":null,\"stop_sequence\":null,",
|
||||
"\"usage\":{\"input_tokens\":25,\"cache_creation_input_tokens\":0,",
|
||||
"\"cache_read_input_tokens\":0,\"output_tokens\":1}}}\n",
|
||||
"\n",
|
||||
"event: content_block_start\n",
|
||||
"data: {\"type\":\"content_block_start\",\"index\":0,",
|
||||
"\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_01ABC\",\"name\":\"read_file\",\"input\":{}}}\n",
|
||||
"\n",
|
||||
"event: content_block_delta\n",
|
||||
"data: {\"type\":\"content_block_delta\",\"index\":0,",
|
||||
"\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"path\\\": \\\"src/\"}}\n",
|
||||
"\n",
|
||||
"event: content_block_delta\n",
|
||||
"data: {\"type\":\"content_block_delta\",\"index\":0,",
|
||||
"\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"main.rs\\\"}\"}}\n",
|
||||
"\n",
|
||||
"event: content_block_stop\n",
|
||||
"data: {\"type\":\"content_block_stop\",\"index\":0}\n",
|
||||
"\n",
|
||||
"event: message_delta\n",
|
||||
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",",
|
||||
"\"stop_sequence\":null},\"usage\":{\"output_tokens\":15}}\n",
|
||||
"\n",
|
||||
"event: message_stop\n",
|
||||
"data: {\"type\":\"message_stop\"}\n",
|
||||
"\n",
|
||||
);
|
||||
|
||||
#[test]
|
||||
fn test_parse_tool_use_sse_fixture() {
|
||||
let events: Vec<StreamEvent> = TOOL_USE_SSE_FIXTURE
|
||||
.split("\n\n")
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.filter_map(parse_sse_event)
|
||||
.collect();
|
||||
|
||||
assert_eq!(events.len(), 7);
|
||||
assert!(matches!(events[0], StreamEvent::InputTokens(25)));
|
||||
assert!(
|
||||
matches!(&events[1], StreamEvent::ToolUseStart { id, name } if id == "toolu_01ABC" && name == "read_file")
|
||||
);
|
||||
assert!(
|
||||
matches!(&events[2], StreamEvent::ToolUseInputDelta(s) if s == "{\"path\": \"src/")
|
||||
);
|
||||
assert!(matches!(&events[3], StreamEvent::ToolUseInputDelta(s) if s == "main.rs\"}"));
|
||||
assert!(matches!(events[4], StreamEvent::ToolUseDone));
|
||||
assert!(matches!(events[5], StreamEvent::OutputTokens(15)));
|
||||
assert!(matches!(events[6], StreamEvent::Done));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_double_newline() {
|
||||
assert_eq!(find_double_newline(b"abc\n\ndef"), Some(3));
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue