From 0c1c928498ddb6ded9b597f1c8247d7000471058 Mon Sep 17 00:00:00 2001 From: Drew Galbraith Date: Tue, 24 Feb 2026 18:31:26 -0800 Subject: [PATCH] Tool Use. --- Cargo.lock | 32 ++ Cargo.toml | 4 + PLAN.md | 76 +++- TODO.md | 2 + src/app/mod.rs | 12 +- src/core/history.rs | 26 +- src/core/orchestrator.rs | 695 +++++++++++++++++++++++++++++++++--- src/core/types.rs | 152 +++++++- src/main.rs | 1 + src/provider/claude.rs | 156 ++++++-- src/provider/mod.rs | 6 +- src/tools/list_directory.rs | 87 +++++ src/tools/mod.rs | 220 ++++++++++++ src/tools/read_file.rs | 92 +++++ src/tools/shell_exec.rs | 108 ++++++ src/tools/write_file.rs | 98 +++++ src/tui/events.rs | 86 ++++- src/tui/input.rs | 25 ++ src/tui/mod.rs | 14 + src/tui/render.rs | 59 ++- 20 files changed, 1822 insertions(+), 129 deletions(-) create mode 100644 src/tools/list_directory.rs create mode 100644 src/tools/mod.rs create mode 100644 src/tools/read_file.rs create mode 100644 src/tools/shell_exec.rs create mode 100644 src/tools/write_file.rs diff --git a/Cargo.lock b/Cargo.lock index d032920..7abfbd9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -23,6 +23,17 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "atomic" version = "0.6.1" @@ -447,6 +458,12 @@ dependencies = [ "regex", ] +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "filedescriptor" version = "0.8.3" @@ -2024,12 +2041,14 @@ name = "skate" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", "crossterm", "futures", "ratatui", "reqwest", "serde", "serde_json", + "tempfile", "thiserror 2.0.18", "tokio", "tracing", @@ -2166,6 +2185,19 @@ dependencies = [ "libc", ] +[[package]] +name = "tempfile" +version = "3.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82a72c767771b47409d2345987fda8628641887d5466101319899796367354a0" +dependencies = [ + "fastrand", + "getrandom 0.4.1", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + [[package]] name = "terminfo" version = "0.9.0" diff --git a/Cargo.toml b/Cargo.toml index 00c37e0..ff66a81 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,3 +15,7 @@ tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } reqwest = { version = "0.13", features = ["stream", "json"] } futures = "0.3" +async-trait = "0.1" + +[dev-dependencies] +tempfile = "3.26.0" diff --git a/PLAN.md b/PLAN.md index 03059df..703e7a0 100644 --- a/PLAN.md +++ b/PLAN.md @@ -1,11 +1,77 @@ # Implementation Plan ## Phase 3: Tool Execution -- `Tool` trait, `ToolRegistry`, core tools (`read_file`, `write_file`, `shell_exec`) -- Tool definitions in API requests, parse tool-use responses -- Approval gate: core -> TUI pending event -> user approve/deny -> result back -- Working directory confinement + path validation (no Landlock yet) -- **Done when:** Claude can read, modify files, and run commands with user approval + +### Step 3.1: Enrich the content model +- Replace `ConversationMessage { role, content: String }` with content-block model +- Define `ContentBlock` enum: `Text(String)`, `ToolUse { id, name, input: Value }`, `ToolResult { tool_use_id, content: String, is_error: bool }` +- Change `ConversationMessage.content` from `String` to `Vec` +- Add `ConversationMessage::text(role, s)` helper to keep existing call sites clean +- Update serialization, orchestrator, tests, TUI display +- **Files:** `src/core/types.rs`, `src/core/history.rs` +- **Done when:** `cargo test` passes with new model; all existing tests updated + +### Step 3.2: Send tool definitions in API requests +- Add `ToolDefinition { name, description, input_schema: Value }` (provider-agnostic) +- Extend `ModelProvider::stream` to accept `&[ToolDefinition]` +- Include `"tools"` array in Claude provider request body +- **Files:** `src/provider/mod.rs`, `src/provider/claude.rs` +- **Done when:** API responses contain `tool_use` content blocks in raw SSE stream + +### Step 3.3: Parse tool-use blocks from SSE stream +- Add `StreamEvent::ToolUseStart { id, name }`, `ToolUseInputDelta(String)`, `ToolUseDone` +- Handle `content_block_start` (type "tool_use"), `content_block_delta` (type "input_json_delta"), `content_block_stop` for tool blocks +- Track current block type state in SSE parser +- **Files:** `src/provider/claude.rs`, `src/core/types.rs` +- **Done when:** Unit test with recorded tool-use SSE fixture asserts correct StreamEvent sequence + +### Step 3.4: Orchestrator accumulates tool-use blocks +- Accumulate `ToolUseInputDelta` fragments into JSON buffer per tool-use id +- On `ToolUseDone`, parse JSON into `ContentBlock::ToolUse` +- After `StreamEvent::Done`, if assistant message contains ToolUse blocks, enter tool-execution phase +- **Files:** `src/core/orchestrator.rs` +- **Done when:** Unit test with mock provider emitting tool-use events produces correct ContentBlocks + +### Step 3.5: Tool trait, registry, and core tools +- `Tool` trait: `name()`, `description()`, `input_schema() -> Value`, `execute(input: Value, working_dir: &Path) -> Result` +- `ToolOutput { content: String, is_error: bool }` +- `ToolRegistry`: stores tools, provides `get(name)` and `definitions() -> Vec` +- Risk level: `AutoApprove` (reads), `RequiresApproval` (writes/shell) +- Implement: `read_file` (auto), `list_directory` (auto), `write_file` (approval), `shell_exec` (approval) +- Path validation: `canonicalize` + `starts_with` check, reject paths outside working dir (no Landlock yet) +- **Files:** New `src/tools/` module: `mod.rs`, `read_file.rs`, `write_file.rs`, `list_directory.rs`, `shell_exec.rs` +- **Done when:** Unit tests pass for each tool in temp dirs; path traversal rejected + +### Step 3.6: Approval gate (TUI <-> core) +- New `UIEvent::ToolApprovalRequest { tool_use_id, tool_name, input_summary }` +- New `UserAction::ToolApprovalResponse { tool_use_id, approved: bool }` +- Orchestrator: check risk level -> auto-approve or send approval request and await response +- Denied tools return `ToolResult { is_error: true }` with denial message +- TUI: render approval prompt overlay with y/n keybindings +- **Files:** `src/core/types.rs`, `src/core/orchestrator.rs`, `src/tui/events.rs`, `src/tui/input.rs`, `src/tui/render.rs` +- **Done when:** Integration test: mock provider + mock TUI channel verifies approval flow + +### Step 3.7: Tool results fed back to the model +- After executing tool calls: append assistant message (with ToolUse blocks) to history, append user message with ToolResult blocks, re-call provider +- Loop: model may respond with more tool calls or text +- Cap at max iterations (25) to prevent runaway +- **Files:** `src/core/orchestrator.rs` +- **Done when:** Integration test: mock provider returns tool-use then text; orchestrator makes two calls. Max-iteration cap tested. + +### Step 3.8: TUI display for tool activity +- New `UIEvent::ToolExecuting { tool_name, input_summary }`, `UIEvent::ToolResult { tool_name, output_summary, is_error }` +- Render tool calls as distinct visual blocks in conversation view +- Render tool results inline (truncated if long) +- **Files:** `src/tui/render.rs`, `src/tui/events.rs` +- **Done when:** Visual check with `cargo run`; TestBackend test for tool block rendering + +### Phase 3 verification (end-to-end) +1. `cargo test` -- all tests pass +2. `cargo clippy -- -D warnings` -- zero warnings +3. `cargo run -- --project-dir .` -- ask Claude to read a file, approve, see contents +4. Ask Claude to write a file -- approve, verify written +5. Ask Claude to run a shell command -- approve, verify output +6. Deny an approval -- Claude gets denial and responds gracefully ## Phase 4: Sandboxing - Landlock: read-only system, read-write project dir, network blocked diff --git a/TODO.md b/TODO.md index 71d90dc..945cd7f 100644 --- a/TODO.md +++ b/TODO.md @@ -1,5 +1,7 @@ # Cleanups +- Parallelize tool-use execution in `run_turn` -- requires refactoring `Orchestrator` to use `&self` + interior mutability (`Arc>` around `event_tx`, `action_rx`, `history`) so multiple futures can borrow self simultaneously via `futures::future::join_all`. + - Move keyboard/event reads in the TUI to a separate thread or async/io loop - Keep UI and orchestrator in sync (i.e. messages display out of order if you queue up many.) - `update_scroll` auto-follows in Insert mode, yanking viewport to bottom on mode switch. Only auto-follow when new content arrives (in `drain_ui_events`), not every frame. diff --git a/src/app/mod.rs b/src/app/mod.rs index 69660e4..d2c7ff0 100644 --- a/src/app/mod.rs +++ b/src/app/mod.rs @@ -27,6 +27,7 @@ use tokio::sync::mpsc; use crate::core::orchestrator::Orchestrator; use crate::core::types::{UIEvent, UserAction}; use crate::provider::ClaudeProvider; +use crate::tools::ToolRegistry; /// Model ID sent on every request. /// @@ -68,8 +69,15 @@ pub async fn run(project_dir: &Path) -> anyhow::Result<()> { let (action_tx, action_rx) = mpsc::channel::(CHANNEL_CAP); let (event_tx, event_rx) = mpsc::channel::(CHANNEL_CAP); - // -- Orchestrator (background task) ------------------------------------------- - let orch = Orchestrator::new(provider, action_rx, event_tx); + // -- Tools & Orchestrator (background task) ------------------------------------ + let tool_registry = ToolRegistry::default_tools(); + let orch = Orchestrator::new( + provider, + tool_registry, + project_dir.to_path_buf(), + action_rx, + event_tx, + ); tokio::spawn(orch.run()); // -- TUI (foreground task) ---------------------------------------------------- diff --git a/src/core/history.rs b/src/core/history.rs index fb2598a..9b8099c 100644 --- a/src/core/history.rs +++ b/src/core/history.rs @@ -61,30 +61,21 @@ mod tests { #[test] fn push_and_read_roundtrip() { let mut history = ConversationHistory::new(); - history.push(ConversationMessage { - role: Role::User, - content: "hello".to_string(), - }); - history.push(ConversationMessage { - role: Role::Assistant, - content: "hi there".to_string(), - }); + history.push(ConversationMessage::text(Role::User, "hello")); + history.push(ConversationMessage::text(Role::Assistant, "hi there")); let msgs = history.messages(); assert_eq!(msgs.len(), 2); assert_eq!(msgs[0].role, Role::User); - assert_eq!(msgs[0].content, "hello"); + assert_eq!(msgs[0].text_content(), "hello"); assert_eq!(msgs[1].role, Role::Assistant); - assert_eq!(msgs[1].content, "hi there"); + assert_eq!(msgs[1].text_content(), "hi there"); } #[test] fn clear_empties_history() { let mut history = ConversationHistory::new(); - history.push(ConversationMessage { - role: Role::User, - content: "hello".to_string(), - }); + history.push(ConversationMessage::text(Role::User, "hello")); history.clear(); assert!(history.messages().is_empty()); } @@ -93,13 +84,10 @@ mod tests { fn messages_preserves_insertion_order() { let mut history = ConversationHistory::new(); for i in 0u32..5 { - history.push(ConversationMessage { - role: Role::User, - content: format!("msg {i}"), - }); + history.push(ConversationMessage::text(Role::User, format!("msg {i}"))); } for (i, msg) in history.messages().iter().enumerate() { - assert_eq!(msg.content, format!("msg {i}")); + assert_eq!(msg.text_content(), format!("msg {i}")); } } } diff --git a/src/core/orchestrator.rs b/src/core/orchestrator.rs index de33be7..472f4bd 100644 --- a/src/core/orchestrator.rs +++ b/src/core/orchestrator.rs @@ -1,10 +1,53 @@ use futures::StreamExt; use tokio::sync::mpsc; -use tracing::debug; +use tracing::{debug, warn}; use crate::core::history::ConversationHistory; -use crate::core::types::{ConversationMessage, Role, StreamEvent, UIEvent, UserAction}; +use crate::core::types::{ + ContentBlock, ConversationMessage, Role, StreamEvent, ToolDefinition, UIEvent, UserAction, +}; use crate::provider::ModelProvider; +use crate::tools::{RiskLevel, ToolOutput, ToolRegistry}; + +/// Accumulates data for a single tool-use block while it is being streamed. +/// +/// Created on `ToolUseStart`, populated by `ToolUseInputDelta` fragments, and +/// consumed on `ToolUseDone` via `TryFrom for ContentBlock`. +struct ActiveToolUse { + id: String, + name: String, + json_buf: String, +} + +impl ActiveToolUse { + /// Start accumulating a new tool-use block identified by `id` and `name`. + fn new(id: String, name: String) -> Self { + Self { + id, + name, + json_buf: String::new(), + } + } + + /// Append a JSON input fragment received from a `ToolUseInputDelta` event. + fn append(&mut self, chunk: &str) { + self.json_buf.push_str(chunk); + } +} + +impl TryFrom for ContentBlock { + type Error = serde_json::Error; + + /// Parse the accumulated JSON buffer and produce a `ContentBlock::ToolUse`. + fn try_from(t: ActiveToolUse) -> Result { + let input = serde_json::from_str(&t.json_buf)?; + Ok(ContentBlock::ToolUse { + id: t.id, + name: t.name, + input, + }) + } +} /// Drives the conversation loop between the TUI frontend and the model provider. /// @@ -38,28 +81,321 @@ use crate::provider::ModelProvider; /// OutputTokens -> log at debug level /// 3. Quit -> return /// ``` +/// The result of consuming one provider stream (one assistant turn). +enum StreamResult { + /// The stream completed successfully with these content blocks. + Done(Vec), + /// The stream produced an error. + Error(String), +} + +/// Maximum number of tool-use loop iterations per user message. +const MAX_TOOL_ITERATIONS: usize = 25; + +/// Truncate a string to `max_len` Unicode scalar values, appending "..." if +/// truncated. Uses `char_indices` so that multibyte characters are never split +/// at a byte boundary (which would panic on a bare slice). +fn truncate(s: &str, max_len: usize) -> String { + match s.char_indices().nth(max_len) { + Some((i, _)) => format!("{}...", &s[..i]), + None => s.to_string(), + } +} + pub struct Orchestrator

{ history: ConversationHistory, provider: P, + tool_registry: ToolRegistry, + working_dir: std::path::PathBuf, action_rx: mpsc::Receiver, event_tx: mpsc::Sender, + /// Messages typed by the user while an approval prompt is open. They are + /// queued here and replayed as new turns once the current turn completes. + queued_messages: Vec, } impl Orchestrator

{ /// Construct an orchestrator using the given provider and channel endpoints. pub fn new( provider: P, + tool_registry: ToolRegistry, + working_dir: std::path::PathBuf, action_rx: mpsc::Receiver, event_tx: mpsc::Sender, ) -> Self { Self { history: ConversationHistory::new(), provider, + tool_registry, + working_dir, action_rx, event_tx, + queued_messages: Vec::new(), } } + /// Consume the provider's stream for one turn, returning the content blocks + /// produced by the assistant (text and/or tool-use) or an error. + /// + /// Tool-use input JSON is accumulated from `ToolUseInputDelta` fragments and + /// parsed into `ContentBlock::ToolUse` on `ToolUseDone`. + async fn consume_stream( + &self, + messages: &[ConversationMessage], + tools: &[ToolDefinition], + ) -> StreamResult { + let mut blocks: Vec = Vec::new(); + let mut text_buf = String::new(); + let mut active_tool: Option = None; + + let mut stream = Box::pin(self.provider.stream(messages, tools)); + + while let Some(event) = stream.next().await { + match event { + StreamEvent::TextDelta(chunk) => { + text_buf.push_str(&chunk); + let _ = self.event_tx.send(UIEvent::StreamDelta(chunk)).await; + } + StreamEvent::ToolUseStart { id, name } => { + // Flush any accumulated text before starting tool block. + if !text_buf.is_empty() { + blocks.push(ContentBlock::Text { + text: std::mem::take(&mut text_buf), + }); + } + active_tool = Some(ActiveToolUse::new(id, name)); + } + StreamEvent::ToolUseInputDelta(chunk) => { + if let Some(t) = &mut active_tool { + t.append(&chunk); + } else { + warn!( + "received ToolUseInputDelta outside of an active tool-use block -- ignoring" + ); + } + } + StreamEvent::ToolUseDone => { + if let Some(t) = active_tool.take() { + match ContentBlock::try_from(t) { + Ok(block) => blocks.push(block), + Err(e) => { + return StreamResult::Error(format!( + "failed to parse tool input JSON: {e}" + )); + } + } + } else { + warn!( + "received ToolUseDone outside of an active tool-use block -- ignoring" + ); + } + } + StreamEvent::Done => { + // Flush trailing text. + if !text_buf.is_empty() { + blocks.push(ContentBlock::Text { text: text_buf }); + } + return StreamResult::Done(blocks); + } + StreamEvent::Error(msg) => { + return StreamResult::Error(msg); + } + StreamEvent::InputTokens(n) => { + debug!(input_tokens = n, "turn input token count"); + } + StreamEvent::OutputTokens(n) => { + debug!(output_tokens = n, "turn output token count"); + } + } + } + + // Stream ended without Done -- treat as error. + StreamResult::Error("stream ended unexpectedly".to_string()) + } + + /// Execute one complete turn: stream from provider, execute tools if needed, + /// loop back until the model produces a text-only response or we hit the + /// iteration limit. + async fn run_turn(&mut self) { + let tool_defs = self.tool_registry.definitions(); + + for _ in 0..MAX_TOOL_ITERATIONS { + let messages = self.history.messages().to_vec(); + let result = self.consume_stream(&messages, &tool_defs).await; + + match result { + StreamResult::Error(msg) => { + let _ = self.event_tx.send(UIEvent::Error(msg)).await; + return; + } + StreamResult::Done(blocks) => { + let has_tool_use = blocks + .iter() + .any(|b| matches!(b, ContentBlock::ToolUse { .. })); + + // Append the assistant message to history. + self.history.push(ConversationMessage { + role: Role::Assistant, + content: blocks.clone(), + }); + + if !has_tool_use { + let _ = self.event_tx.send(UIEvent::TurnComplete).await; + return; + } + + // TODO: execute tool-use blocks in parallel -- requires + // refactoring Orchestrator to use &self + interior + // mutability (Arc> around event_tx, action_rx, + // history) so that multiple futures can borrow self + // simultaneously via futures::future::join_all. + // Execute each tool-use block and collect results. + let mut tool_results: Vec = Vec::new(); + + for block in &blocks { + if let ContentBlock::ToolUse { id, name, input } = block { + let result = self.execute_tool_with_approval(id, name, input).await; + tool_results.push(ContentBlock::ToolResult { + tool_use_id: id.clone(), + content: result.content, + is_error: result.is_error, + }); + } + } + + // Append tool results as a user message and loop. + self.history.push(ConversationMessage { + role: Role::User, + content: tool_results, + }); + } + } + } + warn!("tool-use loop reached max iterations ({MAX_TOOL_ITERATIONS})"); + let _ = self + .event_tx + .send(UIEvent::Error( + "tool-use loop reached maximum iterations".to_string(), + )) + .await; + } + + /// Execute a single tool, handling approval if needed. + /// + /// For auto-approve tools, executes immediately. For tools requiring + /// approval, sends a request to the TUI and waits for a response. + async fn execute_tool_with_approval( + &mut self, + tool_use_id: &str, + tool_name: &str, + input: &serde_json::Value, + ) -> ToolOutput { + // Extract tool info upfront to avoid holding a borrow on self across + // the mutable wait_for_approval call. + let risk = match self.tool_registry.get(tool_name) { + Some(t) => t.risk_level(), + None => { + return ToolOutput { + content: format!("unknown tool: {tool_name}"), + is_error: true, + }; + } + }; + + let input_summary = serde_json::to_string(input).unwrap_or_default(); + + // Check approval. + let approved = match risk { + RiskLevel::AutoApprove => { + let _ = self + .event_tx + .send(UIEvent::ToolExecuting { + tool_name: tool_name.to_string(), + input_summary: input_summary.clone(), + }) + .await; + true + } + RiskLevel::RequiresApproval => { + let _ = self + .event_tx + .send(UIEvent::ToolApprovalRequest { + tool_use_id: tool_use_id.to_string(), + tool_name: tool_name.to_string(), + input_summary: input_summary.clone(), + }) + .await; + + // Wait for approval response from TUI. + self.wait_for_approval(tool_use_id).await + } + }; + + if !approved { + return ToolOutput { + content: "tool execution denied by user".to_string(), + is_error: true, + }; + } + + // Re-fetch tool for execution (borrow was released above). + let tool = self.tool_registry.get(tool_name).unwrap(); + match tool.execute(input, &self.working_dir).await { + Ok(output) => { + let _ = self + .event_tx + .send(UIEvent::ToolResult { + tool_name: tool_name.to_string(), + output_summary: truncate(&output.content, 200), + is_error: output.is_error, + }) + .await; + output + } + Err(e) => { + let msg = e.to_string(); + let _ = self + .event_tx + .send(UIEvent::ToolResult { + tool_name: tool_name.to_string(), + output_summary: msg.clone(), + is_error: true, + }) + .await; + ToolOutput { + content: msg, + is_error: true, + } + } + } + } + + /// Wait for a `ToolApprovalResponse` matching `tool_use_id` from the TUI. + /// + /// Any `SendMessage` actions that arrive during the wait are pushed onto + /// `self.queued_messages` rather than discarded. They are replayed as new + /// turns once the current tool-use loop completes (see `run()`). + /// + /// Returns `true` if approved, `false` if denied or the channel closes. + async fn wait_for_approval(&mut self, tool_use_id: &str) -> bool { + while let Some(action) = self.action_rx.recv().await { + match action { + UserAction::ToolApprovalResponse { + tool_use_id: id, + approved, + } if id == tool_use_id => { + return approved; + } + UserAction::Quit => return false, + UserAction::SendMessage(text) => { + self.queued_messages.push(text); + } + _ => {} // discard stale approvals / ClearHistory during wait + } + } + false + } + /// Run the orchestrator until the user quits or the `action_rx` channel closes. pub async fn run(mut self) { while let Some(action) = self.action_rx.recv().await { @@ -69,62 +405,24 @@ impl Orchestrator

{ self.history.clear(); } + // Approval responses are handled inline during tool execution, + // not in the main action loop. If one arrives here it's stale. + UserAction::ToolApprovalResponse { .. } => {} + UserAction::SendMessage(text) => { - // Push the user message before snapshotting, so providers - // see the full conversation including the new message. - self.history.push(ConversationMessage { - role: Role::User, - content: text, - }); + self.history + .push(ConversationMessage::text(Role::User, text)); + self.run_turn().await; - // Snapshot history into an owned Vec so the stream does not - // borrow from `self.history` -- this lets us mutably update - // `self.history` once the stream loop finishes. - let messages: Vec = self.history.messages().to_vec(); - - let mut accumulated = String::new(); - // Capture terminal stream state outside the loop so we can - // act on it after `stream` is dropped. - let mut turn_done = false; - let mut turn_error: Option = None; - - { - let mut stream = Box::pin(self.provider.stream(&messages)); - - while let Some(event) = stream.next().await { - match event { - StreamEvent::TextDelta(chunk) => { - accumulated.push_str(&chunk); - let _ = self.event_tx.send(UIEvent::StreamDelta(chunk)).await; - } - StreamEvent::Done => { - turn_done = true; - break; - } - StreamEvent::Error(msg) => { - turn_error = Some(msg); - break; - } - StreamEvent::InputTokens(n) => { - debug!(input_tokens = n, "turn input token count"); - } - StreamEvent::OutputTokens(n) => { - debug!(output_tokens = n, "turn output token count"); - } - } + // Drain any messages queued while an approval prompt was + // open. Each queued message is a full turn in sequence. + while !self.queued_messages.is_empty() { + let queued = std::mem::take(&mut self.queued_messages); + for msg in queued { + self.history + .push(ConversationMessage::text(Role::User, msg)); + self.run_turn().await; } - // `stream` is dropped here, releasing the borrow on - // `self.provider` and `messages`. - } - - if turn_done { - self.history.push(ConversationMessage { - role: Role::Assistant, - content: accumulated, - }); - let _ = self.event_tx.send(UIEvent::TurnComplete).await; - } else if let Some(msg) = turn_error { - let _ = self.event_tx.send(UIEvent::Error(msg)).await; } } } @@ -135,6 +433,8 @@ impl Orchestrator

{ #[cfg(test)] mod tests { use super::*; + use crate::core::types::ToolDefinition; + use crate::tools::ToolRegistry; use futures::Stream; use tokio::sync::mpsc; @@ -155,11 +455,27 @@ mod tests { fn stream<'a>( &'a self, _messages: &'a [ConversationMessage], + _tools: &'a [ToolDefinition], ) -> impl Stream + Send + 'a { futures::stream::iter(self.events.clone()) } } + /// Create an Orchestrator with no tools for testing text-only flows. + fn test_orchestrator( + provider: P, + action_rx: mpsc::Receiver, + event_tx: mpsc::Sender, + ) -> Orchestrator

{ + Orchestrator::new( + provider, + ToolRegistry::empty(), + std::path::PathBuf::from("/tmp"), + action_rx, + event_tx, + ) + } + /// Collect all UIEvents that arrive within one orchestrator turn, stopping /// when the channel is drained after a `TurnComplete` or `Error`. async fn collect_events(rx: &mut mpsc::Receiver) -> Vec { @@ -195,7 +511,7 @@ mod tests { let (action_tx, action_rx) = mpsc::channel::(8); let (event_tx, mut event_rx) = mpsc::channel::(16); - let orch = Orchestrator::new(provider, action_rx, event_tx); + let orch = test_orchestrator(provider, action_rx, event_tx); let handle = tokio::spawn(orch.run()); action_tx @@ -233,7 +549,7 @@ mod tests { let (action_tx, action_rx) = mpsc::channel::(8); let (event_tx, mut event_rx) = mpsc::channel::(16); - let orch = Orchestrator::new(provider, action_rx, event_tx); + let orch = test_orchestrator(provider, action_rx, event_tx); let handle = tokio::spawn(orch.run()); action_tx @@ -264,6 +580,7 @@ mod tests { fn stream<'a>( &'a self, _messages: &'a [ConversationMessage], + _tools: &'a [ToolDefinition], ) -> impl Stream + Send + 'a { panic!("stream() must not be called after Quit"); #[allow(unreachable_code)] @@ -274,7 +591,7 @@ mod tests { let (action_tx, action_rx) = mpsc::channel::(8); let (event_tx, _event_rx) = mpsc::channel::(8); - let orch = Orchestrator::new(NeverCalledProvider, action_rx, event_tx); + let orch = test_orchestrator(NeverCalledProvider, action_rx, event_tx); let handle = tokio::spawn(orch.run()); action_tx.send(UserAction::Quit).await.unwrap(); @@ -311,6 +628,7 @@ mod tests { fn stream<'a>( &'a self, _messages: &'a [ConversationMessage], + _tools: &'a [ToolDefinition], ) -> impl Stream + Send + 'a { let events = self.turns.lock().unwrap().pop_front().unwrap_or_default(); futures::stream::iter(events) @@ -326,7 +644,7 @@ mod tests { let (action_tx, action_rx) = mpsc::channel::(8); let (event_tx, mut event_rx) = mpsc::channel::(32); - let orch = Orchestrator::new(provider, action_rx, event_tx); + let orch = test_orchestrator(provider, action_rx, event_tx); let handle = tokio::spawn(orch.run()); // First turn. @@ -350,4 +668,263 @@ mod tests { action_tx.send(UserAction::Quit).await.unwrap(); handle.await.unwrap(); } + + // -- tool-use accumulation ---------------------------------------------------- + + /// When the provider emits tool-use events, the orchestrator executes the + /// tool (auto-approve since read_file is AutoApprove), feeds the result back, + /// and the provider's second call returns text. + #[tokio::test] + async fn tool_use_loop_executes_and_feeds_back() { + use std::collections::VecDeque; + use std::sync::{Arc, Mutex}; + + struct MultiCallMock { + turns: Arc>>>, + } + impl ModelProvider for MultiCallMock { + fn stream<'a>( + &'a self, + _messages: &'a [ConversationMessage], + _tools: &'a [ToolDefinition], + ) -> impl Stream + Send + 'a { + let events = self.turns.lock().unwrap().pop_front().unwrap_or_default(); + futures::stream::iter(events) + } + } + + let turns = Arc::new(Mutex::new(VecDeque::from([ + // First call: tool use + vec![ + StreamEvent::TextDelta("Let me read that.".to_string()), + StreamEvent::ToolUseDone, // spurious stop for text block + StreamEvent::ToolUseStart { + id: "toolu_01".to_string(), + name: "read_file".to_string(), + }, + StreamEvent::ToolUseInputDelta("{\"path\":\"Cargo.toml\"}".to_string()), + StreamEvent::ToolUseDone, + StreamEvent::Done, + ], + // Second call: text response after tool result + vec![ + StreamEvent::TextDelta("Here's the file.".to_string()), + StreamEvent::Done, + ], + ]))); + + let (action_tx, action_rx) = mpsc::channel::(8); + let (event_tx, mut event_rx) = mpsc::channel::(32); + + // Use a real ToolRegistry so read_file works. + let dir = tempfile::TempDir::new().unwrap(); + std::fs::write(dir.path().join("Cargo.toml"), "[package]\nname = \"test\"").unwrap(); + let orch = Orchestrator::new( + MultiCallMock { turns }, + ToolRegistry::default_tools(), + dir.path().to_path_buf(), + action_rx, + event_tx, + ); + let handle = tokio::spawn(orch.run()); + + action_tx + .send(UserAction::SendMessage("read Cargo.toml".to_string())) + .await + .unwrap(); + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + + let events = collect_events(&mut event_rx).await; + // Should see text deltas from both calls, tool events, and TurnComplete + assert!( + events + .iter() + .any(|e| matches!(e, UIEvent::StreamDelta(s) if s == "Let me read that.")), + "expected first text delta" + ); + assert!( + events + .iter() + .any(|e| matches!(e, UIEvent::StreamDelta(s) if s == "Here's the file.")), + "expected second text delta" + ); + assert!( + matches!(events.last(), Some(UIEvent::TurnComplete)), + "expected TurnComplete, got: {events:?}" + ); + + action_tx.send(UserAction::Quit).await.unwrap(); + handle.await.unwrap(); + } + + // -- truncate ----------------------------------------------------------------- + + /// Truncating at a boundary inside a multibyte character must not panic. + /// "hello " is 6 bytes; the emoji is 4 bytes. Requesting max_len=8 would + /// previously slice at byte 8, splitting the emoji. The fixed version uses + /// char_indices so the cut falls on a char boundary. + #[test] + fn truncate_multibyte_does_not_panic() { + let s = "hello \u{1F30D} world"; // globe emoji is 4 bytes + // Should not panic and the result should end with "..." + let result = truncate(s, 8); + assert!( + result.ends_with("..."), + "expected truncation suffix: {result}" + ); + // The string before "..." must be valid UTF-8 (guaranteed by the fix). + let prefix = result.trim_end_matches("..."); + assert!(std::str::from_utf8(prefix.as_bytes()).is_ok()); + } + + // -- JSON parse error --------------------------------------------------------- + + /// When the provider emits malformed tool-input JSON, the orchestrator must + /// surface a descriptive UIEvent::Error rather than silently using Null. + #[tokio::test] + async fn malformed_tool_json_surfaces_error() { + let provider = MockProvider::new(vec![ + StreamEvent::ToolUseStart { + id: "toolu_bad".to_string(), + name: "read_file".to_string(), + }, + StreamEvent::ToolUseInputDelta("not valid json{{".to_string()), + StreamEvent::ToolUseDone, + StreamEvent::Done, + ]); + + let (action_tx, action_rx) = mpsc::channel::(8); + let (event_tx, mut event_rx) = mpsc::channel::(16); + + let orch = test_orchestrator(provider, action_rx, event_tx); + let handle = tokio::spawn(orch.run()); + + action_tx + .send(UserAction::SendMessage("go".to_string())) + .await + .unwrap(); + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + + let events = collect_events(&mut event_rx).await; + assert!( + events + .iter() + .any(|e| matches!(e, UIEvent::Error(msg) if msg.contains("tool input JSON"))), + "expected JSON parse error event, got: {events:?}" + ); + + action_tx.send(UserAction::Quit).await.unwrap(); + handle.await.unwrap(); + } + + // -- queued messages ---------------------------------------------------------- + + /// A SendMessage sent while an approval prompt is open must be processed + /// after the current turn completes, not silently dropped. + /// + /// This test uses a RequiresApproval tool (write_file) so that the + /// orchestrator blocks in wait_for_approval. While blocked, we send a second + /// message. After approving (denied here for simplicity), the queued message + /// should still be processed. + #[tokio::test] + async fn send_message_during_approval_is_queued_and_processed() { + use std::collections::VecDeque; + use std::sync::{Arc, Mutex}; + + struct MultiCallMock { + turns: Arc>>>, + } + impl ModelProvider for MultiCallMock { + fn stream<'a>( + &'a self, + _messages: &'a [ConversationMessage], + _tools: &'a [ToolDefinition], + ) -> impl Stream + Send + 'a { + let events = self.turns.lock().unwrap().pop_front().unwrap_or_default(); + futures::stream::iter(events) + } + } + + let turns = Arc::new(Mutex::new(VecDeque::from([ + // Turn 1: requests write_file (RequiresApproval) + vec![ + StreamEvent::ToolUseStart { + id: "toolu_w".to_string(), + name: "write_file".to_string(), + }, + StreamEvent::ToolUseInputDelta( + "{\"path\":\"x.txt\",\"content\":\"hi\"}".to_string(), + ), + StreamEvent::ToolUseDone, + StreamEvent::Done, + ], + // Turn 1 second iteration after tool result (denied) + vec![ + StreamEvent::TextDelta("ok denied".to_string()), + StreamEvent::Done, + ], + // Turn 2: the queued message + vec![ + StreamEvent::TextDelta("queued reply".to_string()), + StreamEvent::Done, + ], + ]))); + + let (action_tx, action_rx) = mpsc::channel::(16); + let (event_tx, mut event_rx) = mpsc::channel::(64); + + let dir = tempfile::TempDir::new().unwrap(); + let orch = Orchestrator::new( + MultiCallMock { turns }, + ToolRegistry::default_tools(), + dir.path().to_path_buf(), + action_rx, + event_tx, + ); + let handle = tokio::spawn(orch.run()); + + // Start turn 1 -- orchestrator will block on approval. + action_tx + .send(UserAction::SendMessage("turn1".to_string())) + .await + .unwrap(); + + // Let the orchestrator reach wait_for_approval. + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + + // Send a message while blocked -- should be queued. + action_tx + .send(UserAction::SendMessage("queued".to_string())) + .await + .unwrap(); + + // Deny the tool -- unblocks wait_for_approval. + action_tx + .send(UserAction::ToolApprovalResponse { + tool_use_id: "toolu_w".to_string(), + approved: false, + }) + .await + .unwrap(); + + // Wait for both turns to complete. + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + + // Collect everything. + let mut all_events = Vec::new(); + while let Ok(ev) = event_rx.try_recv() { + all_events.push(ev); + } + + // The queued message must have produced "queued reply". + assert!( + all_events + .iter() + .any(|e| matches!(e, UIEvent::StreamDelta(s) if s == "queued reply")), + "queued message was not processed; events: {all_events:?}" + ); + + action_tx.send(UserAction::Quit).await.unwrap(); + handle.await.unwrap(); + } } diff --git a/src/core/types.rs b/src/core/types.rs index b0afd0f..65e3f43 100644 --- a/src/core/types.rs +++ b/src/core/types.rs @@ -1,8 +1,16 @@ +use serde::{Deserialize, Serialize}; + /// A streaming event emitted by the model provider. #[derive(Debug, Clone)] pub enum StreamEvent { /// A text chunk from the assistant's response. TextDelta(String), + /// A new tool-use content block has started. + ToolUseStart { id: String, name: String }, + /// A chunk of the tool-use input JSON (streamed incrementally). + ToolUseInputDelta(String), + /// The current tool-use content block has ended. + ToolUseDone, /// Number of input tokens used in this request. InputTokens(u32), /// Number of output tokens generated so far. @@ -18,6 +26,8 @@ pub enum StreamEvent { pub enum UserAction { /// The user has submitted a message. SendMessage(String), + /// The user has responded to a tool approval request. + ToolApprovalResponse { tool_use_id: String, approved: bool }, /// The user has requested to quit. Quit, /// The user has requested to clear conversation history. @@ -29,6 +39,23 @@ pub enum UserAction { pub enum UIEvent { /// A text chunk to append to the current assistant message. StreamDelta(String), + /// A tool requires user approval before execution. + ToolApprovalRequest { + tool_use_id: String, + tool_name: String, + input_summary: String, + }, + /// A tool is being executed (informational, after approval or auto-approve). + ToolExecuting { + tool_name: String, + input_summary: String, + }, + /// A tool has finished executing. + ToolResult { + tool_name: String, + output_summary: String, + is_error: bool, + }, /// The current assistant turn has completed. TurnComplete, /// An error to display to the user. @@ -36,7 +63,7 @@ pub enum UIEvent { } /// The role of a participant in a conversation. -#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum Role { /// A message from the human user. @@ -45,11 +72,128 @@ pub enum Role { Assistant, } +/// A tool definition sent to the model so it knows which tools are available. +/// +/// This is provider-agnostic -- the provider serializes it into the +/// format required by its API (e.g. the `tools` array for Anthropic). +#[derive(Debug, Clone, Serialize)] +pub struct ToolDefinition { + /// The tool name the model will use in `tool_use` blocks. + pub name: String, + /// Human-readable description of what the tool does. + pub description: String, + /// JSON Schema describing the tool's input parameters. + pub input_schema: serde_json::Value, +} + +/// A typed content block within a conversation message. +/// +/// The Anthropic Messages API represents message content as an array of typed +/// blocks. A single assistant message can contain interleaved text and tool-use +/// blocks; a user message following tool execution contains tool-result blocks. +/// +/// See the [Messages API content blocks reference][content-blocks]. +/// +/// [content-blocks]: https://docs.anthropic.com/en/api/messages +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ContentBlock { + /// Plain text content. + Text { text: String }, + /// A tool invocation requested by the assistant. + ToolUse { + id: String, + name: String, + input: serde_json::Value, + }, + /// The result of executing a tool, sent back in a user-role message. + ToolResult { + tool_use_id: String, + content: String, + #[serde(default)] + is_error: bool, + }, +} + /// A single message in the conversation history. -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +/// +/// Content is stored as a `Vec` to support mixed text and +/// tool-use blocks in a single message. For simple text-only messages, use +/// the [`ConversationMessage::text`] constructor. +/// +/// # Serialization +/// +/// Serializes `content` as a plain string when the message contains exactly +/// one `Text` block (the common case), and as an array of typed blocks +/// otherwise. This matches what the Anthropic API expects for both simple +/// and tool-use messages. +#[derive(Debug, Clone)] pub struct ConversationMessage { /// The role of the message author. pub role: Role, - /// The text content of the message. - pub content: String, + /// The content blocks of this message. + pub content: Vec, +} + +impl ConversationMessage { + /// Create a simple text-only message. + pub fn text(role: Role, s: impl Into) -> Self { + Self { + role, + content: vec![ContentBlock::Text { text: s.into() }], + } + } + + /// Extract the concatenated text content, ignoring non-text blocks. + #[allow(dead_code)] + pub fn text_content(&self) -> String { + let mut out = String::new(); + for block in &self.content { + if let ContentBlock::Text { text } = block { + out.push_str(text); + } + } + out + } +} + +impl Serialize for ConversationMessage { + fn serialize(&self, serializer: S) -> Result { + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(2))?; + map.serialize_entry("role", &self.role)?; + // Single text block -> serialize as plain string (common case). + // Otherwise -> serialize as array of content blocks. + match self.content.as_slice() { + [ContentBlock::Text { text }] => map.serialize_entry("content", text)?, + blocks => map.serialize_entry("content", blocks)?, + } + map.end() + } +} + +impl<'de> Deserialize<'de> for ConversationMessage { + fn deserialize>(deserializer: D) -> Result { + #[derive(Deserialize)] + struct Raw { + role: Role, + content: serde_json::Value, + } + let raw = Raw::deserialize(deserializer)?; + let content = match raw.content { + serde_json::Value::String(s) => vec![ContentBlock::Text { text: s }], + serde_json::Value::Array(_) => { + serde_json::from_value(raw.content).map_err(serde::de::Error::custom)? + } + other => { + return Err(serde::de::Error::custom(format!( + "expected string or array for content, got {other}" + ))); + } + }; + Ok(ConversationMessage { + role: raw.role, + content, + }) + } } diff --git a/src/main.rs b/src/main.rs index 6e88668..237a581 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ mod app; mod core; mod provider; +mod tools; mod tui; use std::path::PathBuf; diff --git a/src/provider/claude.rs b/src/provider/claude.rs index 46810dd..72ef1be 100644 --- a/src/provider/claude.rs +++ b/src/provider/claude.rs @@ -2,7 +2,7 @@ use futures::{SinkExt, Stream, StreamExt}; use reqwest::Client; use serde::Deserialize; -use crate::core::types::{ConversationMessage, StreamEvent}; +use crate::core::types::{ConversationMessage, StreamEvent, ToolDefinition}; use super::ModelProvider; @@ -64,15 +64,17 @@ impl ModelProvider for ClaudeProvider { fn stream<'a>( &'a self, messages: &'a [ConversationMessage], + tools: &'a [ToolDefinition], ) -> impl Stream + Send + 'a { let (mut tx, rx) = futures::channel::mpsc::channel(32); let client = self.client.clone(); let api_key = self.api_key.clone(); let model = self.model.clone(); let messages = messages.to_vec(); + let tools = tools.to_vec(); tokio::spawn(async move { - run_stream(client, api_key, model, messages, &mut tx).await; + run_stream(client, api_key, model, messages, tools, &mut tx).await; }); rx @@ -126,14 +128,18 @@ async fn run_stream( api_key: String, model: String, messages: Vec, + tools: Vec, tx: &mut futures::channel::mpsc::Sender, ) { - let body = serde_json::json!({ + let mut body = serde_json::json!({ "model": model, "max_tokens": 8192, "stream": true, "messages": messages, }); + if !tools.is_empty() { + body["tools"] = serde_json::to_value(&tools).unwrap_or_default(); + } let response = match client .post("https://api.anthropic.com/v1/messages") @@ -223,6 +229,8 @@ fn find_double_newline(buf: &[u8]) -> Option { struct SseEvent { #[serde(rename = "type")] event_type: String, + /// Present on `content_block_start` events; describes the new block. + content_block: Option, /// Present on `content_block_delta` events. delta: Option, /// Present on `message_start` events; carries initial token usage. @@ -231,16 +239,29 @@ struct SseEvent { usage: Option, } +/// The `content_block` object inside a `content_block_start` event. +#[derive(Deserialize, Debug)] +struct SseContentBlock { + #[serde(rename = "type")] + block_type: String, + /// Tool-use block ID; present when `block_type == "tool_use"`. + id: Option, + /// Tool name; present when `block_type == "tool_use"`. + name: Option, +} + /// The `delta` object inside a `content_block_delta` event. /// -/// `type` is `"text_delta"` for plain text chunks; other delta types -/// (e.g. `"input_json_delta"` for tool-use blocks) are not yet handled. +/// `type` is `"text_delta"` for plain text chunks, or `"input_json_delta"` +/// for streaming tool-use input JSON. #[derive(Deserialize, Debug)] struct SseDelta { #[serde(rename = "type")] delta_type: Option, /// The text chunk; present when `delta_type == "text_delta"`. text: Option, + /// Partial JSON for tool input; present when `delta_type == "input_json_delta"`. + partial_json: Option, } /// The `message` object inside a `message_start` event. @@ -273,13 +294,20 @@ struct SseUsage { /// /// # Mapping to [`StreamEvent`] /// -/// | API event type | JSON path | Emits | -/// |----------------------|------------------------------------|------------------------------| -/// | `message_start` | `.message.usage.input_tokens` | `InputTokens(n)` | -/// | `content_block_delta`| `.delta.type == "text_delta"` | `TextDelta(chunk)` | -/// | `message_delta` | `.usage.output_tokens` | `OutputTokens(n)` | -/// | `message_stop` | n/a | `Done` | -/// | everything else | n/a | `None` (caller skips) | +/// | API event type | JSON path | Emits | +/// |------------------------|------------------------------------|------------------------------| +/// | `message_start` | `.message.usage.input_tokens` | `InputTokens(n)` | +/// | `content_block_start` | `.content_block.type == "tool_use"`| `ToolUseStart { id, name }` | +/// | `content_block_delta` | `.delta.type == "text_delta"` | `TextDelta(chunk)` | +/// | `content_block_delta` | `.delta.type == "input_json_delta"`| `ToolUseInputDelta(json)` | +/// | `content_block_stop` | n/a | `ToolUseDone` | +/// | `message_delta` | `.usage.output_tokens` | `OutputTokens(n)` | +/// | `message_stop` | n/a | `Done` | +/// | everything else | n/a | `None` (caller skips) | +/// +/// Note: `content_block_stop` always emits `ToolUseDone`. The caller is +/// responsible for tracking whether a tool-use block is active and ignoring +/// `ToolUseDone` when it follows a text block. /// /// [sse-fields]: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream fn parse_sse_event(event_str: &str) -> Option { @@ -297,15 +325,29 @@ fn parse_sse_event(event_str: &str) -> Option { .and_then(|u| u.input_tokens) .map(StreamEvent::InputTokens), - "content_block_delta" => { - let delta = event.delta?; - if delta.delta_type.as_deref() == Some("text_delta") { - delta.text.map(StreamEvent::TextDelta) + "content_block_start" => { + let block = event.content_block?; + if block.block_type == "tool_use" { + Some(StreamEvent::ToolUseStart { + id: block.id.unwrap_or_default(), + name: block.name.unwrap_or_default(), + }) } else { None } } + "content_block_delta" => { + let delta = event.delta?; + match delta.delta_type.as_deref() { + Some("text_delta") => delta.text.map(StreamEvent::TextDelta), + Some("input_json_delta") => delta.partial_json.map(StreamEvent::ToolUseInputDelta), + _ => None, + } + } + + "content_block_stop" => Some(StreamEvent::ToolUseDone), + // usage lives at the top level of message_delta, not inside delta. "message_delta" => event .usage @@ -314,8 +356,7 @@ fn parse_sse_event(event_str: &str) -> Option { "message_stop" => Some(StreamEvent::Done), - // error, ping, content_block_start, content_block_stop -- ignored or - // handled by the caller. + // error, ping -- ignored. _ => None, } } @@ -371,13 +412,15 @@ mod tests { .filter_map(parse_sse_event) .collect(); - // content_block_start, ping, content_block_stop -> None (filtered out) - assert_eq!(events.len(), 5); + // content_block_start (text) and ping -> None (filtered out) + // content_block_stop -> ToolUseDone (always emitted; caller filters) + assert_eq!(events.len(), 6); assert!(matches!(events[0], StreamEvent::InputTokens(10))); assert!(matches!(&events[1], StreamEvent::TextDelta(s) if s == "Hello")); assert!(matches!(&events[2], StreamEvent::TextDelta(s) if s == ", world!")); - assert!(matches!(events[3], StreamEvent::OutputTokens(5))); - assert!(matches!(events[4], StreamEvent::Done)); + assert!(matches!(events[3], StreamEvent::ToolUseDone)); + assert!(matches!(events[4], StreamEvent::OutputTokens(5))); + assert!(matches!(events[5], StreamEvent::Done)); } #[test] @@ -408,14 +451,8 @@ mod tests { #[test] fn test_messages_serialize_to_anthropic_format() { let messages = vec![ - ConversationMessage { - role: Role::User, - content: "Hello".to_string(), - }, - ConversationMessage { - role: Role::Assistant, - content: "Hi there!".to_string(), - }, + ConversationMessage::text(Role::User, "Hello"), + ConversationMessage::text(Role::Assistant, "Hi there!"), ]; let json = serde_json::json!({ @@ -433,6 +470,65 @@ mod tests { assert!(json["max_tokens"].as_u64().unwrap() > 0); } + /// SSE fixture for a response that contains a tool-use block. + /// + /// Sequence: message_start -> content_block_start (tool_use) -> + /// content_block_delta (input_json_delta) -> content_block_stop -> + /// message_delta -> message_stop + const TOOL_USE_SSE_FIXTURE: &str = concat!( + "event: message_start\n", + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_456\",\"type\":\"message\",", + "\"role\":\"assistant\",\"content\":[],\"model\":\"claude-opus-4-6\",", + "\"stop_reason\":null,\"stop_sequence\":null,", + "\"usage\":{\"input_tokens\":25,\"cache_creation_input_tokens\":0,", + "\"cache_read_input_tokens\":0,\"output_tokens\":1}}}\n", + "\n", + "event: content_block_start\n", + "data: {\"type\":\"content_block_start\",\"index\":0,", + "\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_01ABC\",\"name\":\"read_file\",\"input\":{}}}\n", + "\n", + "event: content_block_delta\n", + "data: {\"type\":\"content_block_delta\",\"index\":0,", + "\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"path\\\": \\\"src/\"}}\n", + "\n", + "event: content_block_delta\n", + "data: {\"type\":\"content_block_delta\",\"index\":0,", + "\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"main.rs\\\"}\"}}\n", + "\n", + "event: content_block_stop\n", + "data: {\"type\":\"content_block_stop\",\"index\":0}\n", + "\n", + "event: message_delta\n", + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",", + "\"stop_sequence\":null},\"usage\":{\"output_tokens\":15}}\n", + "\n", + "event: message_stop\n", + "data: {\"type\":\"message_stop\"}\n", + "\n", + ); + + #[test] + fn test_parse_tool_use_sse_fixture() { + let events: Vec = TOOL_USE_SSE_FIXTURE + .split("\n\n") + .filter(|s| !s.trim().is_empty()) + .filter_map(parse_sse_event) + .collect(); + + assert_eq!(events.len(), 7); + assert!(matches!(events[0], StreamEvent::InputTokens(25))); + assert!( + matches!(&events[1], StreamEvent::ToolUseStart { id, name } if id == "toolu_01ABC" && name == "read_file") + ); + assert!( + matches!(&events[2], StreamEvent::ToolUseInputDelta(s) if s == "{\"path\": \"src/") + ); + assert!(matches!(&events[3], StreamEvent::ToolUseInputDelta(s) if s == "main.rs\"}")); + assert!(matches!(events[4], StreamEvent::ToolUseDone)); + assert!(matches!(events[5], StreamEvent::OutputTokens(15))); + assert!(matches!(events[6], StreamEvent::Done)); + } + #[test] fn test_find_double_newline() { assert_eq!(find_double_newline(b"abc\n\ndef"), Some(3)); diff --git a/src/provider/mod.rs b/src/provider/mod.rs index d2cba08..844e8c2 100644 --- a/src/provider/mod.rs +++ b/src/provider/mod.rs @@ -4,16 +4,18 @@ pub use claude::ClaudeProvider; use futures::Stream; -use crate::core::types::{ConversationMessage, StreamEvent}; +use crate::core::types::{ConversationMessage, StreamEvent, ToolDefinition}; /// Trait for model providers that can stream conversation responses. /// /// Implementors take a conversation history and return a stream of [`StreamEvent`]s. /// The trait is provider-agnostic -- no Claude-specific types appear here. pub trait ModelProvider: Send + Sync { - /// Stream a response from the model given the conversation history. + /// Stream a response from the model given the conversation history and + /// available tool definitions. Pass an empty slice if no tools are available. fn stream<'a>( &'a self, messages: &'a [ConversationMessage], + tools: &'a [ToolDefinition], ) -> impl Stream + Send + 'a; } diff --git a/src/tools/list_directory.rs b/src/tools/list_directory.rs new file mode 100644 index 0000000..5be154c --- /dev/null +++ b/src/tools/list_directory.rs @@ -0,0 +1,87 @@ +//! `list_directory` tool: lists entries in a directory within the working directory. + +use std::path::Path; + +use async_trait::async_trait; + +use super::{RiskLevel, Tool, ToolError, ToolOutput, validate_path}; + +/// Lists directory contents. Auto-approved (read-only). +pub struct ListDirectory; + +#[async_trait] +impl Tool for ListDirectory { + fn name(&self) -> &str { + "list_directory" + } + + fn description(&self) -> &str { + "List the files and subdirectories in a directory. The path is relative to the working directory." + } + + fn input_schema(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "The directory path to list, relative to the working directory. Use '.' for the working directory itself." + } + }, + "required": ["path"] + }) + } + + fn risk_level(&self) -> RiskLevel { + RiskLevel::AutoApprove + } + + async fn execute( + &self, + input: &serde_json::Value, + working_dir: &Path, + ) -> Result { + let path_str = input["path"] + .as_str() + .ok_or_else(|| ToolError::InvalidInput("missing 'path' string".to_string()))?; + + let canonical = validate_path(working_dir, path_str)?; + + let mut entries: Vec = Vec::new(); + let mut dir = tokio::fs::read_dir(&canonical).await?; + while let Some(entry) = dir.next_entry().await? { + let name = entry.file_name().to_string_lossy().to_string(); + let suffix = if entry.file_type().await?.is_dir() { + "/" + } else { + "" + }; + entries.push(format!("{name}{suffix}")); + } + entries.sort(); + + Ok(ToolOutput { + content: entries.join("\n"), + is_error: false, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use tempfile::TempDir; + + #[tokio::test] + async fn list_directory_contents() { + let dir = TempDir::new().unwrap(); + fs::write(dir.path().join("a.txt"), "").unwrap(); + fs::create_dir(dir.path().join("subdir")).unwrap(); + let tool = ListDirectory; + let input = serde_json::json!({"path": "."}); + let out = tool.execute(&input, dir.path()).await.unwrap(); + assert!(out.content.contains("a.txt")); + assert!(out.content.contains("subdir/")); + } +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs new file mode 100644 index 0000000..656c935 --- /dev/null +++ b/src/tools/mod.rs @@ -0,0 +1,220 @@ +//! Tool system: trait, registry, risk classification, and built-in tools. +//! +//! All tools implement the [`Tool`] trait. The [`ToolRegistry`] collects them +//! and provides lookup by name plus generation of [`ToolDefinition`]s for the +//! model provider. + +mod list_directory; +mod read_file; +mod shell_exec; +mod write_file; + +use std::path::{Path, PathBuf}; + +use async_trait::async_trait; + +use crate::core::types::ToolDefinition; + +/// The output of a tool execution. +#[derive(Debug)] +pub struct ToolOutput { + /// The text content returned to the model. + pub content: String, + /// Whether the tool encountered an error. + pub is_error: bool, +} + +/// Risk classification for tool approval gating. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RiskLevel { + /// Safe to execute without user confirmation (e.g. read-only operations). + AutoApprove, + /// Requires explicit user approval before execution (e.g. writes, shell). + RequiresApproval, +} + +/// A tool that the model can invoke. +/// +/// The `execute` method is async so that tool implementations can use +/// `tokio::fs` and `tokio::process` without blocking a Tokio worker thread. +/// `#[async_trait]` desugars the async fn to a boxed future, which is required +/// for `dyn Tool` to remain object-safe. +#[async_trait] +pub trait Tool: Send + Sync { + /// The name the model uses to invoke this tool. + fn name(&self) -> &str; + /// Human-readable description for the model. + fn description(&self) -> &str; + /// JSON Schema for the tool's input parameters. + fn input_schema(&self) -> serde_json::Value; + /// The risk level of this tool. + fn risk_level(&self) -> RiskLevel; + /// Execute the tool with the given input, confined to `working_dir`. + async fn execute( + &self, + input: &serde_json::Value, + working_dir: &Path, + ) -> Result; +} + +/// Errors from tool execution. +#[derive(Debug, thiserror::Error)] +pub enum ToolError { + /// The requested path escapes the working directory. + #[error("path escapes working directory: {0}")] + PathEscape(PathBuf), + /// An I/O error during tool execution. + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + /// A required input field is missing or has the wrong type. + #[error("invalid input: {0}")] + InvalidInput(String), +} + +/// Validate that `requested` resolves to a path inside `working_dir`. +/// +/// Joins `working_dir` with `requested`, canonicalizes the result (resolving +/// symlinks and `..` components), and checks that the canonical path starts +/// with the canonical working directory. +/// +/// Returns the canonical path on success, or [`ToolError::PathEscape`] if the +/// path would escape the working directory. +pub fn validate_path(working_dir: &Path, requested: &str) -> Result { + let candidate = if Path::new(requested).is_absolute() { + PathBuf::from(requested) + } else { + working_dir.join(requested) + }; + + // For paths that don't exist yet (e.g. write_file creating a new file), + // canonicalize the parent directory and append the filename. + let canonical = if candidate.exists() { + candidate + .canonicalize() + .map_err(|_| ToolError::PathEscape(candidate.clone()))? + } else { + let parent = candidate + .parent() + .ok_or_else(|| ToolError::PathEscape(candidate.clone()))?; + let file_name = candidate + .file_name() + .ok_or_else(|| ToolError::PathEscape(candidate.clone()))?; + let canonical_parent = parent + .canonicalize() + .map_err(|_| ToolError::PathEscape(candidate.clone()))?; + canonical_parent.join(file_name) + }; + + let canonical_root = working_dir + .canonicalize() + .map_err(|_| ToolError::PathEscape(candidate.clone()))?; + + if canonical.starts_with(&canonical_root) { + Ok(canonical) + } else { + Err(ToolError::PathEscape(candidate)) + } +} + +/// Collection of available tools with name-based lookup. +pub struct ToolRegistry { + tools: Vec>, +} + +impl ToolRegistry { + /// Create an empty registry with no tools. + #[allow(dead_code)] + pub fn empty() -> Self { + Self { tools: Vec::new() } + } + + /// Create a registry with the default built-in tools. + pub fn default_tools() -> Self { + Self { + tools: vec![ + Box::new(read_file::ReadFile), + Box::new(list_directory::ListDirectory), + Box::new(write_file::WriteFile), + Box::new(shell_exec::ShellExec), + ], + } + } + + /// Look up a tool by name. + pub fn get(&self, name: &str) -> Option<&dyn Tool> { + self.tools.iter().find(|t| t.name() == name).map(|t| &**t) + } + + /// Generate [`ToolDefinition`]s for the model provider. + pub fn definitions(&self) -> Vec { + self.tools + .iter() + .map(|t| ToolDefinition { + name: t.name().to_string(), + description: t.description().to_string(), + input_schema: t.input_schema(), + }) + .collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use tempfile::TempDir; + + #[test] + fn validate_path_allows_subpath() { + let dir = TempDir::new().unwrap(); + let sub = dir.path().join("sub"); + fs::create_dir(&sub).unwrap(); + let result = validate_path(dir.path(), "sub"); + assert!(result.is_ok()); + assert!( + result + .unwrap() + .starts_with(dir.path().canonicalize().unwrap()) + ); + } + + #[test] + fn validate_path_rejects_traversal() { + let dir = TempDir::new().unwrap(); + let result = validate_path(dir.path(), "../../../etc/passwd"); + assert!(result.is_err()); + assert!(matches!(result, Err(ToolError::PathEscape(_)))); + } + + #[test] + fn validate_path_rejects_absolute_outside() { + let dir = TempDir::new().unwrap(); + let result = validate_path(dir.path(), "/etc/passwd"); + assert!(result.is_err()); + } + + #[test] + fn validate_path_allows_new_file_in_working_dir() { + let dir = TempDir::new().unwrap(); + let result = validate_path(dir.path(), "new_file.txt"); + assert!(result.is_ok()); + } + + #[test] + fn registry_default_has_all_tools() { + let reg = ToolRegistry::default_tools(); + assert!(reg.get("read_file").is_some()); + assert!(reg.get("list_directory").is_some()); + assert!(reg.get("write_file").is_some()); + assert!(reg.get("shell_exec").is_some()); + assert!(reg.get("nonexistent").is_none()); + } + + #[test] + fn registry_definitions_match_tools() { + let reg = ToolRegistry::default_tools(); + let defs = reg.definitions(); + assert_eq!(defs.len(), 4); + assert!(defs.iter().any(|d| d.name == "read_file")); + } +} diff --git a/src/tools/read_file.rs b/src/tools/read_file.rs new file mode 100644 index 0000000..da2bbf2 --- /dev/null +++ b/src/tools/read_file.rs @@ -0,0 +1,92 @@ +//! `read_file` tool: reads the contents of a file within the working directory. + +use std::path::Path; + +use async_trait::async_trait; + +use super::{RiskLevel, Tool, ToolError, ToolOutput, validate_path}; + +/// Reads file contents. Auto-approved (read-only). +pub struct ReadFile; + +#[async_trait] +impl Tool for ReadFile { + fn name(&self) -> &str { + "read_file" + } + + fn description(&self) -> &str { + "Read the contents of a file. The path is relative to the working directory." + } + + fn input_schema(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "The file path to read, relative to the working directory." + } + }, + "required": ["path"] + }) + } + + fn risk_level(&self) -> RiskLevel { + RiskLevel::AutoApprove + } + + async fn execute( + &self, + input: &serde_json::Value, + working_dir: &Path, + ) -> Result { + let path_str = input["path"] + .as_str() + .ok_or_else(|| ToolError::InvalidInput("missing 'path' string".to_string()))?; + + let canonical = validate_path(working_dir, path_str)?; + let content = tokio::fs::read_to_string(&canonical).await?; + + Ok(ToolOutput { + content, + is_error: false, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use tempfile::TempDir; + + #[tokio::test] + async fn read_existing_file() { + let dir = TempDir::new().unwrap(); + fs::write(dir.path().join("hello.txt"), "world").unwrap(); + let tool = ReadFile; + let input = serde_json::json!({"path": "hello.txt"}); + let out = tool.execute(&input, dir.path()).await.unwrap(); + assert_eq!(out.content, "world"); + assert!(!out.is_error); + } + + #[tokio::test] + async fn read_nonexistent_file_errors() { + let dir = TempDir::new().unwrap(); + let tool = ReadFile; + let input = serde_json::json!({"path": "nope.txt"}); + let result = tool.execute(&input, dir.path()).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn read_path_traversal_rejected() { + let dir = TempDir::new().unwrap(); + let tool = ReadFile; + let input = serde_json::json!({"path": "../../../etc/passwd"}); + let result = tool.execute(&input, dir.path()).await; + assert!(matches!(result, Err(ToolError::PathEscape(_)))); + } +} diff --git a/src/tools/shell_exec.rs b/src/tools/shell_exec.rs new file mode 100644 index 0000000..a8c78ce --- /dev/null +++ b/src/tools/shell_exec.rs @@ -0,0 +1,108 @@ +//! `shell_exec` tool: runs a shell command within the working directory. + +use std::path::Path; + +use async_trait::async_trait; + +use super::{RiskLevel, Tool, ToolError, ToolOutput}; + +/// Executes a shell command. Requires user approval. +pub struct ShellExec; + +#[async_trait] +impl Tool for ShellExec { + fn name(&self) -> &str { + "shell_exec" + } + + fn description(&self) -> &str { + "Execute a shell command in the working directory. Returns stdout and stderr." + } + + fn input_schema(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The shell command to execute." + } + }, + "required": ["command"] + }) + } + + fn risk_level(&self) -> RiskLevel { + RiskLevel::RequiresApproval + } + + async fn execute( + &self, + input: &serde_json::Value, + working_dir: &Path, + ) -> Result { + let command = input["command"] + .as_str() + .ok_or_else(|| ToolError::InvalidInput("missing 'command' string".to_string()))?; + + let output = tokio::process::Command::new("sh") + .arg("-c") + .arg(command) + .current_dir(working_dir) + .output() + .await?; + + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + + let mut content = String::new(); + if !stdout.is_empty() { + content.push_str(&stdout); + } + if !stderr.is_empty() { + if !content.is_empty() { + content.push('\n'); + } + content.push_str("[stderr]\n"); + content.push_str(&stderr); + } + if content.is_empty() { + content.push_str("(no output)"); + } + + let is_error = !output.status.success(); + if is_error { + content.push_str(&format!( + "\n[exit code: {}]", + output.status.code().unwrap_or(-1) + )); + } + + Ok(ToolOutput { content, is_error }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + #[tokio::test] + async fn shell_exec_echo() { + let dir = TempDir::new().unwrap(); + let tool = ShellExec; + let input = serde_json::json!({"command": "echo hello"}); + let out = tool.execute(&input, dir.path()).await.unwrap(); + assert!(out.content.contains("hello")); + assert!(!out.is_error); + } + + #[tokio::test] + async fn shell_exec_failing_command() { + let dir = TempDir::new().unwrap(); + let tool = ShellExec; + let input = serde_json::json!({"command": "false"}); + let out = tool.execute(&input, dir.path()).await.unwrap(); + assert!(out.is_error); + } +} diff --git a/src/tools/write_file.rs b/src/tools/write_file.rs new file mode 100644 index 0000000..315d477 --- /dev/null +++ b/src/tools/write_file.rs @@ -0,0 +1,98 @@ +//! `write_file` tool: writes content to a file within the working directory. + +use std::path::Path; + +use async_trait::async_trait; + +use super::{RiskLevel, Tool, ToolError, ToolOutput, validate_path}; + +/// Writes content to a file. Requires user approval. +pub struct WriteFile; + +#[async_trait] +impl Tool for WriteFile { + fn name(&self) -> &str { + "write_file" + } + + fn description(&self) -> &str { + "Write content to a file. Creates the file if it doesn't exist, overwrites if it does. The path is relative to the working directory." + } + + fn input_schema(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "The file path to write to, relative to the working directory." + }, + "content": { + "type": "string", + "description": "The content to write to the file." + } + }, + "required": ["path", "content"] + }) + } + + fn risk_level(&self) -> RiskLevel { + RiskLevel::RequiresApproval + } + + async fn execute( + &self, + input: &serde_json::Value, + working_dir: &Path, + ) -> Result { + let path_str = input["path"] + .as_str() + .ok_or_else(|| ToolError::InvalidInput("missing 'path' string".to_string()))?; + let content = input["content"] + .as_str() + .ok_or_else(|| ToolError::InvalidInput("missing 'content' string".to_string()))?; + + let canonical = validate_path(working_dir, path_str)?; + + // Create parent directories if needed. + if let Some(parent) = canonical.parent() { + tokio::fs::create_dir_all(parent).await?; + } + + tokio::fs::write(&canonical, content).await?; + + Ok(ToolOutput { + content: format!("Wrote {} bytes to {path_str}", content.len()), + is_error: false, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use tempfile::TempDir; + + #[tokio::test] + async fn write_creates_file() { + let dir = TempDir::new().unwrap(); + let tool = WriteFile; + let input = serde_json::json!({"path": "out.txt", "content": "hello"}); + let out = tool.execute(&input, dir.path()).await.unwrap(); + assert!(!out.is_error); + assert_eq!( + fs::read_to_string(dir.path().join("out.txt")).unwrap(), + "hello" + ); + } + + #[tokio::test] + async fn write_path_traversal_rejected() { + let dir = TempDir::new().unwrap(); + let tool = WriteFile; + let input = serde_json::json!({"path": "../../evil.txt", "content": "bad"}); + let result = tool.execute(&input, dir.path()).await; + assert!(matches!(result, Err(ToolError::PathEscape(_)))); + } +} diff --git a/src/tui/events.rs b/src/tui/events.rs index 74c7842..8d6da6a 100644 --- a/src/tui/events.rs +++ b/src/tui/events.rs @@ -11,11 +11,14 @@ use crate::core::types::{Role, UIEvent}; /// This is non-blocking: it processes all currently-available events and returns /// immediately when the channel is empty. /// -/// | Event | Effect | -/// |--------------------|------------------------------------------------------------| -/// | `StreamDelta(s)` | Append `s` to last message if it's `Assistant`; else push new | -/// | `TurnComplete` | No structural change; logged at debug level | -/// | `Error(msg)` | Push `(Assistant, "[error] {msg}")` | +/// | Event | Effect | +/// |------------------------|------------------------------------------------------------| +/// | `StreamDelta(s)` | Append `s` to last message if it's `Assistant`; else push | +/// | `ToolApprovalRequest` | Set `pending_approval` in state | +/// | `ToolExecuting` | Display tool execution info | +/// | `ToolResult` | Display tool result | +/// | `TurnComplete` | No structural change; logged at debug level | +/// | `Error(msg)` | Push `(Assistant, "[error] {msg}")` | pub(super) fn drain_ui_events(event_rx: &mut mpsc::Receiver, state: &mut AppState) { while let Ok(event) = event_rx.try_recv() { match event { @@ -26,6 +29,36 @@ pub(super) fn drain_ui_events(event_rx: &mut mpsc::Receiver, state: &mu state.messages.push((Role::Assistant, chunk)); } } + UIEvent::ToolApprovalRequest { + tool_use_id, + tool_name, + input_summary, + } => { + state.pending_approval = Some(PendingApproval { + tool_use_id, + tool_name, + input_summary, + }); + } + UIEvent::ToolExecuting { + tool_name, + input_summary, + } => { + state + .messages + .push((Role::Assistant, format!("[{tool_name}] {input_summary}"))); + } + UIEvent::ToolResult { + tool_name, + output_summary, + is_error, + } => { + let prefix = if is_error { "error" } else { "result" }; + state.messages.push(( + Role::Assistant, + format!("[{tool_name} {prefix}] {output_summary}"), + )); + } UIEvent::TurnComplete => { debug!("turn complete"); } @@ -38,6 +71,14 @@ pub(super) fn drain_ui_events(event_rx: &mut mpsc::Receiver, state: &mu } } +/// A pending tool approval request waiting for user input. +#[derive(Debug, Clone)] +pub struct PendingApproval { + pub tool_use_id: String, + pub tool_name: String, + pub input_summary: String, +} + #[cfg(test)] mod tests { use super::*; @@ -69,4 +110,39 @@ mod tests { assert_eq!(state.messages[1].0, Role::Assistant); assert_eq!(state.messages[1].1, "hello"); } + + #[tokio::test] + async fn drain_tool_approval_sets_pending() { + let (tx, mut rx) = tokio::sync::mpsc::channel(8); + let mut state = AppState::new(); + tx.send(UIEvent::ToolApprovalRequest { + tool_use_id: "t1".to_string(), + tool_name: "write_file".to_string(), + input_summary: "path: foo.txt".to_string(), + }) + .await + .unwrap(); + drop(tx); + drain_ui_events(&mut rx, &mut state); + assert!(state.pending_approval.is_some()); + let approval = state.pending_approval.unwrap(); + assert_eq!(approval.tool_name, "write_file"); + } + + #[tokio::test] + async fn drain_tool_result_adds_message() { + let (tx, mut rx) = tokio::sync::mpsc::channel(8); + let mut state = AppState::new(); + tx.send(UIEvent::ToolResult { + tool_name: "read_file".to_string(), + output_summary: "file contents...".to_string(), + is_error: false, + }) + .await + .unwrap(); + drop(tx); + drain_ui_events(&mut rx, &mut state); + assert_eq!(state.messages.len(), 1); + assert!(state.messages[0].1.contains("read_file result")); + } } diff --git a/src/tui/input.rs b/src/tui/input.rs index 59770ef..2b60794 100644 --- a/src/tui/input.rs +++ b/src/tui/input.rs @@ -13,6 +13,8 @@ pub(super) enum LoopControl { Quit, /// The user ran `:clear`; wipe the conversation. ClearHistory, + /// The user responded to a tool approval prompt. + ToolApproval { tool_use_id: String, approved: bool }, } /// Map a key event to a [`LoopControl`] signal, mutating `state` as a side-effect. @@ -23,6 +25,29 @@ pub(super) fn handle_key(key: Option, state: &mut AppState) -> Option< let key = key?; // Clear any transient status error on the next keypress. state.status_error = None; + + // If a tool approval is pending, intercept y/n before normal key handling. + if let Some(approval) = &state.pending_approval { + let tool_use_id = approval.tool_use_id.clone(); + match key.code { + KeyCode::Char('y') | KeyCode::Char('Y') => { + state.pending_approval = None; + return Some(LoopControl::ToolApproval { + tool_use_id, + approved: true, + }); + } + KeyCode::Char('n') | KeyCode::Char('N') => { + state.pending_approval = None; + return Some(LoopControl::ToolApproval { + tool_use_id, + approved: false, + }); + } + _ => return None, // ignore other keys while approval pending + } + } + // Ctrl+C quits from any mode. if key.modifiers.contains(KeyModifiers::CONTROL) && key.code == KeyCode::Char('c') { return Some(LoopControl::Quit); diff --git a/src/tui/mod.rs b/src/tui/mod.rs index 78c5f21..7a21385 100644 --- a/src/tui/mod.rs +++ b/src/tui/mod.rs @@ -72,6 +72,8 @@ pub struct AppState { pub viewport_height: u16, /// Transient error message shown in the status bar, cleared on next keypress. pub status_error: Option, + /// A tool approval request waiting for user input (y/n). + pub pending_approval: Option, } impl AppState { @@ -85,6 +87,7 @@ impl AppState { pending_keys: Vec::new(), viewport_height: 0, status_error: None, + pending_approval: None, } } } @@ -185,6 +188,17 @@ pub async fn run( Some(input::LoopControl::ClearHistory) => { let _ = action_tx.send(UserAction::ClearHistory).await; } + Some(input::LoopControl::ToolApproval { + tool_use_id, + approved, + }) => { + let _ = action_tx + .send(UserAction::ToolApprovalResponse { + tool_use_id, + approved, + }) + .await; + } None => {} } } diff --git a/src/tui/render.rs b/src/tui/render.rs index 82561e1..e12a636 100644 --- a/src/tui/render.rs +++ b/src/tui/render.rs @@ -105,11 +105,37 @@ pub(super) fn render(frame: &mut Frame, state: &AppState) { let output = Paragraph::new(lines) .wrap(Wrap { trim: false }) .scroll((state.scroll, 0)); - frame.render_widget(output, chunks[0]); + let output_area = chunks[0]; + frame.render_widget(output, output_area); + + // --- Tool approval overlay --- + if let Some(ref approval) = state.pending_approval { + let overlay_w = (output_area.width / 2).max(60).min(output_area.width); + let overlay_h: u16 = 5; + let overlay_x = output_area.x + (output_area.width.saturating_sub(overlay_w)) / 2; + let overlay_y = output_area.y + output_area.height.saturating_sub(overlay_h) / 2; + let overlay_area = Rect { + x: overlay_x, + y: overlay_y, + width: overlay_w, + height: overlay_h.min(output_area.height), + }; + frame.render_widget(Clear, overlay_area); + let text = format!( + "{}: {}\n\ny = approve, n = deny", + approval.tool_name, approval.input_summary + ); + let overlay = Paragraph::new(text).block( + Block::bordered() + .border_style(Style::default().fg(Color::Yellow)) + .title("Tool Approval"), + ); + frame.render_widget(overlay, overlay_area); + } // --- Command overlay (floating box centered on output pane) --- if state.mode == Mode::Command { - let overlay_area = command_overlay_rect(chunks[0]); + let overlay_area = command_overlay_rect(output_area); // Clear the area behind the overlay so it appears floating. frame.render_widget(Clear, overlay_area); let overlay = Paragraph::new(format!(":{}", state.command_buffer)).block( @@ -146,7 +172,7 @@ pub(super) fn render(frame: &mut Frame, state: &AppState) { } Mode::Command => { // Cursor in the floating overlay - let overlay = command_overlay_rect(chunks[0]); + let overlay = command_overlay_rect(output_area); // border(1) + ":" (1) + buf len let cursor_x = overlay.x + 1 + 1 + state.command_buffer.len() as u16; let cursor_y = overlay.y + 1; // inside the border @@ -386,4 +412,31 @@ mod tests { "expected error in status bar" ); } + + #[test] + fn render_approval_overlay_visible() { + let backend = TestBackend::new(80, 24); + let mut terminal = Terminal::new(backend).unwrap(); + let mut state = AppState::new(); + state.pending_approval = Some(super::super::events::PendingApproval { + tool_use_id: "t1".to_string(), + tool_name: "write_file".to_string(), + input_summary: "path: foo.txt".to_string(), + }); + terminal.draw(|frame| render(frame, &state)).unwrap(); + let buf = terminal.backend().buffer().clone(); + let all_text: String = buf + .content() + .iter() + .map(|c| c.symbol().to_string()) + .collect(); + assert!( + all_text.contains("Tool Approval"), + "expected 'Tool Approval' overlay" + ); + assert!( + all_text.contains("write_file"), + "expected tool name in overlay" + ); + } } -- 2.49.1