Add tool use to the orchestrator (#4)

Add tool use without sandboxing.

Currently available tools are list dir, read file, write file and exec bash.

Reviewed-on: #4
Co-authored-by: Drew Galbraith <drew@tiramisu.one>
Co-committed-by: Drew Galbraith <drew@tiramisu.one>
This commit is contained in:
Drew 2026-03-02 03:00:13 +00:00 committed by Drew
parent 6b85ff3cb8
commit 797d7564b7
20 changed files with 1822 additions and 129 deletions

View file

@ -2,7 +2,7 @@ use futures::{SinkExt, Stream, StreamExt};
use reqwest::Client;
use serde::Deserialize;
use crate::core::types::{ConversationMessage, StreamEvent};
use crate::core::types::{ConversationMessage, StreamEvent, ToolDefinition};
use super::ModelProvider;
@ -64,15 +64,17 @@ impl ModelProvider for ClaudeProvider {
fn stream<'a>(
&'a self,
messages: &'a [ConversationMessage],
tools: &'a [ToolDefinition],
) -> impl Stream<Item = StreamEvent> + Send + 'a {
let (mut tx, rx) = futures::channel::mpsc::channel(32);
let client = self.client.clone();
let api_key = self.api_key.clone();
let model = self.model.clone();
let messages = messages.to_vec();
let tools = tools.to_vec();
tokio::spawn(async move {
run_stream(client, api_key, model, messages, &mut tx).await;
run_stream(client, api_key, model, messages, tools, &mut tx).await;
});
rx
@ -126,14 +128,18 @@ async fn run_stream(
api_key: String,
model: String,
messages: Vec<ConversationMessage>,
tools: Vec<ToolDefinition>,
tx: &mut futures::channel::mpsc::Sender<StreamEvent>,
) {
let body = serde_json::json!({
let mut body = serde_json::json!({
"model": model,
"max_tokens": 8192,
"stream": true,
"messages": messages,
});
if !tools.is_empty() {
body["tools"] = serde_json::to_value(&tools).unwrap_or_default();
}
let response = match client
.post("https://api.anthropic.com/v1/messages")
@ -223,6 +229,8 @@ fn find_double_newline(buf: &[u8]) -> Option<usize> {
struct SseEvent {
#[serde(rename = "type")]
event_type: String,
/// Present on `content_block_start` events; describes the new block.
content_block: Option<SseContentBlock>,
/// Present on `content_block_delta` events.
delta: Option<SseDelta>,
/// Present on `message_start` events; carries initial token usage.
@ -231,16 +239,29 @@ struct SseEvent {
usage: Option<SseUsage>,
}
/// The `content_block` object inside a `content_block_start` event.
#[derive(Deserialize, Debug)]
struct SseContentBlock {
#[serde(rename = "type")]
block_type: String,
/// Tool-use block ID; present when `block_type == "tool_use"`.
id: Option<String>,
/// Tool name; present when `block_type == "tool_use"`.
name: Option<String>,
}
/// The `delta` object inside a `content_block_delta` event.
///
/// `type` is `"text_delta"` for plain text chunks; other delta types
/// (e.g. `"input_json_delta"` for tool-use blocks) are not yet handled.
/// `type` is `"text_delta"` for plain text chunks, or `"input_json_delta"`
/// for streaming tool-use input JSON.
#[derive(Deserialize, Debug)]
struct SseDelta {
#[serde(rename = "type")]
delta_type: Option<String>,
/// The text chunk; present when `delta_type == "text_delta"`.
text: Option<String>,
/// Partial JSON for tool input; present when `delta_type == "input_json_delta"`.
partial_json: Option<String>,
}
/// The `message` object inside a `message_start` event.
@ -273,13 +294,20 @@ struct SseUsage {
///
/// # Mapping to [`StreamEvent`]
///
/// | API event type | JSON path | Emits |
/// |----------------------|------------------------------------|------------------------------|
/// | `message_start` | `.message.usage.input_tokens` | `InputTokens(n)` |
/// | `content_block_delta`| `.delta.type == "text_delta"` | `TextDelta(chunk)` |
/// | `message_delta` | `.usage.output_tokens` | `OutputTokens(n)` |
/// | `message_stop` | n/a | `Done` |
/// | everything else | n/a | `None` (caller skips) |
/// | API event type | JSON path | Emits |
/// |------------------------|------------------------------------|------------------------------|
/// | `message_start` | `.message.usage.input_tokens` | `InputTokens(n)` |
/// | `content_block_start` | `.content_block.type == "tool_use"`| `ToolUseStart { id, name }` |
/// | `content_block_delta` | `.delta.type == "text_delta"` | `TextDelta(chunk)` |
/// | `content_block_delta` | `.delta.type == "input_json_delta"`| `ToolUseInputDelta(json)` |
/// | `content_block_stop` | n/a | `ToolUseDone` |
/// | `message_delta` | `.usage.output_tokens` | `OutputTokens(n)` |
/// | `message_stop` | n/a | `Done` |
/// | everything else | n/a | `None` (caller skips) |
///
/// Note: `content_block_stop` always emits `ToolUseDone`. The caller is
/// responsible for tracking whether a tool-use block is active and ignoring
/// `ToolUseDone` when it follows a text block.
///
/// [sse-fields]: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream
fn parse_sse_event(event_str: &str) -> Option<StreamEvent> {
@ -297,15 +325,29 @@ fn parse_sse_event(event_str: &str) -> Option<StreamEvent> {
.and_then(|u| u.input_tokens)
.map(StreamEvent::InputTokens),
"content_block_delta" => {
let delta = event.delta?;
if delta.delta_type.as_deref() == Some("text_delta") {
delta.text.map(StreamEvent::TextDelta)
"content_block_start" => {
let block = event.content_block?;
if block.block_type == "tool_use" {
Some(StreamEvent::ToolUseStart {
id: block.id.unwrap_or_default(),
name: block.name.unwrap_or_default(),
})
} else {
None
}
}
"content_block_delta" => {
let delta = event.delta?;
match delta.delta_type.as_deref() {
Some("text_delta") => delta.text.map(StreamEvent::TextDelta),
Some("input_json_delta") => delta.partial_json.map(StreamEvent::ToolUseInputDelta),
_ => None,
}
}
"content_block_stop" => Some(StreamEvent::ToolUseDone),
// usage lives at the top level of message_delta, not inside delta.
"message_delta" => event
.usage
@ -314,8 +356,7 @@ fn parse_sse_event(event_str: &str) -> Option<StreamEvent> {
"message_stop" => Some(StreamEvent::Done),
// error, ping, content_block_start, content_block_stop -- ignored or
// handled by the caller.
// error, ping -- ignored.
_ => None,
}
}
@ -371,13 +412,15 @@ mod tests {
.filter_map(parse_sse_event)
.collect();
// content_block_start, ping, content_block_stop -> None (filtered out)
assert_eq!(events.len(), 5);
// content_block_start (text) and ping -> None (filtered out)
// content_block_stop -> ToolUseDone (always emitted; caller filters)
assert_eq!(events.len(), 6);
assert!(matches!(events[0], StreamEvent::InputTokens(10)));
assert!(matches!(&events[1], StreamEvent::TextDelta(s) if s == "Hello"));
assert!(matches!(&events[2], StreamEvent::TextDelta(s) if s == ", world!"));
assert!(matches!(events[3], StreamEvent::OutputTokens(5)));
assert!(matches!(events[4], StreamEvent::Done));
assert!(matches!(events[3], StreamEvent::ToolUseDone));
assert!(matches!(events[4], StreamEvent::OutputTokens(5)));
assert!(matches!(events[5], StreamEvent::Done));
}
#[test]
@ -408,14 +451,8 @@ mod tests {
#[test]
fn test_messages_serialize_to_anthropic_format() {
let messages = vec![
ConversationMessage {
role: Role::User,
content: "Hello".to_string(),
},
ConversationMessage {
role: Role::Assistant,
content: "Hi there!".to_string(),
},
ConversationMessage::text(Role::User, "Hello"),
ConversationMessage::text(Role::Assistant, "Hi there!"),
];
let json = serde_json::json!({
@ -433,6 +470,65 @@ mod tests {
assert!(json["max_tokens"].as_u64().unwrap() > 0);
}
/// SSE fixture for a response that contains a tool-use block.
///
/// Sequence: message_start -> content_block_start (tool_use) ->
/// content_block_delta (input_json_delta) -> content_block_stop ->
/// message_delta -> message_stop
const TOOL_USE_SSE_FIXTURE: &str = concat!(
"event: message_start\n",
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_456\",\"type\":\"message\",",
"\"role\":\"assistant\",\"content\":[],\"model\":\"claude-opus-4-6\",",
"\"stop_reason\":null,\"stop_sequence\":null,",
"\"usage\":{\"input_tokens\":25,\"cache_creation_input_tokens\":0,",
"\"cache_read_input_tokens\":0,\"output_tokens\":1}}}\n",
"\n",
"event: content_block_start\n",
"data: {\"type\":\"content_block_start\",\"index\":0,",
"\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_01ABC\",\"name\":\"read_file\",\"input\":{}}}\n",
"\n",
"event: content_block_delta\n",
"data: {\"type\":\"content_block_delta\",\"index\":0,",
"\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"path\\\": \\\"src/\"}}\n",
"\n",
"event: content_block_delta\n",
"data: {\"type\":\"content_block_delta\",\"index\":0,",
"\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"main.rs\\\"}\"}}\n",
"\n",
"event: content_block_stop\n",
"data: {\"type\":\"content_block_stop\",\"index\":0}\n",
"\n",
"event: message_delta\n",
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",",
"\"stop_sequence\":null},\"usage\":{\"output_tokens\":15}}\n",
"\n",
"event: message_stop\n",
"data: {\"type\":\"message_stop\"}\n",
"\n",
);
#[test]
fn test_parse_tool_use_sse_fixture() {
let events: Vec<StreamEvent> = TOOL_USE_SSE_FIXTURE
.split("\n\n")
.filter(|s| !s.trim().is_empty())
.filter_map(parse_sse_event)
.collect();
assert_eq!(events.len(), 7);
assert!(matches!(events[0], StreamEvent::InputTokens(25)));
assert!(
matches!(&events[1], StreamEvent::ToolUseStart { id, name } if id == "toolu_01ABC" && name == "read_file")
);
assert!(
matches!(&events[2], StreamEvent::ToolUseInputDelta(s) if s == "{\"path\": \"src/")
);
assert!(matches!(&events[3], StreamEvent::ToolUseInputDelta(s) if s == "main.rs\"}"));
assert!(matches!(events[4], StreamEvent::ToolUseDone));
assert!(matches!(events[5], StreamEvent::OutputTokens(15)));
assert!(matches!(events[6], StreamEvent::Done));
}
#[test]
fn test_find_double_newline() {
assert_eq!(find_double_newline(b"abc\n\ndef"), Some(3));

View file

@ -4,16 +4,18 @@ pub use claude::ClaudeProvider;
use futures::Stream;
use crate::core::types::{ConversationMessage, StreamEvent};
use crate::core::types::{ConversationMessage, StreamEvent, ToolDefinition};
/// Trait for model providers that can stream conversation responses.
///
/// Implementors take a conversation history and return a stream of [`StreamEvent`]s.
/// The trait is provider-agnostic -- no Claude-specific types appear here.
pub trait ModelProvider: Send + Sync {
/// Stream a response from the model given the conversation history.
/// Stream a response from the model given the conversation history and
/// available tool definitions. Pass an empty slice if no tools are available.
fn stream<'a>(
&'a self,
messages: &'a [ConversationMessage],
tools: &'a [ToolDefinition],
) -> impl Stream<Item = StreamEvent> + Send + 'a;
}