Tool Use.
This commit is contained in:
parent
6b85ff3cb8
commit
0c1c928498
20 changed files with 1822 additions and 129 deletions
32
Cargo.lock
generated
32
Cargo.lock
generated
|
|
@ -23,6 +23,17 @@ version = "1.0.102"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c"
|
checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "async-trait"
|
||||||
|
version = "0.1.89"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.117",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "atomic"
|
name = "atomic"
|
||||||
version = "0.6.1"
|
version = "0.6.1"
|
||||||
|
|
@ -447,6 +458,12 @@ dependencies = [
|
||||||
"regex",
|
"regex",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fastrand"
|
||||||
|
version = "2.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "filedescriptor"
|
name = "filedescriptor"
|
||||||
version = "0.8.3"
|
version = "0.8.3"
|
||||||
|
|
@ -2024,12 +2041,14 @@ name = "skate"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"async-trait",
|
||||||
"crossterm",
|
"crossterm",
|
||||||
"futures",
|
"futures",
|
||||||
"ratatui",
|
"ratatui",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"tempfile",
|
||||||
"thiserror 2.0.18",
|
"thiserror 2.0.18",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
|
@ -2166,6 +2185,19 @@ dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tempfile"
|
||||||
|
version = "3.26.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "82a72c767771b47409d2345987fda8628641887d5466101319899796367354a0"
|
||||||
|
dependencies = [
|
||||||
|
"fastrand",
|
||||||
|
"getrandom 0.4.1",
|
||||||
|
"once_cell",
|
||||||
|
"rustix",
|
||||||
|
"windows-sys 0.61.2",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "terminfo"
|
name = "terminfo"
|
||||||
version = "0.9.0"
|
version = "0.9.0"
|
||||||
|
|
|
||||||
|
|
@ -15,3 +15,7 @@ tracing = "0.1"
|
||||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
reqwest = { version = "0.13", features = ["stream", "json"] }
|
reqwest = { version = "0.13", features = ["stream", "json"] }
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
|
async-trait = "0.1"
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
tempfile = "3.26.0"
|
||||||
|
|
|
||||||
76
PLAN.md
76
PLAN.md
|
|
@ -1,11 +1,77 @@
|
||||||
# Implementation Plan
|
# Implementation Plan
|
||||||
|
|
||||||
## Phase 3: Tool Execution
|
## Phase 3: Tool Execution
|
||||||
- `Tool` trait, `ToolRegistry`, core tools (`read_file`, `write_file`, `shell_exec`)
|
|
||||||
- Tool definitions in API requests, parse tool-use responses
|
### Step 3.1: Enrich the content model
|
||||||
- Approval gate: core -> TUI pending event -> user approve/deny -> result back
|
- Replace `ConversationMessage { role, content: String }` with content-block model
|
||||||
- Working directory confinement + path validation (no Landlock yet)
|
- Define `ContentBlock` enum: `Text(String)`, `ToolUse { id, name, input: Value }`, `ToolResult { tool_use_id, content: String, is_error: bool }`
|
||||||
- **Done when:** Claude can read, modify files, and run commands with user approval
|
- Change `ConversationMessage.content` from `String` to `Vec<ContentBlock>`
|
||||||
|
- Add `ConversationMessage::text(role, s)` helper to keep existing call sites clean
|
||||||
|
- Update serialization, orchestrator, tests, TUI display
|
||||||
|
- **Files:** `src/core/types.rs`, `src/core/history.rs`
|
||||||
|
- **Done when:** `cargo test` passes with new model; all existing tests updated
|
||||||
|
|
||||||
|
### Step 3.2: Send tool definitions in API requests
|
||||||
|
- Add `ToolDefinition { name, description, input_schema: Value }` (provider-agnostic)
|
||||||
|
- Extend `ModelProvider::stream` to accept `&[ToolDefinition]`
|
||||||
|
- Include `"tools"` array in Claude provider request body
|
||||||
|
- **Files:** `src/provider/mod.rs`, `src/provider/claude.rs`
|
||||||
|
- **Done when:** API responses contain `tool_use` content blocks in raw SSE stream
|
||||||
|
|
||||||
|
### Step 3.3: Parse tool-use blocks from SSE stream
|
||||||
|
- Add `StreamEvent::ToolUseStart { id, name }`, `ToolUseInputDelta(String)`, `ToolUseDone`
|
||||||
|
- Handle `content_block_start` (type "tool_use"), `content_block_delta` (type "input_json_delta"), `content_block_stop` for tool blocks
|
||||||
|
- Track current block type state in SSE parser
|
||||||
|
- **Files:** `src/provider/claude.rs`, `src/core/types.rs`
|
||||||
|
- **Done when:** Unit test with recorded tool-use SSE fixture asserts correct StreamEvent sequence
|
||||||
|
|
||||||
|
### Step 3.4: Orchestrator accumulates tool-use blocks
|
||||||
|
- Accumulate `ToolUseInputDelta` fragments into JSON buffer per tool-use id
|
||||||
|
- On `ToolUseDone`, parse JSON into `ContentBlock::ToolUse`
|
||||||
|
- After `StreamEvent::Done`, if assistant message contains ToolUse blocks, enter tool-execution phase
|
||||||
|
- **Files:** `src/core/orchestrator.rs`
|
||||||
|
- **Done when:** Unit test with mock provider emitting tool-use events produces correct ContentBlocks
|
||||||
|
|
||||||
|
### Step 3.5: Tool trait, registry, and core tools
|
||||||
|
- `Tool` trait: `name()`, `description()`, `input_schema() -> Value`, `execute(input: Value, working_dir: &Path) -> Result<ToolOutput>`
|
||||||
|
- `ToolOutput { content: String, is_error: bool }`
|
||||||
|
- `ToolRegistry`: stores tools, provides `get(name)` and `definitions() -> Vec<ToolDefinition>`
|
||||||
|
- Risk level: `AutoApprove` (reads), `RequiresApproval` (writes/shell)
|
||||||
|
- Implement: `read_file` (auto), `list_directory` (auto), `write_file` (approval), `shell_exec` (approval)
|
||||||
|
- Path validation: `canonicalize` + `starts_with` check, reject paths outside working dir (no Landlock yet)
|
||||||
|
- **Files:** New `src/tools/` module: `mod.rs`, `read_file.rs`, `write_file.rs`, `list_directory.rs`, `shell_exec.rs`
|
||||||
|
- **Done when:** Unit tests pass for each tool in temp dirs; path traversal rejected
|
||||||
|
|
||||||
|
### Step 3.6: Approval gate (TUI <-> core)
|
||||||
|
- New `UIEvent::ToolApprovalRequest { tool_use_id, tool_name, input_summary }`
|
||||||
|
- New `UserAction::ToolApprovalResponse { tool_use_id, approved: bool }`
|
||||||
|
- Orchestrator: check risk level -> auto-approve or send approval request and await response
|
||||||
|
- Denied tools return `ToolResult { is_error: true }` with denial message
|
||||||
|
- TUI: render approval prompt overlay with y/n keybindings
|
||||||
|
- **Files:** `src/core/types.rs`, `src/core/orchestrator.rs`, `src/tui/events.rs`, `src/tui/input.rs`, `src/tui/render.rs`
|
||||||
|
- **Done when:** Integration test: mock provider + mock TUI channel verifies approval flow
|
||||||
|
|
||||||
|
### Step 3.7: Tool results fed back to the model
|
||||||
|
- After executing tool calls: append assistant message (with ToolUse blocks) to history, append user message with ToolResult blocks, re-call provider
|
||||||
|
- Loop: model may respond with more tool calls or text
|
||||||
|
- Cap at max iterations (25) to prevent runaway
|
||||||
|
- **Files:** `src/core/orchestrator.rs`
|
||||||
|
- **Done when:** Integration test: mock provider returns tool-use then text; orchestrator makes two calls. Max-iteration cap tested.
|
||||||
|
|
||||||
|
### Step 3.8: TUI display for tool activity
|
||||||
|
- New `UIEvent::ToolExecuting { tool_name, input_summary }`, `UIEvent::ToolResult { tool_name, output_summary, is_error }`
|
||||||
|
- Render tool calls as distinct visual blocks in conversation view
|
||||||
|
- Render tool results inline (truncated if long)
|
||||||
|
- **Files:** `src/tui/render.rs`, `src/tui/events.rs`
|
||||||
|
- **Done when:** Visual check with `cargo run`; TestBackend test for tool block rendering
|
||||||
|
|
||||||
|
### Phase 3 verification (end-to-end)
|
||||||
|
1. `cargo test` -- all tests pass
|
||||||
|
2. `cargo clippy -- -D warnings` -- zero warnings
|
||||||
|
3. `cargo run -- --project-dir .` -- ask Claude to read a file, approve, see contents
|
||||||
|
4. Ask Claude to write a file -- approve, verify written
|
||||||
|
5. Ask Claude to run a shell command -- approve, verify output
|
||||||
|
6. Deny an approval -- Claude gets denial and responds gracefully
|
||||||
|
|
||||||
## Phase 4: Sandboxing
|
## Phase 4: Sandboxing
|
||||||
- Landlock: read-only system, read-write project dir, network blocked
|
- Landlock: read-only system, read-write project dir, network blocked
|
||||||
|
|
|
||||||
2
TODO.md
2
TODO.md
|
|
@ -1,5 +1,7 @@
|
||||||
# Cleanups
|
# Cleanups
|
||||||
|
|
||||||
|
- Parallelize tool-use execution in `run_turn` -- requires refactoring `Orchestrator` to use `&self` + interior mutability (`Arc<Mutex<...>>` around `event_tx`, `action_rx`, `history`) so multiple futures can borrow self simultaneously via `futures::future::join_all`.
|
||||||
|
|
||||||
- Move keyboard/event reads in the TUI to a separate thread or async/io loop
|
- Move keyboard/event reads in the TUI to a separate thread or async/io loop
|
||||||
- Keep UI and orchestrator in sync (i.e. messages display out of order if you queue up many.)
|
- Keep UI and orchestrator in sync (i.e. messages display out of order if you queue up many.)
|
||||||
- `update_scroll` auto-follows in Insert mode, yanking viewport to bottom on mode switch. Only auto-follow when new content arrives (in `drain_ui_events`), not every frame.
|
- `update_scroll` auto-follows in Insert mode, yanking viewport to bottom on mode switch. Only auto-follow when new content arrives (in `drain_ui_events`), not every frame.
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ use tokio::sync::mpsc;
|
||||||
use crate::core::orchestrator::Orchestrator;
|
use crate::core::orchestrator::Orchestrator;
|
||||||
use crate::core::types::{UIEvent, UserAction};
|
use crate::core::types::{UIEvent, UserAction};
|
||||||
use crate::provider::ClaudeProvider;
|
use crate::provider::ClaudeProvider;
|
||||||
|
use crate::tools::ToolRegistry;
|
||||||
|
|
||||||
/// Model ID sent on every request.
|
/// Model ID sent on every request.
|
||||||
///
|
///
|
||||||
|
|
@ -68,8 +69,15 @@ pub async fn run(project_dir: &Path) -> anyhow::Result<()> {
|
||||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(CHANNEL_CAP);
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(CHANNEL_CAP);
|
||||||
let (event_tx, event_rx) = mpsc::channel::<UIEvent>(CHANNEL_CAP);
|
let (event_tx, event_rx) = mpsc::channel::<UIEvent>(CHANNEL_CAP);
|
||||||
|
|
||||||
// -- Orchestrator (background task) -------------------------------------------
|
// -- Tools & Orchestrator (background task) ------------------------------------
|
||||||
let orch = Orchestrator::new(provider, action_rx, event_tx);
|
let tool_registry = ToolRegistry::default_tools();
|
||||||
|
let orch = Orchestrator::new(
|
||||||
|
provider,
|
||||||
|
tool_registry,
|
||||||
|
project_dir.to_path_buf(),
|
||||||
|
action_rx,
|
||||||
|
event_tx,
|
||||||
|
);
|
||||||
tokio::spawn(orch.run());
|
tokio::spawn(orch.run());
|
||||||
|
|
||||||
// -- TUI (foreground task) ----------------------------------------------------
|
// -- TUI (foreground task) ----------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -61,30 +61,21 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn push_and_read_roundtrip() {
|
fn push_and_read_roundtrip() {
|
||||||
let mut history = ConversationHistory::new();
|
let mut history = ConversationHistory::new();
|
||||||
history.push(ConversationMessage {
|
history.push(ConversationMessage::text(Role::User, "hello"));
|
||||||
role: Role::User,
|
history.push(ConversationMessage::text(Role::Assistant, "hi there"));
|
||||||
content: "hello".to_string(),
|
|
||||||
});
|
|
||||||
history.push(ConversationMessage {
|
|
||||||
role: Role::Assistant,
|
|
||||||
content: "hi there".to_string(),
|
|
||||||
});
|
|
||||||
|
|
||||||
let msgs = history.messages();
|
let msgs = history.messages();
|
||||||
assert_eq!(msgs.len(), 2);
|
assert_eq!(msgs.len(), 2);
|
||||||
assert_eq!(msgs[0].role, Role::User);
|
assert_eq!(msgs[0].role, Role::User);
|
||||||
assert_eq!(msgs[0].content, "hello");
|
assert_eq!(msgs[0].text_content(), "hello");
|
||||||
assert_eq!(msgs[1].role, Role::Assistant);
|
assert_eq!(msgs[1].role, Role::Assistant);
|
||||||
assert_eq!(msgs[1].content, "hi there");
|
assert_eq!(msgs[1].text_content(), "hi there");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn clear_empties_history() {
|
fn clear_empties_history() {
|
||||||
let mut history = ConversationHistory::new();
|
let mut history = ConversationHistory::new();
|
||||||
history.push(ConversationMessage {
|
history.push(ConversationMessage::text(Role::User, "hello"));
|
||||||
role: Role::User,
|
|
||||||
content: "hello".to_string(),
|
|
||||||
});
|
|
||||||
history.clear();
|
history.clear();
|
||||||
assert!(history.messages().is_empty());
|
assert!(history.messages().is_empty());
|
||||||
}
|
}
|
||||||
|
|
@ -93,13 +84,10 @@ mod tests {
|
||||||
fn messages_preserves_insertion_order() {
|
fn messages_preserves_insertion_order() {
|
||||||
let mut history = ConversationHistory::new();
|
let mut history = ConversationHistory::new();
|
||||||
for i in 0u32..5 {
|
for i in 0u32..5 {
|
||||||
history.push(ConversationMessage {
|
history.push(ConversationMessage::text(Role::User, format!("msg {i}")));
|
||||||
role: Role::User,
|
|
||||||
content: format!("msg {i}"),
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
for (i, msg) in history.messages().iter().enumerate() {
|
for (i, msg) in history.messages().iter().enumerate() {
|
||||||
assert_eq!(msg.content, format!("msg {i}"));
|
assert_eq!(msg.text_content(), format!("msg {i}"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,53 @@
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use tracing::debug;
|
use tracing::{debug, warn};
|
||||||
|
|
||||||
use crate::core::history::ConversationHistory;
|
use crate::core::history::ConversationHistory;
|
||||||
use crate::core::types::{ConversationMessage, Role, StreamEvent, UIEvent, UserAction};
|
use crate::core::types::{
|
||||||
|
ContentBlock, ConversationMessage, Role, StreamEvent, ToolDefinition, UIEvent, UserAction,
|
||||||
|
};
|
||||||
use crate::provider::ModelProvider;
|
use crate::provider::ModelProvider;
|
||||||
|
use crate::tools::{RiskLevel, ToolOutput, ToolRegistry};
|
||||||
|
|
||||||
|
/// Accumulates data for a single tool-use block while it is being streamed.
|
||||||
|
///
|
||||||
|
/// Created on `ToolUseStart`, populated by `ToolUseInputDelta` fragments, and
|
||||||
|
/// consumed on `ToolUseDone` via `TryFrom<ActiveToolUse> for ContentBlock`.
|
||||||
|
struct ActiveToolUse {
|
||||||
|
id: String,
|
||||||
|
name: String,
|
||||||
|
json_buf: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ActiveToolUse {
|
||||||
|
/// Start accumulating a new tool-use block identified by `id` and `name`.
|
||||||
|
fn new(id: String, name: String) -> Self {
|
||||||
|
Self {
|
||||||
|
id,
|
||||||
|
name,
|
||||||
|
json_buf: String::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Append a JSON input fragment received from a `ToolUseInputDelta` event.
|
||||||
|
fn append(&mut self, chunk: &str) {
|
||||||
|
self.json_buf.push_str(chunk);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<ActiveToolUse> for ContentBlock {
|
||||||
|
type Error = serde_json::Error;
|
||||||
|
|
||||||
|
/// Parse the accumulated JSON buffer and produce a `ContentBlock::ToolUse`.
|
||||||
|
fn try_from(t: ActiveToolUse) -> Result<Self, Self::Error> {
|
||||||
|
let input = serde_json::from_str(&t.json_buf)?;
|
||||||
|
Ok(ContentBlock::ToolUse {
|
||||||
|
id: t.id,
|
||||||
|
name: t.name,
|
||||||
|
input,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Drives the conversation loop between the TUI frontend and the model provider.
|
/// Drives the conversation loop between the TUI frontend and the model provider.
|
||||||
///
|
///
|
||||||
|
|
@ -38,28 +81,321 @@ use crate::provider::ModelProvider;
|
||||||
/// OutputTokens -> log at debug level
|
/// OutputTokens -> log at debug level
|
||||||
/// 3. Quit -> return
|
/// 3. Quit -> return
|
||||||
/// ```
|
/// ```
|
||||||
|
/// The result of consuming one provider stream (one assistant turn).
|
||||||
|
enum StreamResult {
|
||||||
|
/// The stream completed successfully with these content blocks.
|
||||||
|
Done(Vec<ContentBlock>),
|
||||||
|
/// The stream produced an error.
|
||||||
|
Error(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Maximum number of tool-use loop iterations per user message.
|
||||||
|
const MAX_TOOL_ITERATIONS: usize = 25;
|
||||||
|
|
||||||
|
/// Truncate a string to `max_len` Unicode scalar values, appending "..." if
|
||||||
|
/// truncated. Uses `char_indices` so that multibyte characters are never split
|
||||||
|
/// at a byte boundary (which would panic on a bare slice).
|
||||||
|
fn truncate(s: &str, max_len: usize) -> String {
|
||||||
|
match s.char_indices().nth(max_len) {
|
||||||
|
Some((i, _)) => format!("{}...", &s[..i]),
|
||||||
|
None => s.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct Orchestrator<P> {
|
pub struct Orchestrator<P> {
|
||||||
history: ConversationHistory,
|
history: ConversationHistory,
|
||||||
provider: P,
|
provider: P,
|
||||||
|
tool_registry: ToolRegistry,
|
||||||
|
working_dir: std::path::PathBuf,
|
||||||
action_rx: mpsc::Receiver<UserAction>,
|
action_rx: mpsc::Receiver<UserAction>,
|
||||||
event_tx: mpsc::Sender<UIEvent>,
|
event_tx: mpsc::Sender<UIEvent>,
|
||||||
|
/// Messages typed by the user while an approval prompt is open. They are
|
||||||
|
/// queued here and replayed as new turns once the current turn completes.
|
||||||
|
queued_messages: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<P: ModelProvider> Orchestrator<P> {
|
impl<P: ModelProvider> Orchestrator<P> {
|
||||||
/// Construct an orchestrator using the given provider and channel endpoints.
|
/// Construct an orchestrator using the given provider and channel endpoints.
|
||||||
pub fn new(
|
pub fn new(
|
||||||
provider: P,
|
provider: P,
|
||||||
|
tool_registry: ToolRegistry,
|
||||||
|
working_dir: std::path::PathBuf,
|
||||||
action_rx: mpsc::Receiver<UserAction>,
|
action_rx: mpsc::Receiver<UserAction>,
|
||||||
event_tx: mpsc::Sender<UIEvent>,
|
event_tx: mpsc::Sender<UIEvent>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
history: ConversationHistory::new(),
|
history: ConversationHistory::new(),
|
||||||
provider,
|
provider,
|
||||||
|
tool_registry,
|
||||||
|
working_dir,
|
||||||
action_rx,
|
action_rx,
|
||||||
event_tx,
|
event_tx,
|
||||||
|
queued_messages: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Consume the provider's stream for one turn, returning the content blocks
|
||||||
|
/// produced by the assistant (text and/or tool-use) or an error.
|
||||||
|
///
|
||||||
|
/// Tool-use input JSON is accumulated from `ToolUseInputDelta` fragments and
|
||||||
|
/// parsed into `ContentBlock::ToolUse` on `ToolUseDone`.
|
||||||
|
async fn consume_stream(
|
||||||
|
&self,
|
||||||
|
messages: &[ConversationMessage],
|
||||||
|
tools: &[ToolDefinition],
|
||||||
|
) -> StreamResult {
|
||||||
|
let mut blocks: Vec<ContentBlock> = Vec::new();
|
||||||
|
let mut text_buf = String::new();
|
||||||
|
let mut active_tool: Option<ActiveToolUse> = None;
|
||||||
|
|
||||||
|
let mut stream = Box::pin(self.provider.stream(messages, tools));
|
||||||
|
|
||||||
|
while let Some(event) = stream.next().await {
|
||||||
|
match event {
|
||||||
|
StreamEvent::TextDelta(chunk) => {
|
||||||
|
text_buf.push_str(&chunk);
|
||||||
|
let _ = self.event_tx.send(UIEvent::StreamDelta(chunk)).await;
|
||||||
|
}
|
||||||
|
StreamEvent::ToolUseStart { id, name } => {
|
||||||
|
// Flush any accumulated text before starting tool block.
|
||||||
|
if !text_buf.is_empty() {
|
||||||
|
blocks.push(ContentBlock::Text {
|
||||||
|
text: std::mem::take(&mut text_buf),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
active_tool = Some(ActiveToolUse::new(id, name));
|
||||||
|
}
|
||||||
|
StreamEvent::ToolUseInputDelta(chunk) => {
|
||||||
|
if let Some(t) = &mut active_tool {
|
||||||
|
t.append(&chunk);
|
||||||
|
} else {
|
||||||
|
warn!(
|
||||||
|
"received ToolUseInputDelta outside of an active tool-use block -- ignoring"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
StreamEvent::ToolUseDone => {
|
||||||
|
if let Some(t) = active_tool.take() {
|
||||||
|
match ContentBlock::try_from(t) {
|
||||||
|
Ok(block) => blocks.push(block),
|
||||||
|
Err(e) => {
|
||||||
|
return StreamResult::Error(format!(
|
||||||
|
"failed to parse tool input JSON: {e}"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
warn!(
|
||||||
|
"received ToolUseDone outside of an active tool-use block -- ignoring"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
StreamEvent::Done => {
|
||||||
|
// Flush trailing text.
|
||||||
|
if !text_buf.is_empty() {
|
||||||
|
blocks.push(ContentBlock::Text { text: text_buf });
|
||||||
|
}
|
||||||
|
return StreamResult::Done(blocks);
|
||||||
|
}
|
||||||
|
StreamEvent::Error(msg) => {
|
||||||
|
return StreamResult::Error(msg);
|
||||||
|
}
|
||||||
|
StreamEvent::InputTokens(n) => {
|
||||||
|
debug!(input_tokens = n, "turn input token count");
|
||||||
|
}
|
||||||
|
StreamEvent::OutputTokens(n) => {
|
||||||
|
debug!(output_tokens = n, "turn output token count");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream ended without Done -- treat as error.
|
||||||
|
StreamResult::Error("stream ended unexpectedly".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute one complete turn: stream from provider, execute tools if needed,
|
||||||
|
/// loop back until the model produces a text-only response or we hit the
|
||||||
|
/// iteration limit.
|
||||||
|
async fn run_turn(&mut self) {
|
||||||
|
let tool_defs = self.tool_registry.definitions();
|
||||||
|
|
||||||
|
for _ in 0..MAX_TOOL_ITERATIONS {
|
||||||
|
let messages = self.history.messages().to_vec();
|
||||||
|
let result = self.consume_stream(&messages, &tool_defs).await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
StreamResult::Error(msg) => {
|
||||||
|
let _ = self.event_tx.send(UIEvent::Error(msg)).await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
StreamResult::Done(blocks) => {
|
||||||
|
let has_tool_use = blocks
|
||||||
|
.iter()
|
||||||
|
.any(|b| matches!(b, ContentBlock::ToolUse { .. }));
|
||||||
|
|
||||||
|
// Append the assistant message to history.
|
||||||
|
self.history.push(ConversationMessage {
|
||||||
|
role: Role::Assistant,
|
||||||
|
content: blocks.clone(),
|
||||||
|
});
|
||||||
|
|
||||||
|
if !has_tool_use {
|
||||||
|
let _ = self.event_tx.send(UIEvent::TurnComplete).await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: execute tool-use blocks in parallel -- requires
|
||||||
|
// refactoring Orchestrator to use &self + interior
|
||||||
|
// mutability (Arc<Mutex<...>> around event_tx, action_rx,
|
||||||
|
// history) so that multiple futures can borrow self
|
||||||
|
// simultaneously via futures::future::join_all.
|
||||||
|
// Execute each tool-use block and collect results.
|
||||||
|
let mut tool_results: Vec<ContentBlock> = Vec::new();
|
||||||
|
|
||||||
|
for block in &blocks {
|
||||||
|
if let ContentBlock::ToolUse { id, name, input } = block {
|
||||||
|
let result = self.execute_tool_with_approval(id, name, input).await;
|
||||||
|
tool_results.push(ContentBlock::ToolResult {
|
||||||
|
tool_use_id: id.clone(),
|
||||||
|
content: result.content,
|
||||||
|
is_error: result.is_error,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append tool results as a user message and loop.
|
||||||
|
self.history.push(ConversationMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: tool_results,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
warn!("tool-use loop reached max iterations ({MAX_TOOL_ITERATIONS})");
|
||||||
|
let _ = self
|
||||||
|
.event_tx
|
||||||
|
.send(UIEvent::Error(
|
||||||
|
"tool-use loop reached maximum iterations".to_string(),
|
||||||
|
))
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute a single tool, handling approval if needed.
|
||||||
|
///
|
||||||
|
/// For auto-approve tools, executes immediately. For tools requiring
|
||||||
|
/// approval, sends a request to the TUI and waits for a response.
|
||||||
|
async fn execute_tool_with_approval(
|
||||||
|
&mut self,
|
||||||
|
tool_use_id: &str,
|
||||||
|
tool_name: &str,
|
||||||
|
input: &serde_json::Value,
|
||||||
|
) -> ToolOutput {
|
||||||
|
// Extract tool info upfront to avoid holding a borrow on self across
|
||||||
|
// the mutable wait_for_approval call.
|
||||||
|
let risk = match self.tool_registry.get(tool_name) {
|
||||||
|
Some(t) => t.risk_level(),
|
||||||
|
None => {
|
||||||
|
return ToolOutput {
|
||||||
|
content: format!("unknown tool: {tool_name}"),
|
||||||
|
is_error: true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let input_summary = serde_json::to_string(input).unwrap_or_default();
|
||||||
|
|
||||||
|
// Check approval.
|
||||||
|
let approved = match risk {
|
||||||
|
RiskLevel::AutoApprove => {
|
||||||
|
let _ = self
|
||||||
|
.event_tx
|
||||||
|
.send(UIEvent::ToolExecuting {
|
||||||
|
tool_name: tool_name.to_string(),
|
||||||
|
input_summary: input_summary.clone(),
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
true
|
||||||
|
}
|
||||||
|
RiskLevel::RequiresApproval => {
|
||||||
|
let _ = self
|
||||||
|
.event_tx
|
||||||
|
.send(UIEvent::ToolApprovalRequest {
|
||||||
|
tool_use_id: tool_use_id.to_string(),
|
||||||
|
tool_name: tool_name.to_string(),
|
||||||
|
input_summary: input_summary.clone(),
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Wait for approval response from TUI.
|
||||||
|
self.wait_for_approval(tool_use_id).await
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if !approved {
|
||||||
|
return ToolOutput {
|
||||||
|
content: "tool execution denied by user".to_string(),
|
||||||
|
is_error: true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-fetch tool for execution (borrow was released above).
|
||||||
|
let tool = self.tool_registry.get(tool_name).unwrap();
|
||||||
|
match tool.execute(input, &self.working_dir).await {
|
||||||
|
Ok(output) => {
|
||||||
|
let _ = self
|
||||||
|
.event_tx
|
||||||
|
.send(UIEvent::ToolResult {
|
||||||
|
tool_name: tool_name.to_string(),
|
||||||
|
output_summary: truncate(&output.content, 200),
|
||||||
|
is_error: output.is_error,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
output
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let msg = e.to_string();
|
||||||
|
let _ = self
|
||||||
|
.event_tx
|
||||||
|
.send(UIEvent::ToolResult {
|
||||||
|
tool_name: tool_name.to_string(),
|
||||||
|
output_summary: msg.clone(),
|
||||||
|
is_error: true,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
ToolOutput {
|
||||||
|
content: msg,
|
||||||
|
is_error: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wait for a `ToolApprovalResponse` matching `tool_use_id` from the TUI.
|
||||||
|
///
|
||||||
|
/// Any `SendMessage` actions that arrive during the wait are pushed onto
|
||||||
|
/// `self.queued_messages` rather than discarded. They are replayed as new
|
||||||
|
/// turns once the current tool-use loop completes (see `run()`).
|
||||||
|
///
|
||||||
|
/// Returns `true` if approved, `false` if denied or the channel closes.
|
||||||
|
async fn wait_for_approval(&mut self, tool_use_id: &str) -> bool {
|
||||||
|
while let Some(action) = self.action_rx.recv().await {
|
||||||
|
match action {
|
||||||
|
UserAction::ToolApprovalResponse {
|
||||||
|
tool_use_id: id,
|
||||||
|
approved,
|
||||||
|
} if id == tool_use_id => {
|
||||||
|
return approved;
|
||||||
|
}
|
||||||
|
UserAction::Quit => return false,
|
||||||
|
UserAction::SendMessage(text) => {
|
||||||
|
self.queued_messages.push(text);
|
||||||
|
}
|
||||||
|
_ => {} // discard stale approvals / ClearHistory during wait
|
||||||
|
}
|
||||||
|
}
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
/// Run the orchestrator until the user quits or the `action_rx` channel closes.
|
/// Run the orchestrator until the user quits or the `action_rx` channel closes.
|
||||||
pub async fn run(mut self) {
|
pub async fn run(mut self) {
|
||||||
while let Some(action) = self.action_rx.recv().await {
|
while let Some(action) = self.action_rx.recv().await {
|
||||||
|
|
@ -69,62 +405,24 @@ impl<P: ModelProvider> Orchestrator<P> {
|
||||||
self.history.clear();
|
self.history.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Approval responses are handled inline during tool execution,
|
||||||
|
// not in the main action loop. If one arrives here it's stale.
|
||||||
|
UserAction::ToolApprovalResponse { .. } => {}
|
||||||
|
|
||||||
UserAction::SendMessage(text) => {
|
UserAction::SendMessage(text) => {
|
||||||
// Push the user message before snapshotting, so providers
|
self.history
|
||||||
// see the full conversation including the new message.
|
.push(ConversationMessage::text(Role::User, text));
|
||||||
self.history.push(ConversationMessage {
|
self.run_turn().await;
|
||||||
role: Role::User,
|
|
||||||
content: text,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Snapshot history into an owned Vec so the stream does not
|
// Drain any messages queued while an approval prompt was
|
||||||
// borrow from `self.history` -- this lets us mutably update
|
// open. Each queued message is a full turn in sequence.
|
||||||
// `self.history` once the stream loop finishes.
|
while !self.queued_messages.is_empty() {
|
||||||
let messages: Vec<ConversationMessage> = self.history.messages().to_vec();
|
let queued = std::mem::take(&mut self.queued_messages);
|
||||||
|
for msg in queued {
|
||||||
let mut accumulated = String::new();
|
self.history
|
||||||
// Capture terminal stream state outside the loop so we can
|
.push(ConversationMessage::text(Role::User, msg));
|
||||||
// act on it after `stream` is dropped.
|
self.run_turn().await;
|
||||||
let mut turn_done = false;
|
|
||||||
let mut turn_error: Option<String> = None;
|
|
||||||
|
|
||||||
{
|
|
||||||
let mut stream = Box::pin(self.provider.stream(&messages));
|
|
||||||
|
|
||||||
while let Some(event) = stream.next().await {
|
|
||||||
match event {
|
|
||||||
StreamEvent::TextDelta(chunk) => {
|
|
||||||
accumulated.push_str(&chunk);
|
|
||||||
let _ = self.event_tx.send(UIEvent::StreamDelta(chunk)).await;
|
|
||||||
}
|
}
|
||||||
StreamEvent::Done => {
|
|
||||||
turn_done = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
StreamEvent::Error(msg) => {
|
|
||||||
turn_error = Some(msg);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
StreamEvent::InputTokens(n) => {
|
|
||||||
debug!(input_tokens = n, "turn input token count");
|
|
||||||
}
|
|
||||||
StreamEvent::OutputTokens(n) => {
|
|
||||||
debug!(output_tokens = n, "turn output token count");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// `stream` is dropped here, releasing the borrow on
|
|
||||||
// `self.provider` and `messages`.
|
|
||||||
}
|
|
||||||
|
|
||||||
if turn_done {
|
|
||||||
self.history.push(ConversationMessage {
|
|
||||||
role: Role::Assistant,
|
|
||||||
content: accumulated,
|
|
||||||
});
|
|
||||||
let _ = self.event_tx.send(UIEvent::TurnComplete).await;
|
|
||||||
} else if let Some(msg) = turn_error {
|
|
||||||
let _ = self.event_tx.send(UIEvent::Error(msg)).await;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -135,6 +433,8 @@ impl<P: ModelProvider> Orchestrator<P> {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::core::types::ToolDefinition;
|
||||||
|
use crate::tools::ToolRegistry;
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
|
|
@ -155,11 +455,27 @@ mod tests {
|
||||||
fn stream<'a>(
|
fn stream<'a>(
|
||||||
&'a self,
|
&'a self,
|
||||||
_messages: &'a [ConversationMessage],
|
_messages: &'a [ConversationMessage],
|
||||||
|
_tools: &'a [ToolDefinition],
|
||||||
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
||||||
futures::stream::iter(self.events.clone())
|
futures::stream::iter(self.events.clone())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create an Orchestrator with no tools for testing text-only flows.
|
||||||
|
fn test_orchestrator<P: ModelProvider>(
|
||||||
|
provider: P,
|
||||||
|
action_rx: mpsc::Receiver<UserAction>,
|
||||||
|
event_tx: mpsc::Sender<UIEvent>,
|
||||||
|
) -> Orchestrator<P> {
|
||||||
|
Orchestrator::new(
|
||||||
|
provider,
|
||||||
|
ToolRegistry::empty(),
|
||||||
|
std::path::PathBuf::from("/tmp"),
|
||||||
|
action_rx,
|
||||||
|
event_tx,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
/// Collect all UIEvents that arrive within one orchestrator turn, stopping
|
/// Collect all UIEvents that arrive within one orchestrator turn, stopping
|
||||||
/// when the channel is drained after a `TurnComplete` or `Error`.
|
/// when the channel is drained after a `TurnComplete` or `Error`.
|
||||||
async fn collect_events(rx: &mut mpsc::Receiver<UIEvent>) -> Vec<UIEvent> {
|
async fn collect_events(rx: &mut mpsc::Receiver<UIEvent>) -> Vec<UIEvent> {
|
||||||
|
|
@ -195,7 +511,7 @@ mod tests {
|
||||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||||
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(16);
|
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(16);
|
||||||
|
|
||||||
let orch = Orchestrator::new(provider, action_rx, event_tx);
|
let orch = test_orchestrator(provider, action_rx, event_tx);
|
||||||
let handle = tokio::spawn(orch.run());
|
let handle = tokio::spawn(orch.run());
|
||||||
|
|
||||||
action_tx
|
action_tx
|
||||||
|
|
@ -233,7 +549,7 @@ mod tests {
|
||||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||||
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(16);
|
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(16);
|
||||||
|
|
||||||
let orch = Orchestrator::new(provider, action_rx, event_tx);
|
let orch = test_orchestrator(provider, action_rx, event_tx);
|
||||||
let handle = tokio::spawn(orch.run());
|
let handle = tokio::spawn(orch.run());
|
||||||
|
|
||||||
action_tx
|
action_tx
|
||||||
|
|
@ -264,6 +580,7 @@ mod tests {
|
||||||
fn stream<'a>(
|
fn stream<'a>(
|
||||||
&'a self,
|
&'a self,
|
||||||
_messages: &'a [ConversationMessage],
|
_messages: &'a [ConversationMessage],
|
||||||
|
_tools: &'a [ToolDefinition],
|
||||||
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
||||||
panic!("stream() must not be called after Quit");
|
panic!("stream() must not be called after Quit");
|
||||||
#[allow(unreachable_code)]
|
#[allow(unreachable_code)]
|
||||||
|
|
@ -274,7 +591,7 @@ mod tests {
|
||||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||||
let (event_tx, _event_rx) = mpsc::channel::<UIEvent>(8);
|
let (event_tx, _event_rx) = mpsc::channel::<UIEvent>(8);
|
||||||
|
|
||||||
let orch = Orchestrator::new(NeverCalledProvider, action_rx, event_tx);
|
let orch = test_orchestrator(NeverCalledProvider, action_rx, event_tx);
|
||||||
let handle = tokio::spawn(orch.run());
|
let handle = tokio::spawn(orch.run());
|
||||||
|
|
||||||
action_tx.send(UserAction::Quit).await.unwrap();
|
action_tx.send(UserAction::Quit).await.unwrap();
|
||||||
|
|
@ -311,6 +628,7 @@ mod tests {
|
||||||
fn stream<'a>(
|
fn stream<'a>(
|
||||||
&'a self,
|
&'a self,
|
||||||
_messages: &'a [ConversationMessage],
|
_messages: &'a [ConversationMessage],
|
||||||
|
_tools: &'a [ToolDefinition],
|
||||||
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
||||||
let events = self.turns.lock().unwrap().pop_front().unwrap_or_default();
|
let events = self.turns.lock().unwrap().pop_front().unwrap_or_default();
|
||||||
futures::stream::iter(events)
|
futures::stream::iter(events)
|
||||||
|
|
@ -326,7 +644,7 @@ mod tests {
|
||||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||||
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(32);
|
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(32);
|
||||||
|
|
||||||
let orch = Orchestrator::new(provider, action_rx, event_tx);
|
let orch = test_orchestrator(provider, action_rx, event_tx);
|
||||||
let handle = tokio::spawn(orch.run());
|
let handle = tokio::spawn(orch.run());
|
||||||
|
|
||||||
// First turn.
|
// First turn.
|
||||||
|
|
@ -350,4 +668,263 @@ mod tests {
|
||||||
action_tx.send(UserAction::Quit).await.unwrap();
|
action_tx.send(UserAction::Quit).await.unwrap();
|
||||||
handle.await.unwrap();
|
handle.await.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -- tool-use accumulation ----------------------------------------------------
|
||||||
|
|
||||||
|
/// When the provider emits tool-use events, the orchestrator executes the
|
||||||
|
/// tool (auto-approve since read_file is AutoApprove), feeds the result back,
|
||||||
|
/// and the provider's second call returns text.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn tool_use_loop_executes_and_feeds_back() {
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
struct MultiCallMock {
|
||||||
|
turns: Arc<Mutex<VecDeque<Vec<StreamEvent>>>>,
|
||||||
|
}
|
||||||
|
impl ModelProvider for MultiCallMock {
|
||||||
|
fn stream<'a>(
|
||||||
|
&'a self,
|
||||||
|
_messages: &'a [ConversationMessage],
|
||||||
|
_tools: &'a [ToolDefinition],
|
||||||
|
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
||||||
|
let events = self.turns.lock().unwrap().pop_front().unwrap_or_default();
|
||||||
|
futures::stream::iter(events)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let turns = Arc::new(Mutex::new(VecDeque::from([
|
||||||
|
// First call: tool use
|
||||||
|
vec![
|
||||||
|
StreamEvent::TextDelta("Let me read that.".to_string()),
|
||||||
|
StreamEvent::ToolUseDone, // spurious stop for text block
|
||||||
|
StreamEvent::ToolUseStart {
|
||||||
|
id: "toolu_01".to_string(),
|
||||||
|
name: "read_file".to_string(),
|
||||||
|
},
|
||||||
|
StreamEvent::ToolUseInputDelta("{\"path\":\"Cargo.toml\"}".to_string()),
|
||||||
|
StreamEvent::ToolUseDone,
|
||||||
|
StreamEvent::Done,
|
||||||
|
],
|
||||||
|
// Second call: text response after tool result
|
||||||
|
vec![
|
||||||
|
StreamEvent::TextDelta("Here's the file.".to_string()),
|
||||||
|
StreamEvent::Done,
|
||||||
|
],
|
||||||
|
])));
|
||||||
|
|
||||||
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||||
|
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(32);
|
||||||
|
|
||||||
|
// Use a real ToolRegistry so read_file works.
|
||||||
|
let dir = tempfile::TempDir::new().unwrap();
|
||||||
|
std::fs::write(dir.path().join("Cargo.toml"), "[package]\nname = \"test\"").unwrap();
|
||||||
|
let orch = Orchestrator::new(
|
||||||
|
MultiCallMock { turns },
|
||||||
|
ToolRegistry::default_tools(),
|
||||||
|
dir.path().to_path_buf(),
|
||||||
|
action_rx,
|
||||||
|
event_tx,
|
||||||
|
);
|
||||||
|
let handle = tokio::spawn(orch.run());
|
||||||
|
|
||||||
|
action_tx
|
||||||
|
.send(UserAction::SendMessage("read Cargo.toml".to_string()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
let events = collect_events(&mut event_rx).await;
|
||||||
|
// Should see text deltas from both calls, tool events, and TurnComplete
|
||||||
|
assert!(
|
||||||
|
events
|
||||||
|
.iter()
|
||||||
|
.any(|e| matches!(e, UIEvent::StreamDelta(s) if s == "Let me read that.")),
|
||||||
|
"expected first text delta"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
events
|
||||||
|
.iter()
|
||||||
|
.any(|e| matches!(e, UIEvent::StreamDelta(s) if s == "Here's the file.")),
|
||||||
|
"expected second text delta"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
matches!(events.last(), Some(UIEvent::TurnComplete)),
|
||||||
|
"expected TurnComplete, got: {events:?}"
|
||||||
|
);
|
||||||
|
|
||||||
|
action_tx.send(UserAction::Quit).await.unwrap();
|
||||||
|
handle.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// -- truncate -----------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Truncating at a boundary inside a multibyte character must not panic.
|
||||||
|
/// "hello " is 6 bytes; the emoji is 4 bytes. Requesting max_len=8 would
|
||||||
|
/// previously slice at byte 8, splitting the emoji. The fixed version uses
|
||||||
|
/// char_indices so the cut falls on a char boundary.
|
||||||
|
#[test]
|
||||||
|
fn truncate_multibyte_does_not_panic() {
|
||||||
|
let s = "hello \u{1F30D} world"; // globe emoji is 4 bytes
|
||||||
|
// Should not panic and the result should end with "..."
|
||||||
|
let result = truncate(s, 8);
|
||||||
|
assert!(
|
||||||
|
result.ends_with("..."),
|
||||||
|
"expected truncation suffix: {result}"
|
||||||
|
);
|
||||||
|
// The string before "..." must be valid UTF-8 (guaranteed by the fix).
|
||||||
|
let prefix = result.trim_end_matches("...");
|
||||||
|
assert!(std::str::from_utf8(prefix.as_bytes()).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
// -- JSON parse error ---------------------------------------------------------
|
||||||
|
|
||||||
|
/// When the provider emits malformed tool-input JSON, the orchestrator must
|
||||||
|
/// surface a descriptive UIEvent::Error rather than silently using Null.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn malformed_tool_json_surfaces_error() {
|
||||||
|
let provider = MockProvider::new(vec![
|
||||||
|
StreamEvent::ToolUseStart {
|
||||||
|
id: "toolu_bad".to_string(),
|
||||||
|
name: "read_file".to_string(),
|
||||||
|
},
|
||||||
|
StreamEvent::ToolUseInputDelta("not valid json{{".to_string()),
|
||||||
|
StreamEvent::ToolUseDone,
|
||||||
|
StreamEvent::Done,
|
||||||
|
]);
|
||||||
|
|
||||||
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||||
|
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(16);
|
||||||
|
|
||||||
|
let orch = test_orchestrator(provider, action_rx, event_tx);
|
||||||
|
let handle = tokio::spawn(orch.run());
|
||||||
|
|
||||||
|
action_tx
|
||||||
|
.send(UserAction::SendMessage("go".to_string()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||||
|
|
||||||
|
let events = collect_events(&mut event_rx).await;
|
||||||
|
assert!(
|
||||||
|
events
|
||||||
|
.iter()
|
||||||
|
.any(|e| matches!(e, UIEvent::Error(msg) if msg.contains("tool input JSON"))),
|
||||||
|
"expected JSON parse error event, got: {events:?}"
|
||||||
|
);
|
||||||
|
|
||||||
|
action_tx.send(UserAction::Quit).await.unwrap();
|
||||||
|
handle.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// -- queued messages ----------------------------------------------------------
|
||||||
|
|
||||||
|
/// A SendMessage sent while an approval prompt is open must be processed
|
||||||
|
/// after the current turn completes, not silently dropped.
|
||||||
|
///
|
||||||
|
/// This test uses a RequiresApproval tool (write_file) so that the
|
||||||
|
/// orchestrator blocks in wait_for_approval. While blocked, we send a second
|
||||||
|
/// message. After approving (denied here for simplicity), the queued message
|
||||||
|
/// should still be processed.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn send_message_during_approval_is_queued_and_processed() {
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
struct MultiCallMock {
|
||||||
|
turns: Arc<Mutex<VecDeque<Vec<StreamEvent>>>>,
|
||||||
|
}
|
||||||
|
impl ModelProvider for MultiCallMock {
|
||||||
|
fn stream<'a>(
|
||||||
|
&'a self,
|
||||||
|
_messages: &'a [ConversationMessage],
|
||||||
|
_tools: &'a [ToolDefinition],
|
||||||
|
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
||||||
|
let events = self.turns.lock().unwrap().pop_front().unwrap_or_default();
|
||||||
|
futures::stream::iter(events)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let turns = Arc::new(Mutex::new(VecDeque::from([
|
||||||
|
// Turn 1: requests write_file (RequiresApproval)
|
||||||
|
vec![
|
||||||
|
StreamEvent::ToolUseStart {
|
||||||
|
id: "toolu_w".to_string(),
|
||||||
|
name: "write_file".to_string(),
|
||||||
|
},
|
||||||
|
StreamEvent::ToolUseInputDelta(
|
||||||
|
"{\"path\":\"x.txt\",\"content\":\"hi\"}".to_string(),
|
||||||
|
),
|
||||||
|
StreamEvent::ToolUseDone,
|
||||||
|
StreamEvent::Done,
|
||||||
|
],
|
||||||
|
// Turn 1 second iteration after tool result (denied)
|
||||||
|
vec![
|
||||||
|
StreamEvent::TextDelta("ok denied".to_string()),
|
||||||
|
StreamEvent::Done,
|
||||||
|
],
|
||||||
|
// Turn 2: the queued message
|
||||||
|
vec![
|
||||||
|
StreamEvent::TextDelta("queued reply".to_string()),
|
||||||
|
StreamEvent::Done,
|
||||||
|
],
|
||||||
|
])));
|
||||||
|
|
||||||
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(16);
|
||||||
|
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(64);
|
||||||
|
|
||||||
|
let dir = tempfile::TempDir::new().unwrap();
|
||||||
|
let orch = Orchestrator::new(
|
||||||
|
MultiCallMock { turns },
|
||||||
|
ToolRegistry::default_tools(),
|
||||||
|
dir.path().to_path_buf(),
|
||||||
|
action_rx,
|
||||||
|
event_tx,
|
||||||
|
);
|
||||||
|
let handle = tokio::spawn(orch.run());
|
||||||
|
|
||||||
|
// Start turn 1 -- orchestrator will block on approval.
|
||||||
|
action_tx
|
||||||
|
.send(UserAction::SendMessage("turn1".to_string()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Let the orchestrator reach wait_for_approval.
|
||||||
|
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||||
|
|
||||||
|
// Send a message while blocked -- should be queued.
|
||||||
|
action_tx
|
||||||
|
.send(UserAction::SendMessage("queued".to_string()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Deny the tool -- unblocks wait_for_approval.
|
||||||
|
action_tx
|
||||||
|
.send(UserAction::ToolApprovalResponse {
|
||||||
|
tool_use_id: "toolu_w".to_string(),
|
||||||
|
approved: false,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Wait for both turns to complete.
|
||||||
|
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// Collect everything.
|
||||||
|
let mut all_events = Vec::new();
|
||||||
|
while let Ok(ev) = event_rx.try_recv() {
|
||||||
|
all_events.push(ev);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The queued message must have produced "queued reply".
|
||||||
|
assert!(
|
||||||
|
all_events
|
||||||
|
.iter()
|
||||||
|
.any(|e| matches!(e, UIEvent::StreamDelta(s) if s == "queued reply")),
|
||||||
|
"queued message was not processed; events: {all_events:?}"
|
||||||
|
);
|
||||||
|
|
||||||
|
action_tx.send(UserAction::Quit).await.unwrap();
|
||||||
|
handle.await.unwrap();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,16 @@
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// A streaming event emitted by the model provider.
|
/// A streaming event emitted by the model provider.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum StreamEvent {
|
pub enum StreamEvent {
|
||||||
/// A text chunk from the assistant's response.
|
/// A text chunk from the assistant's response.
|
||||||
TextDelta(String),
|
TextDelta(String),
|
||||||
|
/// A new tool-use content block has started.
|
||||||
|
ToolUseStart { id: String, name: String },
|
||||||
|
/// A chunk of the tool-use input JSON (streamed incrementally).
|
||||||
|
ToolUseInputDelta(String),
|
||||||
|
/// The current tool-use content block has ended.
|
||||||
|
ToolUseDone,
|
||||||
/// Number of input tokens used in this request.
|
/// Number of input tokens used in this request.
|
||||||
InputTokens(u32),
|
InputTokens(u32),
|
||||||
/// Number of output tokens generated so far.
|
/// Number of output tokens generated so far.
|
||||||
|
|
@ -18,6 +26,8 @@ pub enum StreamEvent {
|
||||||
pub enum UserAction {
|
pub enum UserAction {
|
||||||
/// The user has submitted a message.
|
/// The user has submitted a message.
|
||||||
SendMessage(String),
|
SendMessage(String),
|
||||||
|
/// The user has responded to a tool approval request.
|
||||||
|
ToolApprovalResponse { tool_use_id: String, approved: bool },
|
||||||
/// The user has requested to quit.
|
/// The user has requested to quit.
|
||||||
Quit,
|
Quit,
|
||||||
/// The user has requested to clear conversation history.
|
/// The user has requested to clear conversation history.
|
||||||
|
|
@ -29,6 +39,23 @@ pub enum UserAction {
|
||||||
pub enum UIEvent {
|
pub enum UIEvent {
|
||||||
/// A text chunk to append to the current assistant message.
|
/// A text chunk to append to the current assistant message.
|
||||||
StreamDelta(String),
|
StreamDelta(String),
|
||||||
|
/// A tool requires user approval before execution.
|
||||||
|
ToolApprovalRequest {
|
||||||
|
tool_use_id: String,
|
||||||
|
tool_name: String,
|
||||||
|
input_summary: String,
|
||||||
|
},
|
||||||
|
/// A tool is being executed (informational, after approval or auto-approve).
|
||||||
|
ToolExecuting {
|
||||||
|
tool_name: String,
|
||||||
|
input_summary: String,
|
||||||
|
},
|
||||||
|
/// A tool has finished executing.
|
||||||
|
ToolResult {
|
||||||
|
tool_name: String,
|
||||||
|
output_summary: String,
|
||||||
|
is_error: bool,
|
||||||
|
},
|
||||||
/// The current assistant turn has completed.
|
/// The current assistant turn has completed.
|
||||||
TurnComplete,
|
TurnComplete,
|
||||||
/// An error to display to the user.
|
/// An error to display to the user.
|
||||||
|
|
@ -36,7 +63,7 @@ pub enum UIEvent {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The role of a participant in a conversation.
|
/// The role of a participant in a conversation.
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
pub enum Role {
|
pub enum Role {
|
||||||
/// A message from the human user.
|
/// A message from the human user.
|
||||||
|
|
@ -45,11 +72,128 @@ pub enum Role {
|
||||||
Assistant,
|
Assistant,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A tool definition sent to the model so it knows which tools are available.
|
||||||
|
///
|
||||||
|
/// This is provider-agnostic -- the provider serializes it into the
|
||||||
|
/// format required by its API (e.g. the `tools` array for Anthropic).
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
pub struct ToolDefinition {
|
||||||
|
/// The tool name the model will use in `tool_use` blocks.
|
||||||
|
pub name: String,
|
||||||
|
/// Human-readable description of what the tool does.
|
||||||
|
pub description: String,
|
||||||
|
/// JSON Schema describing the tool's input parameters.
|
||||||
|
pub input_schema: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A typed content block within a conversation message.
|
||||||
|
///
|
||||||
|
/// The Anthropic Messages API represents message content as an array of typed
|
||||||
|
/// blocks. A single assistant message can contain interleaved text and tool-use
|
||||||
|
/// blocks; a user message following tool execution contains tool-result blocks.
|
||||||
|
///
|
||||||
|
/// See the [Messages API content blocks reference][content-blocks].
|
||||||
|
///
|
||||||
|
/// [content-blocks]: https://docs.anthropic.com/en/api/messages
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum ContentBlock {
|
||||||
|
/// Plain text content.
|
||||||
|
Text { text: String },
|
||||||
|
/// A tool invocation requested by the assistant.
|
||||||
|
ToolUse {
|
||||||
|
id: String,
|
||||||
|
name: String,
|
||||||
|
input: serde_json::Value,
|
||||||
|
},
|
||||||
|
/// The result of executing a tool, sent back in a user-role message.
|
||||||
|
ToolResult {
|
||||||
|
tool_use_id: String,
|
||||||
|
content: String,
|
||||||
|
#[serde(default)]
|
||||||
|
is_error: bool,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
/// A single message in the conversation history.
|
/// A single message in the conversation history.
|
||||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
///
|
||||||
|
/// Content is stored as a `Vec<ContentBlock>` to support mixed text and
|
||||||
|
/// tool-use blocks in a single message. For simple text-only messages, use
|
||||||
|
/// the [`ConversationMessage::text`] constructor.
|
||||||
|
///
|
||||||
|
/// # Serialization
|
||||||
|
///
|
||||||
|
/// Serializes `content` as a plain string when the message contains exactly
|
||||||
|
/// one `Text` block (the common case), and as an array of typed blocks
|
||||||
|
/// otherwise. This matches what the Anthropic API expects for both simple
|
||||||
|
/// and tool-use messages.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct ConversationMessage {
|
pub struct ConversationMessage {
|
||||||
/// The role of the message author.
|
/// The role of the message author.
|
||||||
pub role: Role,
|
pub role: Role,
|
||||||
/// The text content of the message.
|
/// The content blocks of this message.
|
||||||
pub content: String,
|
pub content: Vec<ContentBlock>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConversationMessage {
|
||||||
|
/// Create a simple text-only message.
|
||||||
|
pub fn text(role: Role, s: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
role,
|
||||||
|
content: vec![ContentBlock::Text { text: s.into() }],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract the concatenated text content, ignoring non-text blocks.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn text_content(&self) -> String {
|
||||||
|
let mut out = String::new();
|
||||||
|
for block in &self.content {
|
||||||
|
if let ContentBlock::Text { text } = block {
|
||||||
|
out.push_str(text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Serialize for ConversationMessage {
|
||||||
|
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||||
|
use serde::ser::SerializeMap;
|
||||||
|
let mut map = serializer.serialize_map(Some(2))?;
|
||||||
|
map.serialize_entry("role", &self.role)?;
|
||||||
|
// Single text block -> serialize as plain string (common case).
|
||||||
|
// Otherwise -> serialize as array of content blocks.
|
||||||
|
match self.content.as_slice() {
|
||||||
|
[ContentBlock::Text { text }] => map.serialize_entry("content", text)?,
|
||||||
|
blocks => map.serialize_entry("content", blocks)?,
|
||||||
|
}
|
||||||
|
map.end()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for ConversationMessage {
|
||||||
|
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct Raw {
|
||||||
|
role: Role,
|
||||||
|
content: serde_json::Value,
|
||||||
|
}
|
||||||
|
let raw = Raw::deserialize(deserializer)?;
|
||||||
|
let content = match raw.content {
|
||||||
|
serde_json::Value::String(s) => vec![ContentBlock::Text { text: s }],
|
||||||
|
serde_json::Value::Array(_) => {
|
||||||
|
serde_json::from_value(raw.content).map_err(serde::de::Error::custom)?
|
||||||
|
}
|
||||||
|
other => {
|
||||||
|
return Err(serde::de::Error::custom(format!(
|
||||||
|
"expected string or array for content, got {other}"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(ConversationMessage {
|
||||||
|
role: raw.role,
|
||||||
|
content,
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
mod app;
|
mod app;
|
||||||
mod core;
|
mod core;
|
||||||
mod provider;
|
mod provider;
|
||||||
|
mod tools;
|
||||||
mod tui;
|
mod tui;
|
||||||
|
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ use futures::{SinkExt, Stream, StreamExt};
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
use crate::core::types::{ConversationMessage, StreamEvent};
|
use crate::core::types::{ConversationMessage, StreamEvent, ToolDefinition};
|
||||||
|
|
||||||
use super::ModelProvider;
|
use super::ModelProvider;
|
||||||
|
|
||||||
|
|
@ -64,15 +64,17 @@ impl ModelProvider for ClaudeProvider {
|
||||||
fn stream<'a>(
|
fn stream<'a>(
|
||||||
&'a self,
|
&'a self,
|
||||||
messages: &'a [ConversationMessage],
|
messages: &'a [ConversationMessage],
|
||||||
|
tools: &'a [ToolDefinition],
|
||||||
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
||||||
let (mut tx, rx) = futures::channel::mpsc::channel(32);
|
let (mut tx, rx) = futures::channel::mpsc::channel(32);
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let api_key = self.api_key.clone();
|
let api_key = self.api_key.clone();
|
||||||
let model = self.model.clone();
|
let model = self.model.clone();
|
||||||
let messages = messages.to_vec();
|
let messages = messages.to_vec();
|
||||||
|
let tools = tools.to_vec();
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
run_stream(client, api_key, model, messages, &mut tx).await;
|
run_stream(client, api_key, model, messages, tools, &mut tx).await;
|
||||||
});
|
});
|
||||||
|
|
||||||
rx
|
rx
|
||||||
|
|
@ -126,14 +128,18 @@ async fn run_stream(
|
||||||
api_key: String,
|
api_key: String,
|
||||||
model: String,
|
model: String,
|
||||||
messages: Vec<ConversationMessage>,
|
messages: Vec<ConversationMessage>,
|
||||||
|
tools: Vec<ToolDefinition>,
|
||||||
tx: &mut futures::channel::mpsc::Sender<StreamEvent>,
|
tx: &mut futures::channel::mpsc::Sender<StreamEvent>,
|
||||||
) {
|
) {
|
||||||
let body = serde_json::json!({
|
let mut body = serde_json::json!({
|
||||||
"model": model,
|
"model": model,
|
||||||
"max_tokens": 8192,
|
"max_tokens": 8192,
|
||||||
"stream": true,
|
"stream": true,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
});
|
});
|
||||||
|
if !tools.is_empty() {
|
||||||
|
body["tools"] = serde_json::to_value(&tools).unwrap_or_default();
|
||||||
|
}
|
||||||
|
|
||||||
let response = match client
|
let response = match client
|
||||||
.post("https://api.anthropic.com/v1/messages")
|
.post("https://api.anthropic.com/v1/messages")
|
||||||
|
|
@ -223,6 +229,8 @@ fn find_double_newline(buf: &[u8]) -> Option<usize> {
|
||||||
struct SseEvent {
|
struct SseEvent {
|
||||||
#[serde(rename = "type")]
|
#[serde(rename = "type")]
|
||||||
event_type: String,
|
event_type: String,
|
||||||
|
/// Present on `content_block_start` events; describes the new block.
|
||||||
|
content_block: Option<SseContentBlock>,
|
||||||
/// Present on `content_block_delta` events.
|
/// Present on `content_block_delta` events.
|
||||||
delta: Option<SseDelta>,
|
delta: Option<SseDelta>,
|
||||||
/// Present on `message_start` events; carries initial token usage.
|
/// Present on `message_start` events; carries initial token usage.
|
||||||
|
|
@ -231,16 +239,29 @@ struct SseEvent {
|
||||||
usage: Option<SseUsage>,
|
usage: Option<SseUsage>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The `content_block` object inside a `content_block_start` event.
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
struct SseContentBlock {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
block_type: String,
|
||||||
|
/// Tool-use block ID; present when `block_type == "tool_use"`.
|
||||||
|
id: Option<String>,
|
||||||
|
/// Tool name; present when `block_type == "tool_use"`.
|
||||||
|
name: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
/// The `delta` object inside a `content_block_delta` event.
|
/// The `delta` object inside a `content_block_delta` event.
|
||||||
///
|
///
|
||||||
/// `type` is `"text_delta"` for plain text chunks; other delta types
|
/// `type` is `"text_delta"` for plain text chunks, or `"input_json_delta"`
|
||||||
/// (e.g. `"input_json_delta"` for tool-use blocks) are not yet handled.
|
/// for streaming tool-use input JSON.
|
||||||
#[derive(Deserialize, Debug)]
|
#[derive(Deserialize, Debug)]
|
||||||
struct SseDelta {
|
struct SseDelta {
|
||||||
#[serde(rename = "type")]
|
#[serde(rename = "type")]
|
||||||
delta_type: Option<String>,
|
delta_type: Option<String>,
|
||||||
/// The text chunk; present when `delta_type == "text_delta"`.
|
/// The text chunk; present when `delta_type == "text_delta"`.
|
||||||
text: Option<String>,
|
text: Option<String>,
|
||||||
|
/// Partial JSON for tool input; present when `delta_type == "input_json_delta"`.
|
||||||
|
partial_json: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The `message` object inside a `message_start` event.
|
/// The `message` object inside a `message_start` event.
|
||||||
|
|
@ -274,13 +295,20 @@ struct SseUsage {
|
||||||
/// # Mapping to [`StreamEvent`]
|
/// # Mapping to [`StreamEvent`]
|
||||||
///
|
///
|
||||||
/// | API event type | JSON path | Emits |
|
/// | API event type | JSON path | Emits |
|
||||||
/// |----------------------|------------------------------------|------------------------------|
|
/// |------------------------|------------------------------------|------------------------------|
|
||||||
/// | `message_start` | `.message.usage.input_tokens` | `InputTokens(n)` |
|
/// | `message_start` | `.message.usage.input_tokens` | `InputTokens(n)` |
|
||||||
|
/// | `content_block_start` | `.content_block.type == "tool_use"`| `ToolUseStart { id, name }` |
|
||||||
/// | `content_block_delta` | `.delta.type == "text_delta"` | `TextDelta(chunk)` |
|
/// | `content_block_delta` | `.delta.type == "text_delta"` | `TextDelta(chunk)` |
|
||||||
|
/// | `content_block_delta` | `.delta.type == "input_json_delta"`| `ToolUseInputDelta(json)` |
|
||||||
|
/// | `content_block_stop` | n/a | `ToolUseDone` |
|
||||||
/// | `message_delta` | `.usage.output_tokens` | `OutputTokens(n)` |
|
/// | `message_delta` | `.usage.output_tokens` | `OutputTokens(n)` |
|
||||||
/// | `message_stop` | n/a | `Done` |
|
/// | `message_stop` | n/a | `Done` |
|
||||||
/// | everything else | n/a | `None` (caller skips) |
|
/// | everything else | n/a | `None` (caller skips) |
|
||||||
///
|
///
|
||||||
|
/// Note: `content_block_stop` always emits `ToolUseDone`. The caller is
|
||||||
|
/// responsible for tracking whether a tool-use block is active and ignoring
|
||||||
|
/// `ToolUseDone` when it follows a text block.
|
||||||
|
///
|
||||||
/// [sse-fields]: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream
|
/// [sse-fields]: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream
|
||||||
fn parse_sse_event(event_str: &str) -> Option<StreamEvent> {
|
fn parse_sse_event(event_str: &str) -> Option<StreamEvent> {
|
||||||
// SSE events may have multiple fields; we only need `data:`.
|
// SSE events may have multiple fields; we only need `data:`.
|
||||||
|
|
@ -297,15 +325,29 @@ fn parse_sse_event(event_str: &str) -> Option<StreamEvent> {
|
||||||
.and_then(|u| u.input_tokens)
|
.and_then(|u| u.input_tokens)
|
||||||
.map(StreamEvent::InputTokens),
|
.map(StreamEvent::InputTokens),
|
||||||
|
|
||||||
"content_block_delta" => {
|
"content_block_start" => {
|
||||||
let delta = event.delta?;
|
let block = event.content_block?;
|
||||||
if delta.delta_type.as_deref() == Some("text_delta") {
|
if block.block_type == "tool_use" {
|
||||||
delta.text.map(StreamEvent::TextDelta)
|
Some(StreamEvent::ToolUseStart {
|
||||||
|
id: block.id.unwrap_or_default(),
|
||||||
|
name: block.name.unwrap_or_default(),
|
||||||
|
})
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
"content_block_delta" => {
|
||||||
|
let delta = event.delta?;
|
||||||
|
match delta.delta_type.as_deref() {
|
||||||
|
Some("text_delta") => delta.text.map(StreamEvent::TextDelta),
|
||||||
|
Some("input_json_delta") => delta.partial_json.map(StreamEvent::ToolUseInputDelta),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
"content_block_stop" => Some(StreamEvent::ToolUseDone),
|
||||||
|
|
||||||
// usage lives at the top level of message_delta, not inside delta.
|
// usage lives at the top level of message_delta, not inside delta.
|
||||||
"message_delta" => event
|
"message_delta" => event
|
||||||
.usage
|
.usage
|
||||||
|
|
@ -314,8 +356,7 @@ fn parse_sse_event(event_str: &str) -> Option<StreamEvent> {
|
||||||
|
|
||||||
"message_stop" => Some(StreamEvent::Done),
|
"message_stop" => Some(StreamEvent::Done),
|
||||||
|
|
||||||
// error, ping, content_block_start, content_block_stop -- ignored or
|
// error, ping -- ignored.
|
||||||
// handled by the caller.
|
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -371,13 +412,15 @@ mod tests {
|
||||||
.filter_map(parse_sse_event)
|
.filter_map(parse_sse_event)
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// content_block_start, ping, content_block_stop -> None (filtered out)
|
// content_block_start (text) and ping -> None (filtered out)
|
||||||
assert_eq!(events.len(), 5);
|
// content_block_stop -> ToolUseDone (always emitted; caller filters)
|
||||||
|
assert_eq!(events.len(), 6);
|
||||||
assert!(matches!(events[0], StreamEvent::InputTokens(10)));
|
assert!(matches!(events[0], StreamEvent::InputTokens(10)));
|
||||||
assert!(matches!(&events[1], StreamEvent::TextDelta(s) if s == "Hello"));
|
assert!(matches!(&events[1], StreamEvent::TextDelta(s) if s == "Hello"));
|
||||||
assert!(matches!(&events[2], StreamEvent::TextDelta(s) if s == ", world!"));
|
assert!(matches!(&events[2], StreamEvent::TextDelta(s) if s == ", world!"));
|
||||||
assert!(matches!(events[3], StreamEvent::OutputTokens(5)));
|
assert!(matches!(events[3], StreamEvent::ToolUseDone));
|
||||||
assert!(matches!(events[4], StreamEvent::Done));
|
assert!(matches!(events[4], StreamEvent::OutputTokens(5)));
|
||||||
|
assert!(matches!(events[5], StreamEvent::Done));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -408,14 +451,8 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn test_messages_serialize_to_anthropic_format() {
|
fn test_messages_serialize_to_anthropic_format() {
|
||||||
let messages = vec![
|
let messages = vec![
|
||||||
ConversationMessage {
|
ConversationMessage::text(Role::User, "Hello"),
|
||||||
role: Role::User,
|
ConversationMessage::text(Role::Assistant, "Hi there!"),
|
||||||
content: "Hello".to_string(),
|
|
||||||
},
|
|
||||||
ConversationMessage {
|
|
||||||
role: Role::Assistant,
|
|
||||||
content: "Hi there!".to_string(),
|
|
||||||
},
|
|
||||||
];
|
];
|
||||||
|
|
||||||
let json = serde_json::json!({
|
let json = serde_json::json!({
|
||||||
|
|
@ -433,6 +470,65 @@ mod tests {
|
||||||
assert!(json["max_tokens"].as_u64().unwrap() > 0);
|
assert!(json["max_tokens"].as_u64().unwrap() > 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// SSE fixture for a response that contains a tool-use block.
|
||||||
|
///
|
||||||
|
/// Sequence: message_start -> content_block_start (tool_use) ->
|
||||||
|
/// content_block_delta (input_json_delta) -> content_block_stop ->
|
||||||
|
/// message_delta -> message_stop
|
||||||
|
const TOOL_USE_SSE_FIXTURE: &str = concat!(
|
||||||
|
"event: message_start\n",
|
||||||
|
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_456\",\"type\":\"message\",",
|
||||||
|
"\"role\":\"assistant\",\"content\":[],\"model\":\"claude-opus-4-6\",",
|
||||||
|
"\"stop_reason\":null,\"stop_sequence\":null,",
|
||||||
|
"\"usage\":{\"input_tokens\":25,\"cache_creation_input_tokens\":0,",
|
||||||
|
"\"cache_read_input_tokens\":0,\"output_tokens\":1}}}\n",
|
||||||
|
"\n",
|
||||||
|
"event: content_block_start\n",
|
||||||
|
"data: {\"type\":\"content_block_start\",\"index\":0,",
|
||||||
|
"\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_01ABC\",\"name\":\"read_file\",\"input\":{}}}\n",
|
||||||
|
"\n",
|
||||||
|
"event: content_block_delta\n",
|
||||||
|
"data: {\"type\":\"content_block_delta\",\"index\":0,",
|
||||||
|
"\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"path\\\": \\\"src/\"}}\n",
|
||||||
|
"\n",
|
||||||
|
"event: content_block_delta\n",
|
||||||
|
"data: {\"type\":\"content_block_delta\",\"index\":0,",
|
||||||
|
"\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"main.rs\\\"}\"}}\n",
|
||||||
|
"\n",
|
||||||
|
"event: content_block_stop\n",
|
||||||
|
"data: {\"type\":\"content_block_stop\",\"index\":0}\n",
|
||||||
|
"\n",
|
||||||
|
"event: message_delta\n",
|
||||||
|
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",",
|
||||||
|
"\"stop_sequence\":null},\"usage\":{\"output_tokens\":15}}\n",
|
||||||
|
"\n",
|
||||||
|
"event: message_stop\n",
|
||||||
|
"data: {\"type\":\"message_stop\"}\n",
|
||||||
|
"\n",
|
||||||
|
);
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_tool_use_sse_fixture() {
|
||||||
|
let events: Vec<StreamEvent> = TOOL_USE_SSE_FIXTURE
|
||||||
|
.split("\n\n")
|
||||||
|
.filter(|s| !s.trim().is_empty())
|
||||||
|
.filter_map(parse_sse_event)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
assert_eq!(events.len(), 7);
|
||||||
|
assert!(matches!(events[0], StreamEvent::InputTokens(25)));
|
||||||
|
assert!(
|
||||||
|
matches!(&events[1], StreamEvent::ToolUseStart { id, name } if id == "toolu_01ABC" && name == "read_file")
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
matches!(&events[2], StreamEvent::ToolUseInputDelta(s) if s == "{\"path\": \"src/")
|
||||||
|
);
|
||||||
|
assert!(matches!(&events[3], StreamEvent::ToolUseInputDelta(s) if s == "main.rs\"}"));
|
||||||
|
assert!(matches!(events[4], StreamEvent::ToolUseDone));
|
||||||
|
assert!(matches!(events[5], StreamEvent::OutputTokens(15)));
|
||||||
|
assert!(matches!(events[6], StreamEvent::Done));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_find_double_newline() {
|
fn test_find_double_newline() {
|
||||||
assert_eq!(find_double_newline(b"abc\n\ndef"), Some(3));
|
assert_eq!(find_double_newline(b"abc\n\ndef"), Some(3));
|
||||||
|
|
|
||||||
|
|
@ -4,16 +4,18 @@ pub use claude::ClaudeProvider;
|
||||||
|
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
|
|
||||||
use crate::core::types::{ConversationMessage, StreamEvent};
|
use crate::core::types::{ConversationMessage, StreamEvent, ToolDefinition};
|
||||||
|
|
||||||
/// Trait for model providers that can stream conversation responses.
|
/// Trait for model providers that can stream conversation responses.
|
||||||
///
|
///
|
||||||
/// Implementors take a conversation history and return a stream of [`StreamEvent`]s.
|
/// Implementors take a conversation history and return a stream of [`StreamEvent`]s.
|
||||||
/// The trait is provider-agnostic -- no Claude-specific types appear here.
|
/// The trait is provider-agnostic -- no Claude-specific types appear here.
|
||||||
pub trait ModelProvider: Send + Sync {
|
pub trait ModelProvider: Send + Sync {
|
||||||
/// Stream a response from the model given the conversation history.
|
/// Stream a response from the model given the conversation history and
|
||||||
|
/// available tool definitions. Pass an empty slice if no tools are available.
|
||||||
fn stream<'a>(
|
fn stream<'a>(
|
||||||
&'a self,
|
&'a self,
|
||||||
messages: &'a [ConversationMessage],
|
messages: &'a [ConversationMessage],
|
||||||
|
tools: &'a [ToolDefinition],
|
||||||
) -> impl Stream<Item = StreamEvent> + Send + 'a;
|
) -> impl Stream<Item = StreamEvent> + Send + 'a;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
87
src/tools/list_directory.rs
Normal file
87
src/tools/list_directory.rs
Normal file
|
|
@ -0,0 +1,87 @@
|
||||||
|
//! `list_directory` tool: lists entries in a directory within the working directory.
|
||||||
|
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
use super::{RiskLevel, Tool, ToolError, ToolOutput, validate_path};
|
||||||
|
|
||||||
|
/// Lists directory contents. Auto-approved (read-only).
|
||||||
|
pub struct ListDirectory;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for ListDirectory {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"list_directory"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"List the files and subdirectories in a directory. The path is relative to the working directory."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn input_schema(&self) -> serde_json::Value {
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The directory path to list, relative to the working directory. Use '.' for the working directory itself."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn risk_level(&self) -> RiskLevel {
|
||||||
|
RiskLevel::AutoApprove
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(
|
||||||
|
&self,
|
||||||
|
input: &serde_json::Value,
|
||||||
|
working_dir: &Path,
|
||||||
|
) -> Result<ToolOutput, ToolError> {
|
||||||
|
let path_str = input["path"]
|
||||||
|
.as_str()
|
||||||
|
.ok_or_else(|| ToolError::InvalidInput("missing 'path' string".to_string()))?;
|
||||||
|
|
||||||
|
let canonical = validate_path(working_dir, path_str)?;
|
||||||
|
|
||||||
|
let mut entries: Vec<String> = Vec::new();
|
||||||
|
let mut dir = tokio::fs::read_dir(&canonical).await?;
|
||||||
|
while let Some(entry) = dir.next_entry().await? {
|
||||||
|
let name = entry.file_name().to_string_lossy().to_string();
|
||||||
|
let suffix = if entry.file_type().await?.is_dir() {
|
||||||
|
"/"
|
||||||
|
} else {
|
||||||
|
""
|
||||||
|
};
|
||||||
|
entries.push(format!("{name}{suffix}"));
|
||||||
|
}
|
||||||
|
entries.sort();
|
||||||
|
|
||||||
|
Ok(ToolOutput {
|
||||||
|
content: entries.join("\n"),
|
||||||
|
is_error: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::fs;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn list_directory_contents() {
|
||||||
|
let dir = TempDir::new().unwrap();
|
||||||
|
fs::write(dir.path().join("a.txt"), "").unwrap();
|
||||||
|
fs::create_dir(dir.path().join("subdir")).unwrap();
|
||||||
|
let tool = ListDirectory;
|
||||||
|
let input = serde_json::json!({"path": "."});
|
||||||
|
let out = tool.execute(&input, dir.path()).await.unwrap();
|
||||||
|
assert!(out.content.contains("a.txt"));
|
||||||
|
assert!(out.content.contains("subdir/"));
|
||||||
|
}
|
||||||
|
}
|
||||||
220
src/tools/mod.rs
Normal file
220
src/tools/mod.rs
Normal file
|
|
@ -0,0 +1,220 @@
|
||||||
|
//! Tool system: trait, registry, risk classification, and built-in tools.
|
||||||
|
//!
|
||||||
|
//! All tools implement the [`Tool`] trait. The [`ToolRegistry`] collects them
|
||||||
|
//! and provides lookup by name plus generation of [`ToolDefinition`]s for the
|
||||||
|
//! model provider.
|
||||||
|
|
||||||
|
mod list_directory;
|
||||||
|
mod read_file;
|
||||||
|
mod shell_exec;
|
||||||
|
mod write_file;
|
||||||
|
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
use crate::core::types::ToolDefinition;
|
||||||
|
|
||||||
|
/// The output of a tool execution.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ToolOutput {
|
||||||
|
/// The text content returned to the model.
|
||||||
|
pub content: String,
|
||||||
|
/// Whether the tool encountered an error.
|
||||||
|
pub is_error: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Risk classification for tool approval gating.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum RiskLevel {
|
||||||
|
/// Safe to execute without user confirmation (e.g. read-only operations).
|
||||||
|
AutoApprove,
|
||||||
|
/// Requires explicit user approval before execution (e.g. writes, shell).
|
||||||
|
RequiresApproval,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A tool that the model can invoke.
|
||||||
|
///
|
||||||
|
/// The `execute` method is async so that tool implementations can use
|
||||||
|
/// `tokio::fs` and `tokio::process` without blocking a Tokio worker thread.
|
||||||
|
/// `#[async_trait]` desugars the async fn to a boxed future, which is required
|
||||||
|
/// for `dyn Tool` to remain object-safe.
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Tool: Send + Sync {
|
||||||
|
/// The name the model uses to invoke this tool.
|
||||||
|
fn name(&self) -> &str;
|
||||||
|
/// Human-readable description for the model.
|
||||||
|
fn description(&self) -> &str;
|
||||||
|
/// JSON Schema for the tool's input parameters.
|
||||||
|
fn input_schema(&self) -> serde_json::Value;
|
||||||
|
/// The risk level of this tool.
|
||||||
|
fn risk_level(&self) -> RiskLevel;
|
||||||
|
/// Execute the tool with the given input, confined to `working_dir`.
|
||||||
|
async fn execute(
|
||||||
|
&self,
|
||||||
|
input: &serde_json::Value,
|
||||||
|
working_dir: &Path,
|
||||||
|
) -> Result<ToolOutput, ToolError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Errors from tool execution.
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum ToolError {
|
||||||
|
/// The requested path escapes the working directory.
|
||||||
|
#[error("path escapes working directory: {0}")]
|
||||||
|
PathEscape(PathBuf),
|
||||||
|
/// An I/O error during tool execution.
|
||||||
|
#[error("I/O error: {0}")]
|
||||||
|
Io(#[from] std::io::Error),
|
||||||
|
/// A required input field is missing or has the wrong type.
|
||||||
|
#[error("invalid input: {0}")]
|
||||||
|
InvalidInput(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate that `requested` resolves to a path inside `working_dir`.
|
||||||
|
///
|
||||||
|
/// Joins `working_dir` with `requested`, canonicalizes the result (resolving
|
||||||
|
/// symlinks and `..` components), and checks that the canonical path starts
|
||||||
|
/// with the canonical working directory.
|
||||||
|
///
|
||||||
|
/// Returns the canonical path on success, or [`ToolError::PathEscape`] if the
|
||||||
|
/// path would escape the working directory.
|
||||||
|
pub fn validate_path(working_dir: &Path, requested: &str) -> Result<PathBuf, ToolError> {
|
||||||
|
let candidate = if Path::new(requested).is_absolute() {
|
||||||
|
PathBuf::from(requested)
|
||||||
|
} else {
|
||||||
|
working_dir.join(requested)
|
||||||
|
};
|
||||||
|
|
||||||
|
// For paths that don't exist yet (e.g. write_file creating a new file),
|
||||||
|
// canonicalize the parent directory and append the filename.
|
||||||
|
let canonical = if candidate.exists() {
|
||||||
|
candidate
|
||||||
|
.canonicalize()
|
||||||
|
.map_err(|_| ToolError::PathEscape(candidate.clone()))?
|
||||||
|
} else {
|
||||||
|
let parent = candidate
|
||||||
|
.parent()
|
||||||
|
.ok_or_else(|| ToolError::PathEscape(candidate.clone()))?;
|
||||||
|
let file_name = candidate
|
||||||
|
.file_name()
|
||||||
|
.ok_or_else(|| ToolError::PathEscape(candidate.clone()))?;
|
||||||
|
let canonical_parent = parent
|
||||||
|
.canonicalize()
|
||||||
|
.map_err(|_| ToolError::PathEscape(candidate.clone()))?;
|
||||||
|
canonical_parent.join(file_name)
|
||||||
|
};
|
||||||
|
|
||||||
|
let canonical_root = working_dir
|
||||||
|
.canonicalize()
|
||||||
|
.map_err(|_| ToolError::PathEscape(candidate.clone()))?;
|
||||||
|
|
||||||
|
if canonical.starts_with(&canonical_root) {
|
||||||
|
Ok(canonical)
|
||||||
|
} else {
|
||||||
|
Err(ToolError::PathEscape(candidate))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Collection of available tools with name-based lookup.
|
||||||
|
pub struct ToolRegistry {
|
||||||
|
tools: Vec<Box<dyn Tool>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolRegistry {
|
||||||
|
/// Create an empty registry with no tools.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn empty() -> Self {
|
||||||
|
Self { tools: Vec::new() }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a registry with the default built-in tools.
|
||||||
|
pub fn default_tools() -> Self {
|
||||||
|
Self {
|
||||||
|
tools: vec![
|
||||||
|
Box::new(read_file::ReadFile),
|
||||||
|
Box::new(list_directory::ListDirectory),
|
||||||
|
Box::new(write_file::WriteFile),
|
||||||
|
Box::new(shell_exec::ShellExec),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Look up a tool by name.
|
||||||
|
pub fn get(&self, name: &str) -> Option<&dyn Tool> {
|
||||||
|
self.tools.iter().find(|t| t.name() == name).map(|t| &**t)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate [`ToolDefinition`]s for the model provider.
|
||||||
|
pub fn definitions(&self) -> Vec<ToolDefinition> {
|
||||||
|
self.tools
|
||||||
|
.iter()
|
||||||
|
.map(|t| ToolDefinition {
|
||||||
|
name: t.name().to_string(),
|
||||||
|
description: t.description().to_string(),
|
||||||
|
input_schema: t.input_schema(),
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::fs;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn validate_path_allows_subpath() {
|
||||||
|
let dir = TempDir::new().unwrap();
|
||||||
|
let sub = dir.path().join("sub");
|
||||||
|
fs::create_dir(&sub).unwrap();
|
||||||
|
let result = validate_path(dir.path(), "sub");
|
||||||
|
assert!(result.is_ok());
|
||||||
|
assert!(
|
||||||
|
result
|
||||||
|
.unwrap()
|
||||||
|
.starts_with(dir.path().canonicalize().unwrap())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn validate_path_rejects_traversal() {
|
||||||
|
let dir = TempDir::new().unwrap();
|
||||||
|
let result = validate_path(dir.path(), "../../../etc/passwd");
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(matches!(result, Err(ToolError::PathEscape(_))));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn validate_path_rejects_absolute_outside() {
|
||||||
|
let dir = TempDir::new().unwrap();
|
||||||
|
let result = validate_path(dir.path(), "/etc/passwd");
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn validate_path_allows_new_file_in_working_dir() {
|
||||||
|
let dir = TempDir::new().unwrap();
|
||||||
|
let result = validate_path(dir.path(), "new_file.txt");
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn registry_default_has_all_tools() {
|
||||||
|
let reg = ToolRegistry::default_tools();
|
||||||
|
assert!(reg.get("read_file").is_some());
|
||||||
|
assert!(reg.get("list_directory").is_some());
|
||||||
|
assert!(reg.get("write_file").is_some());
|
||||||
|
assert!(reg.get("shell_exec").is_some());
|
||||||
|
assert!(reg.get("nonexistent").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn registry_definitions_match_tools() {
|
||||||
|
let reg = ToolRegistry::default_tools();
|
||||||
|
let defs = reg.definitions();
|
||||||
|
assert_eq!(defs.len(), 4);
|
||||||
|
assert!(defs.iter().any(|d| d.name == "read_file"));
|
||||||
|
}
|
||||||
|
}
|
||||||
92
src/tools/read_file.rs
Normal file
92
src/tools/read_file.rs
Normal file
|
|
@ -0,0 +1,92 @@
|
||||||
|
//! `read_file` tool: reads the contents of a file within the working directory.
|
||||||
|
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
use super::{RiskLevel, Tool, ToolError, ToolOutput, validate_path};
|
||||||
|
|
||||||
|
/// Reads file contents. Auto-approved (read-only).
|
||||||
|
pub struct ReadFile;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for ReadFile {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"read_file"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Read the contents of a file. The path is relative to the working directory."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn input_schema(&self) -> serde_json::Value {
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The file path to read, relative to the working directory."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn risk_level(&self) -> RiskLevel {
|
||||||
|
RiskLevel::AutoApprove
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(
|
||||||
|
&self,
|
||||||
|
input: &serde_json::Value,
|
||||||
|
working_dir: &Path,
|
||||||
|
) -> Result<ToolOutput, ToolError> {
|
||||||
|
let path_str = input["path"]
|
||||||
|
.as_str()
|
||||||
|
.ok_or_else(|| ToolError::InvalidInput("missing 'path' string".to_string()))?;
|
||||||
|
|
||||||
|
let canonical = validate_path(working_dir, path_str)?;
|
||||||
|
let content = tokio::fs::read_to_string(&canonical).await?;
|
||||||
|
|
||||||
|
Ok(ToolOutput {
|
||||||
|
content,
|
||||||
|
is_error: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::fs;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn read_existing_file() {
|
||||||
|
let dir = TempDir::new().unwrap();
|
||||||
|
fs::write(dir.path().join("hello.txt"), "world").unwrap();
|
||||||
|
let tool = ReadFile;
|
||||||
|
let input = serde_json::json!({"path": "hello.txt"});
|
||||||
|
let out = tool.execute(&input, dir.path()).await.unwrap();
|
||||||
|
assert_eq!(out.content, "world");
|
||||||
|
assert!(!out.is_error);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn read_nonexistent_file_errors() {
|
||||||
|
let dir = TempDir::new().unwrap();
|
||||||
|
let tool = ReadFile;
|
||||||
|
let input = serde_json::json!({"path": "nope.txt"});
|
||||||
|
let result = tool.execute(&input, dir.path()).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn read_path_traversal_rejected() {
|
||||||
|
let dir = TempDir::new().unwrap();
|
||||||
|
let tool = ReadFile;
|
||||||
|
let input = serde_json::json!({"path": "../../../etc/passwd"});
|
||||||
|
let result = tool.execute(&input, dir.path()).await;
|
||||||
|
assert!(matches!(result, Err(ToolError::PathEscape(_))));
|
||||||
|
}
|
||||||
|
}
|
||||||
108
src/tools/shell_exec.rs
Normal file
108
src/tools/shell_exec.rs
Normal file
|
|
@ -0,0 +1,108 @@
|
||||||
|
//! `shell_exec` tool: runs a shell command within the working directory.
|
||||||
|
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
use super::{RiskLevel, Tool, ToolError, ToolOutput};
|
||||||
|
|
||||||
|
/// Executes a shell command. Requires user approval.
|
||||||
|
pub struct ShellExec;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for ShellExec {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"shell_exec"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Execute a shell command in the working directory. Returns stdout and stderr."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn input_schema(&self) -> serde_json::Value {
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The shell command to execute."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["command"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn risk_level(&self) -> RiskLevel {
|
||||||
|
RiskLevel::RequiresApproval
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(
|
||||||
|
&self,
|
||||||
|
input: &serde_json::Value,
|
||||||
|
working_dir: &Path,
|
||||||
|
) -> Result<ToolOutput, ToolError> {
|
||||||
|
let command = input["command"]
|
||||||
|
.as_str()
|
||||||
|
.ok_or_else(|| ToolError::InvalidInput("missing 'command' string".to_string()))?;
|
||||||
|
|
||||||
|
let output = tokio::process::Command::new("sh")
|
||||||
|
.arg("-c")
|
||||||
|
.arg(command)
|
||||||
|
.current_dir(working_dir)
|
||||||
|
.output()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||||
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||||
|
|
||||||
|
let mut content = String::new();
|
||||||
|
if !stdout.is_empty() {
|
||||||
|
content.push_str(&stdout);
|
||||||
|
}
|
||||||
|
if !stderr.is_empty() {
|
||||||
|
if !content.is_empty() {
|
||||||
|
content.push('\n');
|
||||||
|
}
|
||||||
|
content.push_str("[stderr]\n");
|
||||||
|
content.push_str(&stderr);
|
||||||
|
}
|
||||||
|
if content.is_empty() {
|
||||||
|
content.push_str("(no output)");
|
||||||
|
}
|
||||||
|
|
||||||
|
let is_error = !output.status.success();
|
||||||
|
if is_error {
|
||||||
|
content.push_str(&format!(
|
||||||
|
"\n[exit code: {}]",
|
||||||
|
output.status.code().unwrap_or(-1)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ToolOutput { content, is_error })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn shell_exec_echo() {
|
||||||
|
let dir = TempDir::new().unwrap();
|
||||||
|
let tool = ShellExec;
|
||||||
|
let input = serde_json::json!({"command": "echo hello"});
|
||||||
|
let out = tool.execute(&input, dir.path()).await.unwrap();
|
||||||
|
assert!(out.content.contains("hello"));
|
||||||
|
assert!(!out.is_error);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn shell_exec_failing_command() {
|
||||||
|
let dir = TempDir::new().unwrap();
|
||||||
|
let tool = ShellExec;
|
||||||
|
let input = serde_json::json!({"command": "false"});
|
||||||
|
let out = tool.execute(&input, dir.path()).await.unwrap();
|
||||||
|
assert!(out.is_error);
|
||||||
|
}
|
||||||
|
}
|
||||||
98
src/tools/write_file.rs
Normal file
98
src/tools/write_file.rs
Normal file
|
|
@ -0,0 +1,98 @@
|
||||||
|
//! `write_file` tool: writes content to a file within the working directory.
|
||||||
|
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
use super::{RiskLevel, Tool, ToolError, ToolOutput, validate_path};
|
||||||
|
|
||||||
|
/// Writes content to a file. Requires user approval.
|
||||||
|
pub struct WriteFile;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for WriteFile {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"write_file"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Write content to a file. Creates the file if it doesn't exist, overwrites if it does. The path is relative to the working directory."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn input_schema(&self) -> serde_json::Value {
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The file path to write to, relative to the working directory."
|
||||||
|
},
|
||||||
|
"content": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The content to write to the file."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path", "content"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn risk_level(&self) -> RiskLevel {
|
||||||
|
RiskLevel::RequiresApproval
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(
|
||||||
|
&self,
|
||||||
|
input: &serde_json::Value,
|
||||||
|
working_dir: &Path,
|
||||||
|
) -> Result<ToolOutput, ToolError> {
|
||||||
|
let path_str = input["path"]
|
||||||
|
.as_str()
|
||||||
|
.ok_or_else(|| ToolError::InvalidInput("missing 'path' string".to_string()))?;
|
||||||
|
let content = input["content"]
|
||||||
|
.as_str()
|
||||||
|
.ok_or_else(|| ToolError::InvalidInput("missing 'content' string".to_string()))?;
|
||||||
|
|
||||||
|
let canonical = validate_path(working_dir, path_str)?;
|
||||||
|
|
||||||
|
// Create parent directories if needed.
|
||||||
|
if let Some(parent) = canonical.parent() {
|
||||||
|
tokio::fs::create_dir_all(parent).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
tokio::fs::write(&canonical, content).await?;
|
||||||
|
|
||||||
|
Ok(ToolOutput {
|
||||||
|
content: format!("Wrote {} bytes to {path_str}", content.len()),
|
||||||
|
is_error: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::fs;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn write_creates_file() {
|
||||||
|
let dir = TempDir::new().unwrap();
|
||||||
|
let tool = WriteFile;
|
||||||
|
let input = serde_json::json!({"path": "out.txt", "content": "hello"});
|
||||||
|
let out = tool.execute(&input, dir.path()).await.unwrap();
|
||||||
|
assert!(!out.is_error);
|
||||||
|
assert_eq!(
|
||||||
|
fs::read_to_string(dir.path().join("out.txt")).unwrap(),
|
||||||
|
"hello"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn write_path_traversal_rejected() {
|
||||||
|
let dir = TempDir::new().unwrap();
|
||||||
|
let tool = WriteFile;
|
||||||
|
let input = serde_json::json!({"path": "../../evil.txt", "content": "bad"});
|
||||||
|
let result = tool.execute(&input, dir.path()).await;
|
||||||
|
assert!(matches!(result, Err(ToolError::PathEscape(_))));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -12,8 +12,11 @@ use crate::core::types::{Role, UIEvent};
|
||||||
/// immediately when the channel is empty.
|
/// immediately when the channel is empty.
|
||||||
///
|
///
|
||||||
/// | Event | Effect |
|
/// | Event | Effect |
|
||||||
/// |--------------------|------------------------------------------------------------|
|
/// |------------------------|------------------------------------------------------------|
|
||||||
/// | `StreamDelta(s)` | Append `s` to last message if it's `Assistant`; else push new |
|
/// | `StreamDelta(s)` | Append `s` to last message if it's `Assistant`; else push |
|
||||||
|
/// | `ToolApprovalRequest` | Set `pending_approval` in state |
|
||||||
|
/// | `ToolExecuting` | Display tool execution info |
|
||||||
|
/// | `ToolResult` | Display tool result |
|
||||||
/// | `TurnComplete` | No structural change; logged at debug level |
|
/// | `TurnComplete` | No structural change; logged at debug level |
|
||||||
/// | `Error(msg)` | Push `(Assistant, "[error] {msg}")` |
|
/// | `Error(msg)` | Push `(Assistant, "[error] {msg}")` |
|
||||||
pub(super) fn drain_ui_events(event_rx: &mut mpsc::Receiver<UIEvent>, state: &mut AppState) {
|
pub(super) fn drain_ui_events(event_rx: &mut mpsc::Receiver<UIEvent>, state: &mut AppState) {
|
||||||
|
|
@ -26,6 +29,36 @@ pub(super) fn drain_ui_events(event_rx: &mut mpsc::Receiver<UIEvent>, state: &mu
|
||||||
state.messages.push((Role::Assistant, chunk));
|
state.messages.push((Role::Assistant, chunk));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
UIEvent::ToolApprovalRequest {
|
||||||
|
tool_use_id,
|
||||||
|
tool_name,
|
||||||
|
input_summary,
|
||||||
|
} => {
|
||||||
|
state.pending_approval = Some(PendingApproval {
|
||||||
|
tool_use_id,
|
||||||
|
tool_name,
|
||||||
|
input_summary,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
UIEvent::ToolExecuting {
|
||||||
|
tool_name,
|
||||||
|
input_summary,
|
||||||
|
} => {
|
||||||
|
state
|
||||||
|
.messages
|
||||||
|
.push((Role::Assistant, format!("[{tool_name}] {input_summary}")));
|
||||||
|
}
|
||||||
|
UIEvent::ToolResult {
|
||||||
|
tool_name,
|
||||||
|
output_summary,
|
||||||
|
is_error,
|
||||||
|
} => {
|
||||||
|
let prefix = if is_error { "error" } else { "result" };
|
||||||
|
state.messages.push((
|
||||||
|
Role::Assistant,
|
||||||
|
format!("[{tool_name} {prefix}] {output_summary}"),
|
||||||
|
));
|
||||||
|
}
|
||||||
UIEvent::TurnComplete => {
|
UIEvent::TurnComplete => {
|
||||||
debug!("turn complete");
|
debug!("turn complete");
|
||||||
}
|
}
|
||||||
|
|
@ -38,6 +71,14 @@ pub(super) fn drain_ui_events(event_rx: &mut mpsc::Receiver<UIEvent>, state: &mu
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A pending tool approval request waiting for user input.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct PendingApproval {
|
||||||
|
pub tool_use_id: String,
|
||||||
|
pub tool_name: String,
|
||||||
|
pub input_summary: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
@ -69,4 +110,39 @@ mod tests {
|
||||||
assert_eq!(state.messages[1].0, Role::Assistant);
|
assert_eq!(state.messages[1].0, Role::Assistant);
|
||||||
assert_eq!(state.messages[1].1, "hello");
|
assert_eq!(state.messages[1].1, "hello");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn drain_tool_approval_sets_pending() {
|
||||||
|
let (tx, mut rx) = tokio::sync::mpsc::channel(8);
|
||||||
|
let mut state = AppState::new();
|
||||||
|
tx.send(UIEvent::ToolApprovalRequest {
|
||||||
|
tool_use_id: "t1".to_string(),
|
||||||
|
tool_name: "write_file".to_string(),
|
||||||
|
input_summary: "path: foo.txt".to_string(),
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
drop(tx);
|
||||||
|
drain_ui_events(&mut rx, &mut state);
|
||||||
|
assert!(state.pending_approval.is_some());
|
||||||
|
let approval = state.pending_approval.unwrap();
|
||||||
|
assert_eq!(approval.tool_name, "write_file");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn drain_tool_result_adds_message() {
|
||||||
|
let (tx, mut rx) = tokio::sync::mpsc::channel(8);
|
||||||
|
let mut state = AppState::new();
|
||||||
|
tx.send(UIEvent::ToolResult {
|
||||||
|
tool_name: "read_file".to_string(),
|
||||||
|
output_summary: "file contents...".to_string(),
|
||||||
|
is_error: false,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
drop(tx);
|
||||||
|
drain_ui_events(&mut rx, &mut state);
|
||||||
|
assert_eq!(state.messages.len(), 1);
|
||||||
|
assert!(state.messages[0].1.contains("read_file result"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,8 @@ pub(super) enum LoopControl {
|
||||||
Quit,
|
Quit,
|
||||||
/// The user ran `:clear`; wipe the conversation.
|
/// The user ran `:clear`; wipe the conversation.
|
||||||
ClearHistory,
|
ClearHistory,
|
||||||
|
/// The user responded to a tool approval prompt.
|
||||||
|
ToolApproval { tool_use_id: String, approved: bool },
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Map a key event to a [`LoopControl`] signal, mutating `state` as a side-effect.
|
/// Map a key event to a [`LoopControl`] signal, mutating `state` as a side-effect.
|
||||||
|
|
@ -23,6 +25,29 @@ pub(super) fn handle_key(key: Option<KeyEvent>, state: &mut AppState) -> Option<
|
||||||
let key = key?;
|
let key = key?;
|
||||||
// Clear any transient status error on the next keypress.
|
// Clear any transient status error on the next keypress.
|
||||||
state.status_error = None;
|
state.status_error = None;
|
||||||
|
|
||||||
|
// If a tool approval is pending, intercept y/n before normal key handling.
|
||||||
|
if let Some(approval) = &state.pending_approval {
|
||||||
|
let tool_use_id = approval.tool_use_id.clone();
|
||||||
|
match key.code {
|
||||||
|
KeyCode::Char('y') | KeyCode::Char('Y') => {
|
||||||
|
state.pending_approval = None;
|
||||||
|
return Some(LoopControl::ToolApproval {
|
||||||
|
tool_use_id,
|
||||||
|
approved: true,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
KeyCode::Char('n') | KeyCode::Char('N') => {
|
||||||
|
state.pending_approval = None;
|
||||||
|
return Some(LoopControl::ToolApproval {
|
||||||
|
tool_use_id,
|
||||||
|
approved: false,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
_ => return None, // ignore other keys while approval pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Ctrl+C quits from any mode.
|
// Ctrl+C quits from any mode.
|
||||||
if key.modifiers.contains(KeyModifiers::CONTROL) && key.code == KeyCode::Char('c') {
|
if key.modifiers.contains(KeyModifiers::CONTROL) && key.code == KeyCode::Char('c') {
|
||||||
return Some(LoopControl::Quit);
|
return Some(LoopControl::Quit);
|
||||||
|
|
|
||||||
|
|
@ -72,6 +72,8 @@ pub struct AppState {
|
||||||
pub viewport_height: u16,
|
pub viewport_height: u16,
|
||||||
/// Transient error message shown in the status bar, cleared on next keypress.
|
/// Transient error message shown in the status bar, cleared on next keypress.
|
||||||
pub status_error: Option<String>,
|
pub status_error: Option<String>,
|
||||||
|
/// A tool approval request waiting for user input (y/n).
|
||||||
|
pub pending_approval: Option<events::PendingApproval>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AppState {
|
impl AppState {
|
||||||
|
|
@ -85,6 +87,7 @@ impl AppState {
|
||||||
pending_keys: Vec::new(),
|
pending_keys: Vec::new(),
|
||||||
viewport_height: 0,
|
viewport_height: 0,
|
||||||
status_error: None,
|
status_error: None,
|
||||||
|
pending_approval: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -185,6 +188,17 @@ pub async fn run(
|
||||||
Some(input::LoopControl::ClearHistory) => {
|
Some(input::LoopControl::ClearHistory) => {
|
||||||
let _ = action_tx.send(UserAction::ClearHistory).await;
|
let _ = action_tx.send(UserAction::ClearHistory).await;
|
||||||
}
|
}
|
||||||
|
Some(input::LoopControl::ToolApproval {
|
||||||
|
tool_use_id,
|
||||||
|
approved,
|
||||||
|
}) => {
|
||||||
|
let _ = action_tx
|
||||||
|
.send(UserAction::ToolApprovalResponse {
|
||||||
|
tool_use_id,
|
||||||
|
approved,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
None => {}
|
None => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -105,11 +105,37 @@ pub(super) fn render(frame: &mut Frame, state: &AppState) {
|
||||||
let output = Paragraph::new(lines)
|
let output = Paragraph::new(lines)
|
||||||
.wrap(Wrap { trim: false })
|
.wrap(Wrap { trim: false })
|
||||||
.scroll((state.scroll, 0));
|
.scroll((state.scroll, 0));
|
||||||
frame.render_widget(output, chunks[0]);
|
let output_area = chunks[0];
|
||||||
|
frame.render_widget(output, output_area);
|
||||||
|
|
||||||
|
// --- Tool approval overlay ---
|
||||||
|
if let Some(ref approval) = state.pending_approval {
|
||||||
|
let overlay_w = (output_area.width / 2).max(60).min(output_area.width);
|
||||||
|
let overlay_h: u16 = 5;
|
||||||
|
let overlay_x = output_area.x + (output_area.width.saturating_sub(overlay_w)) / 2;
|
||||||
|
let overlay_y = output_area.y + output_area.height.saturating_sub(overlay_h) / 2;
|
||||||
|
let overlay_area = Rect {
|
||||||
|
x: overlay_x,
|
||||||
|
y: overlay_y,
|
||||||
|
width: overlay_w,
|
||||||
|
height: overlay_h.min(output_area.height),
|
||||||
|
};
|
||||||
|
frame.render_widget(Clear, overlay_area);
|
||||||
|
let text = format!(
|
||||||
|
"{}: {}\n\ny = approve, n = deny",
|
||||||
|
approval.tool_name, approval.input_summary
|
||||||
|
);
|
||||||
|
let overlay = Paragraph::new(text).block(
|
||||||
|
Block::bordered()
|
||||||
|
.border_style(Style::default().fg(Color::Yellow))
|
||||||
|
.title("Tool Approval"),
|
||||||
|
);
|
||||||
|
frame.render_widget(overlay, overlay_area);
|
||||||
|
}
|
||||||
|
|
||||||
// --- Command overlay (floating box centered on output pane) ---
|
// --- Command overlay (floating box centered on output pane) ---
|
||||||
if state.mode == Mode::Command {
|
if state.mode == Mode::Command {
|
||||||
let overlay_area = command_overlay_rect(chunks[0]);
|
let overlay_area = command_overlay_rect(output_area);
|
||||||
// Clear the area behind the overlay so it appears floating.
|
// Clear the area behind the overlay so it appears floating.
|
||||||
frame.render_widget(Clear, overlay_area);
|
frame.render_widget(Clear, overlay_area);
|
||||||
let overlay = Paragraph::new(format!(":{}", state.command_buffer)).block(
|
let overlay = Paragraph::new(format!(":{}", state.command_buffer)).block(
|
||||||
|
|
@ -146,7 +172,7 @@ pub(super) fn render(frame: &mut Frame, state: &AppState) {
|
||||||
}
|
}
|
||||||
Mode::Command => {
|
Mode::Command => {
|
||||||
// Cursor in the floating overlay
|
// Cursor in the floating overlay
|
||||||
let overlay = command_overlay_rect(chunks[0]);
|
let overlay = command_overlay_rect(output_area);
|
||||||
// border(1) + ":" (1) + buf len
|
// border(1) + ":" (1) + buf len
|
||||||
let cursor_x = overlay.x + 1 + 1 + state.command_buffer.len() as u16;
|
let cursor_x = overlay.x + 1 + 1 + state.command_buffer.len() as u16;
|
||||||
let cursor_y = overlay.y + 1; // inside the border
|
let cursor_y = overlay.y + 1; // inside the border
|
||||||
|
|
@ -386,4 +412,31 @@ mod tests {
|
||||||
"expected error in status bar"
|
"expected error in status bar"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn render_approval_overlay_visible() {
|
||||||
|
let backend = TestBackend::new(80, 24);
|
||||||
|
let mut terminal = Terminal::new(backend).unwrap();
|
||||||
|
let mut state = AppState::new();
|
||||||
|
state.pending_approval = Some(super::super::events::PendingApproval {
|
||||||
|
tool_use_id: "t1".to_string(),
|
||||||
|
tool_name: "write_file".to_string(),
|
||||||
|
input_summary: "path: foo.txt".to_string(),
|
||||||
|
});
|
||||||
|
terminal.draw(|frame| render(frame, &state)).unwrap();
|
||||||
|
let buf = terminal.backend().buffer().clone();
|
||||||
|
let all_text: String = buf
|
||||||
|
.content()
|
||||||
|
.iter()
|
||||||
|
.map(|c| c.symbol().to_string())
|
||||||
|
.collect();
|
||||||
|
assert!(
|
||||||
|
all_text.contains("Tool Approval"),
|
||||||
|
"expected 'Tool Approval' overlay"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
all_text.contains("write_file"),
|
||||||
|
"expected tool name in overlay"
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue