Tool Use.

This commit is contained in:
Drew 2026-02-24 18:31:26 -08:00
parent 6b85ff3cb8
commit 0c1c928498
20 changed files with 1822 additions and 129 deletions

32
Cargo.lock generated
View file

@ -23,6 +23,17 @@ version = "1.0.102"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c"
[[package]]
name = "async-trait"
version = "0.1.89"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.117",
]
[[package]]
name = "atomic"
version = "0.6.1"
@ -447,6 +458,12 @@ dependencies = [
"regex",
]
[[package]]
name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "filedescriptor"
version = "0.8.3"
@ -2024,12 +2041,14 @@ name = "skate"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"crossterm",
"futures",
"ratatui",
"reqwest",
"serde",
"serde_json",
"tempfile",
"thiserror 2.0.18",
"tokio",
"tracing",
@ -2166,6 +2185,19 @@ dependencies = [
"libc",
]
[[package]]
name = "tempfile"
version = "3.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82a72c767771b47409d2345987fda8628641887d5466101319899796367354a0"
dependencies = [
"fastrand",
"getrandom 0.4.1",
"once_cell",
"rustix",
"windows-sys 0.61.2",
]
[[package]]
name = "terminfo"
version = "0.9.0"

View file

@ -15,3 +15,7 @@ tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
reqwest = { version = "0.13", features = ["stream", "json"] }
futures = "0.3"
async-trait = "0.1"
[dev-dependencies]
tempfile = "3.26.0"

76
PLAN.md
View file

@ -1,11 +1,77 @@
# Implementation Plan
## Phase 3: Tool Execution
- `Tool` trait, `ToolRegistry`, core tools (`read_file`, `write_file`, `shell_exec`)
- Tool definitions in API requests, parse tool-use responses
- Approval gate: core -> TUI pending event -> user approve/deny -> result back
- Working directory confinement + path validation (no Landlock yet)
- **Done when:** Claude can read, modify files, and run commands with user approval
### Step 3.1: Enrich the content model
- Replace `ConversationMessage { role, content: String }` with content-block model
- Define `ContentBlock` enum: `Text(String)`, `ToolUse { id, name, input: Value }`, `ToolResult { tool_use_id, content: String, is_error: bool }`
- Change `ConversationMessage.content` from `String` to `Vec<ContentBlock>`
- Add `ConversationMessage::text(role, s)` helper to keep existing call sites clean
- Update serialization, orchestrator, tests, TUI display
- **Files:** `src/core/types.rs`, `src/core/history.rs`
- **Done when:** `cargo test` passes with new model; all existing tests updated
### Step 3.2: Send tool definitions in API requests
- Add `ToolDefinition { name, description, input_schema: Value }` (provider-agnostic)
- Extend `ModelProvider::stream` to accept `&[ToolDefinition]`
- Include `"tools"` array in Claude provider request body
- **Files:** `src/provider/mod.rs`, `src/provider/claude.rs`
- **Done when:** API responses contain `tool_use` content blocks in raw SSE stream
### Step 3.3: Parse tool-use blocks from SSE stream
- Add `StreamEvent::ToolUseStart { id, name }`, `ToolUseInputDelta(String)`, `ToolUseDone`
- Handle `content_block_start` (type "tool_use"), `content_block_delta` (type "input_json_delta"), `content_block_stop` for tool blocks
- Track current block type state in SSE parser
- **Files:** `src/provider/claude.rs`, `src/core/types.rs`
- **Done when:** Unit test with recorded tool-use SSE fixture asserts correct StreamEvent sequence
### Step 3.4: Orchestrator accumulates tool-use blocks
- Accumulate `ToolUseInputDelta` fragments into JSON buffer per tool-use id
- On `ToolUseDone`, parse JSON into `ContentBlock::ToolUse`
- After `StreamEvent::Done`, if assistant message contains ToolUse blocks, enter tool-execution phase
- **Files:** `src/core/orchestrator.rs`
- **Done when:** Unit test with mock provider emitting tool-use events produces correct ContentBlocks
### Step 3.5: Tool trait, registry, and core tools
- `Tool` trait: `name()`, `description()`, `input_schema() -> Value`, `execute(input: Value, working_dir: &Path) -> Result<ToolOutput>`
- `ToolOutput { content: String, is_error: bool }`
- `ToolRegistry`: stores tools, provides `get(name)` and `definitions() -> Vec<ToolDefinition>`
- Risk level: `AutoApprove` (reads), `RequiresApproval` (writes/shell)
- Implement: `read_file` (auto), `list_directory` (auto), `write_file` (approval), `shell_exec` (approval)
- Path validation: `canonicalize` + `starts_with` check, reject paths outside working dir (no Landlock yet)
- **Files:** New `src/tools/` module: `mod.rs`, `read_file.rs`, `write_file.rs`, `list_directory.rs`, `shell_exec.rs`
- **Done when:** Unit tests pass for each tool in temp dirs; path traversal rejected
### Step 3.6: Approval gate (TUI <-> core)
- New `UIEvent::ToolApprovalRequest { tool_use_id, tool_name, input_summary }`
- New `UserAction::ToolApprovalResponse { tool_use_id, approved: bool }`
- Orchestrator: check risk level -> auto-approve or send approval request and await response
- Denied tools return `ToolResult { is_error: true }` with denial message
- TUI: render approval prompt overlay with y/n keybindings
- **Files:** `src/core/types.rs`, `src/core/orchestrator.rs`, `src/tui/events.rs`, `src/tui/input.rs`, `src/tui/render.rs`
- **Done when:** Integration test: mock provider + mock TUI channel verifies approval flow
### Step 3.7: Tool results fed back to the model
- After executing tool calls: append assistant message (with ToolUse blocks) to history, append user message with ToolResult blocks, re-call provider
- Loop: model may respond with more tool calls or text
- Cap at max iterations (25) to prevent runaway
- **Files:** `src/core/orchestrator.rs`
- **Done when:** Integration test: mock provider returns tool-use then text; orchestrator makes two calls. Max-iteration cap tested.
### Step 3.8: TUI display for tool activity
- New `UIEvent::ToolExecuting { tool_name, input_summary }`, `UIEvent::ToolResult { tool_name, output_summary, is_error }`
- Render tool calls as distinct visual blocks in conversation view
- Render tool results inline (truncated if long)
- **Files:** `src/tui/render.rs`, `src/tui/events.rs`
- **Done when:** Visual check with `cargo run`; TestBackend test for tool block rendering
### Phase 3 verification (end-to-end)
1. `cargo test` -- all tests pass
2. `cargo clippy -- -D warnings` -- zero warnings
3. `cargo run -- --project-dir .` -- ask Claude to read a file, approve, see contents
4. Ask Claude to write a file -- approve, verify written
5. Ask Claude to run a shell command -- approve, verify output
6. Deny an approval -- Claude gets denial and responds gracefully
## Phase 4: Sandboxing
- Landlock: read-only system, read-write project dir, network blocked

View file

@ -1,5 +1,7 @@
# Cleanups
- Parallelize tool-use execution in `run_turn` -- requires refactoring `Orchestrator` to use `&self` + interior mutability (`Arc<Mutex<...>>` around `event_tx`, `action_rx`, `history`) so multiple futures can borrow self simultaneously via `futures::future::join_all`.
- Move keyboard/event reads in the TUI to a separate thread or async/io loop
- Keep UI and orchestrator in sync (i.e. messages display out of order if you queue up many.)
- `update_scroll` auto-follows in Insert mode, yanking viewport to bottom on mode switch. Only auto-follow when new content arrives (in `drain_ui_events`), not every frame.

View file

@ -27,6 +27,7 @@ use tokio::sync::mpsc;
use crate::core::orchestrator::Orchestrator;
use crate::core::types::{UIEvent, UserAction};
use crate::provider::ClaudeProvider;
use crate::tools::ToolRegistry;
/// Model ID sent on every request.
///
@ -68,8 +69,15 @@ pub async fn run(project_dir: &Path) -> anyhow::Result<()> {
let (action_tx, action_rx) = mpsc::channel::<UserAction>(CHANNEL_CAP);
let (event_tx, event_rx) = mpsc::channel::<UIEvent>(CHANNEL_CAP);
// -- Orchestrator (background task) -------------------------------------------
let orch = Orchestrator::new(provider, action_rx, event_tx);
// -- Tools & Orchestrator (background task) ------------------------------------
let tool_registry = ToolRegistry::default_tools();
let orch = Orchestrator::new(
provider,
tool_registry,
project_dir.to_path_buf(),
action_rx,
event_tx,
);
tokio::spawn(orch.run());
// -- TUI (foreground task) ----------------------------------------------------

View file

@ -61,30 +61,21 @@ mod tests {
#[test]
fn push_and_read_roundtrip() {
let mut history = ConversationHistory::new();
history.push(ConversationMessage {
role: Role::User,
content: "hello".to_string(),
});
history.push(ConversationMessage {
role: Role::Assistant,
content: "hi there".to_string(),
});
history.push(ConversationMessage::text(Role::User, "hello"));
history.push(ConversationMessage::text(Role::Assistant, "hi there"));
let msgs = history.messages();
assert_eq!(msgs.len(), 2);
assert_eq!(msgs[0].role, Role::User);
assert_eq!(msgs[0].content, "hello");
assert_eq!(msgs[0].text_content(), "hello");
assert_eq!(msgs[1].role, Role::Assistant);
assert_eq!(msgs[1].content, "hi there");
assert_eq!(msgs[1].text_content(), "hi there");
}
#[test]
fn clear_empties_history() {
let mut history = ConversationHistory::new();
history.push(ConversationMessage {
role: Role::User,
content: "hello".to_string(),
});
history.push(ConversationMessage::text(Role::User, "hello"));
history.clear();
assert!(history.messages().is_empty());
}
@ -93,13 +84,10 @@ mod tests {
fn messages_preserves_insertion_order() {
let mut history = ConversationHistory::new();
for i in 0u32..5 {
history.push(ConversationMessage {
role: Role::User,
content: format!("msg {i}"),
});
history.push(ConversationMessage::text(Role::User, format!("msg {i}")));
}
for (i, msg) in history.messages().iter().enumerate() {
assert_eq!(msg.content, format!("msg {i}"));
assert_eq!(msg.text_content(), format!("msg {i}"));
}
}
}

View file

@ -1,10 +1,53 @@
use futures::StreamExt;
use tokio::sync::mpsc;
use tracing::debug;
use tracing::{debug, warn};
use crate::core::history::ConversationHistory;
use crate::core::types::{ConversationMessage, Role, StreamEvent, UIEvent, UserAction};
use crate::core::types::{
ContentBlock, ConversationMessage, Role, StreamEvent, ToolDefinition, UIEvent, UserAction,
};
use crate::provider::ModelProvider;
use crate::tools::{RiskLevel, ToolOutput, ToolRegistry};
/// Accumulates data for a single tool-use block while it is being streamed.
///
/// Created on `ToolUseStart`, populated by `ToolUseInputDelta` fragments, and
/// consumed on `ToolUseDone` via `TryFrom<ActiveToolUse> for ContentBlock`.
struct ActiveToolUse {
id: String,
name: String,
json_buf: String,
}
impl ActiveToolUse {
/// Start accumulating a new tool-use block identified by `id` and `name`.
fn new(id: String, name: String) -> Self {
Self {
id,
name,
json_buf: String::new(),
}
}
/// Append a JSON input fragment received from a `ToolUseInputDelta` event.
fn append(&mut self, chunk: &str) {
self.json_buf.push_str(chunk);
}
}
impl TryFrom<ActiveToolUse> for ContentBlock {
type Error = serde_json::Error;
/// Parse the accumulated JSON buffer and produce a `ContentBlock::ToolUse`.
fn try_from(t: ActiveToolUse) -> Result<Self, Self::Error> {
let input = serde_json::from_str(&t.json_buf)?;
Ok(ContentBlock::ToolUse {
id: t.id,
name: t.name,
input,
})
}
}
/// Drives the conversation loop between the TUI frontend and the model provider.
///
@ -38,28 +81,321 @@ use crate::provider::ModelProvider;
/// OutputTokens -> log at debug level
/// 3. Quit -> return
/// ```
/// The result of consuming one provider stream (one assistant turn).
enum StreamResult {
/// The stream completed successfully with these content blocks.
Done(Vec<ContentBlock>),
/// The stream produced an error.
Error(String),
}
/// Maximum number of tool-use loop iterations per user message.
const MAX_TOOL_ITERATIONS: usize = 25;
/// Truncate a string to `max_len` Unicode scalar values, appending "..." if
/// truncated. Uses `char_indices` so that multibyte characters are never split
/// at a byte boundary (which would panic on a bare slice).
fn truncate(s: &str, max_len: usize) -> String {
match s.char_indices().nth(max_len) {
Some((i, _)) => format!("{}...", &s[..i]),
None => s.to_string(),
}
}
pub struct Orchestrator<P> {
history: ConversationHistory,
provider: P,
tool_registry: ToolRegistry,
working_dir: std::path::PathBuf,
action_rx: mpsc::Receiver<UserAction>,
event_tx: mpsc::Sender<UIEvent>,
/// Messages typed by the user while an approval prompt is open. They are
/// queued here and replayed as new turns once the current turn completes.
queued_messages: Vec<String>,
}
impl<P: ModelProvider> Orchestrator<P> {
/// Construct an orchestrator using the given provider and channel endpoints.
pub fn new(
provider: P,
tool_registry: ToolRegistry,
working_dir: std::path::PathBuf,
action_rx: mpsc::Receiver<UserAction>,
event_tx: mpsc::Sender<UIEvent>,
) -> Self {
Self {
history: ConversationHistory::new(),
provider,
tool_registry,
working_dir,
action_rx,
event_tx,
queued_messages: Vec::new(),
}
}
/// Consume the provider's stream for one turn, returning the content blocks
/// produced by the assistant (text and/or tool-use) or an error.
///
/// Tool-use input JSON is accumulated from `ToolUseInputDelta` fragments and
/// parsed into `ContentBlock::ToolUse` on `ToolUseDone`.
async fn consume_stream(
&self,
messages: &[ConversationMessage],
tools: &[ToolDefinition],
) -> StreamResult {
let mut blocks: Vec<ContentBlock> = Vec::new();
let mut text_buf = String::new();
let mut active_tool: Option<ActiveToolUse> = None;
let mut stream = Box::pin(self.provider.stream(messages, tools));
while let Some(event) = stream.next().await {
match event {
StreamEvent::TextDelta(chunk) => {
text_buf.push_str(&chunk);
let _ = self.event_tx.send(UIEvent::StreamDelta(chunk)).await;
}
StreamEvent::ToolUseStart { id, name } => {
// Flush any accumulated text before starting tool block.
if !text_buf.is_empty() {
blocks.push(ContentBlock::Text {
text: std::mem::take(&mut text_buf),
});
}
active_tool = Some(ActiveToolUse::new(id, name));
}
StreamEvent::ToolUseInputDelta(chunk) => {
if let Some(t) = &mut active_tool {
t.append(&chunk);
} else {
warn!(
"received ToolUseInputDelta outside of an active tool-use block -- ignoring"
);
}
}
StreamEvent::ToolUseDone => {
if let Some(t) = active_tool.take() {
match ContentBlock::try_from(t) {
Ok(block) => blocks.push(block),
Err(e) => {
return StreamResult::Error(format!(
"failed to parse tool input JSON: {e}"
));
}
}
} else {
warn!(
"received ToolUseDone outside of an active tool-use block -- ignoring"
);
}
}
StreamEvent::Done => {
// Flush trailing text.
if !text_buf.is_empty() {
blocks.push(ContentBlock::Text { text: text_buf });
}
return StreamResult::Done(blocks);
}
StreamEvent::Error(msg) => {
return StreamResult::Error(msg);
}
StreamEvent::InputTokens(n) => {
debug!(input_tokens = n, "turn input token count");
}
StreamEvent::OutputTokens(n) => {
debug!(output_tokens = n, "turn output token count");
}
}
}
// Stream ended without Done -- treat as error.
StreamResult::Error("stream ended unexpectedly".to_string())
}
/// Execute one complete turn: stream from provider, execute tools if needed,
/// loop back until the model produces a text-only response or we hit the
/// iteration limit.
async fn run_turn(&mut self) {
let tool_defs = self.tool_registry.definitions();
for _ in 0..MAX_TOOL_ITERATIONS {
let messages = self.history.messages().to_vec();
let result = self.consume_stream(&messages, &tool_defs).await;
match result {
StreamResult::Error(msg) => {
let _ = self.event_tx.send(UIEvent::Error(msg)).await;
return;
}
StreamResult::Done(blocks) => {
let has_tool_use = blocks
.iter()
.any(|b| matches!(b, ContentBlock::ToolUse { .. }));
// Append the assistant message to history.
self.history.push(ConversationMessage {
role: Role::Assistant,
content: blocks.clone(),
});
if !has_tool_use {
let _ = self.event_tx.send(UIEvent::TurnComplete).await;
return;
}
// TODO: execute tool-use blocks in parallel -- requires
// refactoring Orchestrator to use &self + interior
// mutability (Arc<Mutex<...>> around event_tx, action_rx,
// history) so that multiple futures can borrow self
// simultaneously via futures::future::join_all.
// Execute each tool-use block and collect results.
let mut tool_results: Vec<ContentBlock> = Vec::new();
for block in &blocks {
if let ContentBlock::ToolUse { id, name, input } = block {
let result = self.execute_tool_with_approval(id, name, input).await;
tool_results.push(ContentBlock::ToolResult {
tool_use_id: id.clone(),
content: result.content,
is_error: result.is_error,
});
}
}
// Append tool results as a user message and loop.
self.history.push(ConversationMessage {
role: Role::User,
content: tool_results,
});
}
}
}
warn!("tool-use loop reached max iterations ({MAX_TOOL_ITERATIONS})");
let _ = self
.event_tx
.send(UIEvent::Error(
"tool-use loop reached maximum iterations".to_string(),
))
.await;
}
/// Execute a single tool, handling approval if needed.
///
/// For auto-approve tools, executes immediately. For tools requiring
/// approval, sends a request to the TUI and waits for a response.
async fn execute_tool_with_approval(
&mut self,
tool_use_id: &str,
tool_name: &str,
input: &serde_json::Value,
) -> ToolOutput {
// Extract tool info upfront to avoid holding a borrow on self across
// the mutable wait_for_approval call.
let risk = match self.tool_registry.get(tool_name) {
Some(t) => t.risk_level(),
None => {
return ToolOutput {
content: format!("unknown tool: {tool_name}"),
is_error: true,
};
}
};
let input_summary = serde_json::to_string(input).unwrap_or_default();
// Check approval.
let approved = match risk {
RiskLevel::AutoApprove => {
let _ = self
.event_tx
.send(UIEvent::ToolExecuting {
tool_name: tool_name.to_string(),
input_summary: input_summary.clone(),
})
.await;
true
}
RiskLevel::RequiresApproval => {
let _ = self
.event_tx
.send(UIEvent::ToolApprovalRequest {
tool_use_id: tool_use_id.to_string(),
tool_name: tool_name.to_string(),
input_summary: input_summary.clone(),
})
.await;
// Wait for approval response from TUI.
self.wait_for_approval(tool_use_id).await
}
};
if !approved {
return ToolOutput {
content: "tool execution denied by user".to_string(),
is_error: true,
};
}
// Re-fetch tool for execution (borrow was released above).
let tool = self.tool_registry.get(tool_name).unwrap();
match tool.execute(input, &self.working_dir).await {
Ok(output) => {
let _ = self
.event_tx
.send(UIEvent::ToolResult {
tool_name: tool_name.to_string(),
output_summary: truncate(&output.content, 200),
is_error: output.is_error,
})
.await;
output
}
Err(e) => {
let msg = e.to_string();
let _ = self
.event_tx
.send(UIEvent::ToolResult {
tool_name: tool_name.to_string(),
output_summary: msg.clone(),
is_error: true,
})
.await;
ToolOutput {
content: msg,
is_error: true,
}
}
}
}
/// Wait for a `ToolApprovalResponse` matching `tool_use_id` from the TUI.
///
/// Any `SendMessage` actions that arrive during the wait are pushed onto
/// `self.queued_messages` rather than discarded. They are replayed as new
/// turns once the current tool-use loop completes (see `run()`).
///
/// Returns `true` if approved, `false` if denied or the channel closes.
async fn wait_for_approval(&mut self, tool_use_id: &str) -> bool {
while let Some(action) = self.action_rx.recv().await {
match action {
UserAction::ToolApprovalResponse {
tool_use_id: id,
approved,
} if id == tool_use_id => {
return approved;
}
UserAction::Quit => return false,
UserAction::SendMessage(text) => {
self.queued_messages.push(text);
}
_ => {} // discard stale approvals / ClearHistory during wait
}
}
false
}
/// Run the orchestrator until the user quits or the `action_rx` channel closes.
pub async fn run(mut self) {
while let Some(action) = self.action_rx.recv().await {
@ -69,62 +405,24 @@ impl<P: ModelProvider> Orchestrator<P> {
self.history.clear();
}
// Approval responses are handled inline during tool execution,
// not in the main action loop. If one arrives here it's stale.
UserAction::ToolApprovalResponse { .. } => {}
UserAction::SendMessage(text) => {
// Push the user message before snapshotting, so providers
// see the full conversation including the new message.
self.history.push(ConversationMessage {
role: Role::User,
content: text,
});
self.history
.push(ConversationMessage::text(Role::User, text));
self.run_turn().await;
// Snapshot history into an owned Vec so the stream does not
// borrow from `self.history` -- this lets us mutably update
// `self.history` once the stream loop finishes.
let messages: Vec<ConversationMessage> = self.history.messages().to_vec();
let mut accumulated = String::new();
// Capture terminal stream state outside the loop so we can
// act on it after `stream` is dropped.
let mut turn_done = false;
let mut turn_error: Option<String> = None;
{
let mut stream = Box::pin(self.provider.stream(&messages));
while let Some(event) = stream.next().await {
match event {
StreamEvent::TextDelta(chunk) => {
accumulated.push_str(&chunk);
let _ = self.event_tx.send(UIEvent::StreamDelta(chunk)).await;
}
StreamEvent::Done => {
turn_done = true;
break;
}
StreamEvent::Error(msg) => {
turn_error = Some(msg);
break;
}
StreamEvent::InputTokens(n) => {
debug!(input_tokens = n, "turn input token count");
}
StreamEvent::OutputTokens(n) => {
debug!(output_tokens = n, "turn output token count");
}
}
// Drain any messages queued while an approval prompt was
// open. Each queued message is a full turn in sequence.
while !self.queued_messages.is_empty() {
let queued = std::mem::take(&mut self.queued_messages);
for msg in queued {
self.history
.push(ConversationMessage::text(Role::User, msg));
self.run_turn().await;
}
// `stream` is dropped here, releasing the borrow on
// `self.provider` and `messages`.
}
if turn_done {
self.history.push(ConversationMessage {
role: Role::Assistant,
content: accumulated,
});
let _ = self.event_tx.send(UIEvent::TurnComplete).await;
} else if let Some(msg) = turn_error {
let _ = self.event_tx.send(UIEvent::Error(msg)).await;
}
}
}
@ -135,6 +433,8 @@ impl<P: ModelProvider> Orchestrator<P> {
#[cfg(test)]
mod tests {
use super::*;
use crate::core::types::ToolDefinition;
use crate::tools::ToolRegistry;
use futures::Stream;
use tokio::sync::mpsc;
@ -155,11 +455,27 @@ mod tests {
fn stream<'a>(
&'a self,
_messages: &'a [ConversationMessage],
_tools: &'a [ToolDefinition],
) -> impl Stream<Item = StreamEvent> + Send + 'a {
futures::stream::iter(self.events.clone())
}
}
/// Create an Orchestrator with no tools for testing text-only flows.
fn test_orchestrator<P: ModelProvider>(
provider: P,
action_rx: mpsc::Receiver<UserAction>,
event_tx: mpsc::Sender<UIEvent>,
) -> Orchestrator<P> {
Orchestrator::new(
provider,
ToolRegistry::empty(),
std::path::PathBuf::from("/tmp"),
action_rx,
event_tx,
)
}
/// Collect all UIEvents that arrive within one orchestrator turn, stopping
/// when the channel is drained after a `TurnComplete` or `Error`.
async fn collect_events(rx: &mut mpsc::Receiver<UIEvent>) -> Vec<UIEvent> {
@ -195,7 +511,7 @@ mod tests {
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(16);
let orch = Orchestrator::new(provider, action_rx, event_tx);
let orch = test_orchestrator(provider, action_rx, event_tx);
let handle = tokio::spawn(orch.run());
action_tx
@ -233,7 +549,7 @@ mod tests {
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(16);
let orch = Orchestrator::new(provider, action_rx, event_tx);
let orch = test_orchestrator(provider, action_rx, event_tx);
let handle = tokio::spawn(orch.run());
action_tx
@ -264,6 +580,7 @@ mod tests {
fn stream<'a>(
&'a self,
_messages: &'a [ConversationMessage],
_tools: &'a [ToolDefinition],
) -> impl Stream<Item = StreamEvent> + Send + 'a {
panic!("stream() must not be called after Quit");
#[allow(unreachable_code)]
@ -274,7 +591,7 @@ mod tests {
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
let (event_tx, _event_rx) = mpsc::channel::<UIEvent>(8);
let orch = Orchestrator::new(NeverCalledProvider, action_rx, event_tx);
let orch = test_orchestrator(NeverCalledProvider, action_rx, event_tx);
let handle = tokio::spawn(orch.run());
action_tx.send(UserAction::Quit).await.unwrap();
@ -311,6 +628,7 @@ mod tests {
fn stream<'a>(
&'a self,
_messages: &'a [ConversationMessage],
_tools: &'a [ToolDefinition],
) -> impl Stream<Item = StreamEvent> + Send + 'a {
let events = self.turns.lock().unwrap().pop_front().unwrap_or_default();
futures::stream::iter(events)
@ -326,7 +644,7 @@ mod tests {
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(32);
let orch = Orchestrator::new(provider, action_rx, event_tx);
let orch = test_orchestrator(provider, action_rx, event_tx);
let handle = tokio::spawn(orch.run());
// First turn.
@ -350,4 +668,263 @@ mod tests {
action_tx.send(UserAction::Quit).await.unwrap();
handle.await.unwrap();
}
// -- tool-use accumulation ----------------------------------------------------
/// When the provider emits tool-use events, the orchestrator executes the
/// tool (auto-approve since read_file is AutoApprove), feeds the result back,
/// and the provider's second call returns text.
#[tokio::test]
async fn tool_use_loop_executes_and_feeds_back() {
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
struct MultiCallMock {
turns: Arc<Mutex<VecDeque<Vec<StreamEvent>>>>,
}
impl ModelProvider for MultiCallMock {
fn stream<'a>(
&'a self,
_messages: &'a [ConversationMessage],
_tools: &'a [ToolDefinition],
) -> impl Stream<Item = StreamEvent> + Send + 'a {
let events = self.turns.lock().unwrap().pop_front().unwrap_or_default();
futures::stream::iter(events)
}
}
let turns = Arc::new(Mutex::new(VecDeque::from([
// First call: tool use
vec![
StreamEvent::TextDelta("Let me read that.".to_string()),
StreamEvent::ToolUseDone, // spurious stop for text block
StreamEvent::ToolUseStart {
id: "toolu_01".to_string(),
name: "read_file".to_string(),
},
StreamEvent::ToolUseInputDelta("{\"path\":\"Cargo.toml\"}".to_string()),
StreamEvent::ToolUseDone,
StreamEvent::Done,
],
// Second call: text response after tool result
vec![
StreamEvent::TextDelta("Here's the file.".to_string()),
StreamEvent::Done,
],
])));
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(32);
// Use a real ToolRegistry so read_file works.
let dir = tempfile::TempDir::new().unwrap();
std::fs::write(dir.path().join("Cargo.toml"), "[package]\nname = \"test\"").unwrap();
let orch = Orchestrator::new(
MultiCallMock { turns },
ToolRegistry::default_tools(),
dir.path().to_path_buf(),
action_rx,
event_tx,
);
let handle = tokio::spawn(orch.run());
action_tx
.send(UserAction::SendMessage("read Cargo.toml".to_string()))
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let events = collect_events(&mut event_rx).await;
// Should see text deltas from both calls, tool events, and TurnComplete
assert!(
events
.iter()
.any(|e| matches!(e, UIEvent::StreamDelta(s) if s == "Let me read that.")),
"expected first text delta"
);
assert!(
events
.iter()
.any(|e| matches!(e, UIEvent::StreamDelta(s) if s == "Here's the file.")),
"expected second text delta"
);
assert!(
matches!(events.last(), Some(UIEvent::TurnComplete)),
"expected TurnComplete, got: {events:?}"
);
action_tx.send(UserAction::Quit).await.unwrap();
handle.await.unwrap();
}
// -- truncate -----------------------------------------------------------------
/// Truncating at a boundary inside a multibyte character must not panic.
/// "hello " is 6 bytes; the emoji is 4 bytes. Requesting max_len=8 would
/// previously slice at byte 8, splitting the emoji. The fixed version uses
/// char_indices so the cut falls on a char boundary.
#[test]
fn truncate_multibyte_does_not_panic() {
let s = "hello \u{1F30D} world"; // globe emoji is 4 bytes
// Should not panic and the result should end with "..."
let result = truncate(s, 8);
assert!(
result.ends_with("..."),
"expected truncation suffix: {result}"
);
// The string before "..." must be valid UTF-8 (guaranteed by the fix).
let prefix = result.trim_end_matches("...");
assert!(std::str::from_utf8(prefix.as_bytes()).is_ok());
}
// -- JSON parse error ---------------------------------------------------------
/// When the provider emits malformed tool-input JSON, the orchestrator must
/// surface a descriptive UIEvent::Error rather than silently using Null.
#[tokio::test]
async fn malformed_tool_json_surfaces_error() {
let provider = MockProvider::new(vec![
StreamEvent::ToolUseStart {
id: "toolu_bad".to_string(),
name: "read_file".to_string(),
},
StreamEvent::ToolUseInputDelta("not valid json{{".to_string()),
StreamEvent::ToolUseDone,
StreamEvent::Done,
]);
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(16);
let orch = test_orchestrator(provider, action_rx, event_tx);
let handle = tokio::spawn(orch.run());
action_tx
.send(UserAction::SendMessage("go".to_string()))
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let events = collect_events(&mut event_rx).await;
assert!(
events
.iter()
.any(|e| matches!(e, UIEvent::Error(msg) if msg.contains("tool input JSON"))),
"expected JSON parse error event, got: {events:?}"
);
action_tx.send(UserAction::Quit).await.unwrap();
handle.await.unwrap();
}
// -- queued messages ----------------------------------------------------------
/// A SendMessage sent while an approval prompt is open must be processed
/// after the current turn completes, not silently dropped.
///
/// This test uses a RequiresApproval tool (write_file) so that the
/// orchestrator blocks in wait_for_approval. While blocked, we send a second
/// message. After approving (denied here for simplicity), the queued message
/// should still be processed.
#[tokio::test]
async fn send_message_during_approval_is_queued_and_processed() {
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
struct MultiCallMock {
turns: Arc<Mutex<VecDeque<Vec<StreamEvent>>>>,
}
impl ModelProvider for MultiCallMock {
fn stream<'a>(
&'a self,
_messages: &'a [ConversationMessage],
_tools: &'a [ToolDefinition],
) -> impl Stream<Item = StreamEvent> + Send + 'a {
let events = self.turns.lock().unwrap().pop_front().unwrap_or_default();
futures::stream::iter(events)
}
}
let turns = Arc::new(Mutex::new(VecDeque::from([
// Turn 1: requests write_file (RequiresApproval)
vec![
StreamEvent::ToolUseStart {
id: "toolu_w".to_string(),
name: "write_file".to_string(),
},
StreamEvent::ToolUseInputDelta(
"{\"path\":\"x.txt\",\"content\":\"hi\"}".to_string(),
),
StreamEvent::ToolUseDone,
StreamEvent::Done,
],
// Turn 1 second iteration after tool result (denied)
vec![
StreamEvent::TextDelta("ok denied".to_string()),
StreamEvent::Done,
],
// Turn 2: the queued message
vec![
StreamEvent::TextDelta("queued reply".to_string()),
StreamEvent::Done,
],
])));
let (action_tx, action_rx) = mpsc::channel::<UserAction>(16);
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(64);
let dir = tempfile::TempDir::new().unwrap();
let orch = Orchestrator::new(
MultiCallMock { turns },
ToolRegistry::default_tools(),
dir.path().to_path_buf(),
action_rx,
event_tx,
);
let handle = tokio::spawn(orch.run());
// Start turn 1 -- orchestrator will block on approval.
action_tx
.send(UserAction::SendMessage("turn1".to_string()))
.await
.unwrap();
// Let the orchestrator reach wait_for_approval.
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
// Send a message while blocked -- should be queued.
action_tx
.send(UserAction::SendMessage("queued".to_string()))
.await
.unwrap();
// Deny the tool -- unblocks wait_for_approval.
action_tx
.send(UserAction::ToolApprovalResponse {
tool_use_id: "toolu_w".to_string(),
approved: false,
})
.await
.unwrap();
// Wait for both turns to complete.
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
// Collect everything.
let mut all_events = Vec::new();
while let Ok(ev) = event_rx.try_recv() {
all_events.push(ev);
}
// The queued message must have produced "queued reply".
assert!(
all_events
.iter()
.any(|e| matches!(e, UIEvent::StreamDelta(s) if s == "queued reply")),
"queued message was not processed; events: {all_events:?}"
);
action_tx.send(UserAction::Quit).await.unwrap();
handle.await.unwrap();
}
}

View file

@ -1,8 +1,16 @@
use serde::{Deserialize, Serialize};
/// A streaming event emitted by the model provider.
#[derive(Debug, Clone)]
pub enum StreamEvent {
/// A text chunk from the assistant's response.
TextDelta(String),
/// A new tool-use content block has started.
ToolUseStart { id: String, name: String },
/// A chunk of the tool-use input JSON (streamed incrementally).
ToolUseInputDelta(String),
/// The current tool-use content block has ended.
ToolUseDone,
/// Number of input tokens used in this request.
InputTokens(u32),
/// Number of output tokens generated so far.
@ -18,6 +26,8 @@ pub enum StreamEvent {
pub enum UserAction {
/// The user has submitted a message.
SendMessage(String),
/// The user has responded to a tool approval request.
ToolApprovalResponse { tool_use_id: String, approved: bool },
/// The user has requested to quit.
Quit,
/// The user has requested to clear conversation history.
@ -29,6 +39,23 @@ pub enum UserAction {
pub enum UIEvent {
/// A text chunk to append to the current assistant message.
StreamDelta(String),
/// A tool requires user approval before execution.
ToolApprovalRequest {
tool_use_id: String,
tool_name: String,
input_summary: String,
},
/// A tool is being executed (informational, after approval or auto-approve).
ToolExecuting {
tool_name: String,
input_summary: String,
},
/// A tool has finished executing.
ToolResult {
tool_name: String,
output_summary: String,
is_error: bool,
},
/// The current assistant turn has completed.
TurnComplete,
/// An error to display to the user.
@ -36,7 +63,7 @@ pub enum UIEvent {
}
/// The role of a participant in a conversation.
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
/// A message from the human user.
@ -45,11 +72,128 @@ pub enum Role {
Assistant,
}
/// A tool definition sent to the model so it knows which tools are available.
///
/// This is provider-agnostic -- the provider serializes it into the
/// format required by its API (e.g. the `tools` array for Anthropic).
#[derive(Debug, Clone, Serialize)]
pub struct ToolDefinition {
/// The tool name the model will use in `tool_use` blocks.
pub name: String,
/// Human-readable description of what the tool does.
pub description: String,
/// JSON Schema describing the tool's input parameters.
pub input_schema: serde_json::Value,
}
/// A typed content block within a conversation message.
///
/// The Anthropic Messages API represents message content as an array of typed
/// blocks. A single assistant message can contain interleaved text and tool-use
/// blocks; a user message following tool execution contains tool-result blocks.
///
/// See the [Messages API content blocks reference][content-blocks].
///
/// [content-blocks]: https://docs.anthropic.com/en/api/messages
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
/// Plain text content.
Text { text: String },
/// A tool invocation requested by the assistant.
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
/// The result of executing a tool, sent back in a user-role message.
ToolResult {
tool_use_id: String,
content: String,
#[serde(default)]
is_error: bool,
},
}
/// A single message in the conversation history.
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
///
/// Content is stored as a `Vec<ContentBlock>` to support mixed text and
/// tool-use blocks in a single message. For simple text-only messages, use
/// the [`ConversationMessage::text`] constructor.
///
/// # Serialization
///
/// Serializes `content` as a plain string when the message contains exactly
/// one `Text` block (the common case), and as an array of typed blocks
/// otherwise. This matches what the Anthropic API expects for both simple
/// and tool-use messages.
#[derive(Debug, Clone)]
pub struct ConversationMessage {
/// The role of the message author.
pub role: Role,
/// The text content of the message.
pub content: String,
/// The content blocks of this message.
pub content: Vec<ContentBlock>,
}
impl ConversationMessage {
/// Create a simple text-only message.
pub fn text(role: Role, s: impl Into<String>) -> Self {
Self {
role,
content: vec![ContentBlock::Text { text: s.into() }],
}
}
/// Extract the concatenated text content, ignoring non-text blocks.
#[allow(dead_code)]
pub fn text_content(&self) -> String {
let mut out = String::new();
for block in &self.content {
if let ContentBlock::Text { text } = block {
out.push_str(text);
}
}
out
}
}
impl Serialize for ConversationMessage {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
use serde::ser::SerializeMap;
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_entry("role", &self.role)?;
// Single text block -> serialize as plain string (common case).
// Otherwise -> serialize as array of content blocks.
match self.content.as_slice() {
[ContentBlock::Text { text }] => map.serialize_entry("content", text)?,
blocks => map.serialize_entry("content", blocks)?,
}
map.end()
}
}
impl<'de> Deserialize<'de> for ConversationMessage {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
#[derive(Deserialize)]
struct Raw {
role: Role,
content: serde_json::Value,
}
let raw = Raw::deserialize(deserializer)?;
let content = match raw.content {
serde_json::Value::String(s) => vec![ContentBlock::Text { text: s }],
serde_json::Value::Array(_) => {
serde_json::from_value(raw.content).map_err(serde::de::Error::custom)?
}
other => {
return Err(serde::de::Error::custom(format!(
"expected string or array for content, got {other}"
)));
}
};
Ok(ConversationMessage {
role: raw.role,
content,
})
}
}

View file

@ -1,6 +1,7 @@
mod app;
mod core;
mod provider;
mod tools;
mod tui;
use std::path::PathBuf;

View file

@ -2,7 +2,7 @@ use futures::{SinkExt, Stream, StreamExt};
use reqwest::Client;
use serde::Deserialize;
use crate::core::types::{ConversationMessage, StreamEvent};
use crate::core::types::{ConversationMessage, StreamEvent, ToolDefinition};
use super::ModelProvider;
@ -64,15 +64,17 @@ impl ModelProvider for ClaudeProvider {
fn stream<'a>(
&'a self,
messages: &'a [ConversationMessage],
tools: &'a [ToolDefinition],
) -> impl Stream<Item = StreamEvent> + Send + 'a {
let (mut tx, rx) = futures::channel::mpsc::channel(32);
let client = self.client.clone();
let api_key = self.api_key.clone();
let model = self.model.clone();
let messages = messages.to_vec();
let tools = tools.to_vec();
tokio::spawn(async move {
run_stream(client, api_key, model, messages, &mut tx).await;
run_stream(client, api_key, model, messages, tools, &mut tx).await;
});
rx
@ -126,14 +128,18 @@ async fn run_stream(
api_key: String,
model: String,
messages: Vec<ConversationMessage>,
tools: Vec<ToolDefinition>,
tx: &mut futures::channel::mpsc::Sender<StreamEvent>,
) {
let body = serde_json::json!({
let mut body = serde_json::json!({
"model": model,
"max_tokens": 8192,
"stream": true,
"messages": messages,
});
if !tools.is_empty() {
body["tools"] = serde_json::to_value(&tools).unwrap_or_default();
}
let response = match client
.post("https://api.anthropic.com/v1/messages")
@ -223,6 +229,8 @@ fn find_double_newline(buf: &[u8]) -> Option<usize> {
struct SseEvent {
#[serde(rename = "type")]
event_type: String,
/// Present on `content_block_start` events; describes the new block.
content_block: Option<SseContentBlock>,
/// Present on `content_block_delta` events.
delta: Option<SseDelta>,
/// Present on `message_start` events; carries initial token usage.
@ -231,16 +239,29 @@ struct SseEvent {
usage: Option<SseUsage>,
}
/// The `content_block` object inside a `content_block_start` event.
#[derive(Deserialize, Debug)]
struct SseContentBlock {
#[serde(rename = "type")]
block_type: String,
/// Tool-use block ID; present when `block_type == "tool_use"`.
id: Option<String>,
/// Tool name; present when `block_type == "tool_use"`.
name: Option<String>,
}
/// The `delta` object inside a `content_block_delta` event.
///
/// `type` is `"text_delta"` for plain text chunks; other delta types
/// (e.g. `"input_json_delta"` for tool-use blocks) are not yet handled.
/// `type` is `"text_delta"` for plain text chunks, or `"input_json_delta"`
/// for streaming tool-use input JSON.
#[derive(Deserialize, Debug)]
struct SseDelta {
#[serde(rename = "type")]
delta_type: Option<String>,
/// The text chunk; present when `delta_type == "text_delta"`.
text: Option<String>,
/// Partial JSON for tool input; present when `delta_type == "input_json_delta"`.
partial_json: Option<String>,
}
/// The `message` object inside a `message_start` event.
@ -273,13 +294,20 @@ struct SseUsage {
///
/// # Mapping to [`StreamEvent`]
///
/// | API event type | JSON path | Emits |
/// |----------------------|------------------------------------|------------------------------|
/// | `message_start` | `.message.usage.input_tokens` | `InputTokens(n)` |
/// | `content_block_delta`| `.delta.type == "text_delta"` | `TextDelta(chunk)` |
/// | `message_delta` | `.usage.output_tokens` | `OutputTokens(n)` |
/// | `message_stop` | n/a | `Done` |
/// | everything else | n/a | `None` (caller skips) |
/// | API event type | JSON path | Emits |
/// |------------------------|------------------------------------|------------------------------|
/// | `message_start` | `.message.usage.input_tokens` | `InputTokens(n)` |
/// | `content_block_start` | `.content_block.type == "tool_use"`| `ToolUseStart { id, name }` |
/// | `content_block_delta` | `.delta.type == "text_delta"` | `TextDelta(chunk)` |
/// | `content_block_delta` | `.delta.type == "input_json_delta"`| `ToolUseInputDelta(json)` |
/// | `content_block_stop` | n/a | `ToolUseDone` |
/// | `message_delta` | `.usage.output_tokens` | `OutputTokens(n)` |
/// | `message_stop` | n/a | `Done` |
/// | everything else | n/a | `None` (caller skips) |
///
/// Note: `content_block_stop` always emits `ToolUseDone`. The caller is
/// responsible for tracking whether a tool-use block is active and ignoring
/// `ToolUseDone` when it follows a text block.
///
/// [sse-fields]: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream
fn parse_sse_event(event_str: &str) -> Option<StreamEvent> {
@ -297,15 +325,29 @@ fn parse_sse_event(event_str: &str) -> Option<StreamEvent> {
.and_then(|u| u.input_tokens)
.map(StreamEvent::InputTokens),
"content_block_delta" => {
let delta = event.delta?;
if delta.delta_type.as_deref() == Some("text_delta") {
delta.text.map(StreamEvent::TextDelta)
"content_block_start" => {
let block = event.content_block?;
if block.block_type == "tool_use" {
Some(StreamEvent::ToolUseStart {
id: block.id.unwrap_or_default(),
name: block.name.unwrap_or_default(),
})
} else {
None
}
}
"content_block_delta" => {
let delta = event.delta?;
match delta.delta_type.as_deref() {
Some("text_delta") => delta.text.map(StreamEvent::TextDelta),
Some("input_json_delta") => delta.partial_json.map(StreamEvent::ToolUseInputDelta),
_ => None,
}
}
"content_block_stop" => Some(StreamEvent::ToolUseDone),
// usage lives at the top level of message_delta, not inside delta.
"message_delta" => event
.usage
@ -314,8 +356,7 @@ fn parse_sse_event(event_str: &str) -> Option<StreamEvent> {
"message_stop" => Some(StreamEvent::Done),
// error, ping, content_block_start, content_block_stop -- ignored or
// handled by the caller.
// error, ping -- ignored.
_ => None,
}
}
@ -371,13 +412,15 @@ mod tests {
.filter_map(parse_sse_event)
.collect();
// content_block_start, ping, content_block_stop -> None (filtered out)
assert_eq!(events.len(), 5);
// content_block_start (text) and ping -> None (filtered out)
// content_block_stop -> ToolUseDone (always emitted; caller filters)
assert_eq!(events.len(), 6);
assert!(matches!(events[0], StreamEvent::InputTokens(10)));
assert!(matches!(&events[1], StreamEvent::TextDelta(s) if s == "Hello"));
assert!(matches!(&events[2], StreamEvent::TextDelta(s) if s == ", world!"));
assert!(matches!(events[3], StreamEvent::OutputTokens(5)));
assert!(matches!(events[4], StreamEvent::Done));
assert!(matches!(events[3], StreamEvent::ToolUseDone));
assert!(matches!(events[4], StreamEvent::OutputTokens(5)));
assert!(matches!(events[5], StreamEvent::Done));
}
#[test]
@ -408,14 +451,8 @@ mod tests {
#[test]
fn test_messages_serialize_to_anthropic_format() {
let messages = vec![
ConversationMessage {
role: Role::User,
content: "Hello".to_string(),
},
ConversationMessage {
role: Role::Assistant,
content: "Hi there!".to_string(),
},
ConversationMessage::text(Role::User, "Hello"),
ConversationMessage::text(Role::Assistant, "Hi there!"),
];
let json = serde_json::json!({
@ -433,6 +470,65 @@ mod tests {
assert!(json["max_tokens"].as_u64().unwrap() > 0);
}
/// SSE fixture for a response that contains a tool-use block.
///
/// Sequence: message_start -> content_block_start (tool_use) ->
/// content_block_delta (input_json_delta) -> content_block_stop ->
/// message_delta -> message_stop
const TOOL_USE_SSE_FIXTURE: &str = concat!(
"event: message_start\n",
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_456\",\"type\":\"message\",",
"\"role\":\"assistant\",\"content\":[],\"model\":\"claude-opus-4-6\",",
"\"stop_reason\":null,\"stop_sequence\":null,",
"\"usage\":{\"input_tokens\":25,\"cache_creation_input_tokens\":0,",
"\"cache_read_input_tokens\":0,\"output_tokens\":1}}}\n",
"\n",
"event: content_block_start\n",
"data: {\"type\":\"content_block_start\",\"index\":0,",
"\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_01ABC\",\"name\":\"read_file\",\"input\":{}}}\n",
"\n",
"event: content_block_delta\n",
"data: {\"type\":\"content_block_delta\",\"index\":0,",
"\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"path\\\": \\\"src/\"}}\n",
"\n",
"event: content_block_delta\n",
"data: {\"type\":\"content_block_delta\",\"index\":0,",
"\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"main.rs\\\"}\"}}\n",
"\n",
"event: content_block_stop\n",
"data: {\"type\":\"content_block_stop\",\"index\":0}\n",
"\n",
"event: message_delta\n",
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",",
"\"stop_sequence\":null},\"usage\":{\"output_tokens\":15}}\n",
"\n",
"event: message_stop\n",
"data: {\"type\":\"message_stop\"}\n",
"\n",
);
#[test]
fn test_parse_tool_use_sse_fixture() {
let events: Vec<StreamEvent> = TOOL_USE_SSE_FIXTURE
.split("\n\n")
.filter(|s| !s.trim().is_empty())
.filter_map(parse_sse_event)
.collect();
assert_eq!(events.len(), 7);
assert!(matches!(events[0], StreamEvent::InputTokens(25)));
assert!(
matches!(&events[1], StreamEvent::ToolUseStart { id, name } if id == "toolu_01ABC" && name == "read_file")
);
assert!(
matches!(&events[2], StreamEvent::ToolUseInputDelta(s) if s == "{\"path\": \"src/")
);
assert!(matches!(&events[3], StreamEvent::ToolUseInputDelta(s) if s == "main.rs\"}"));
assert!(matches!(events[4], StreamEvent::ToolUseDone));
assert!(matches!(events[5], StreamEvent::OutputTokens(15)));
assert!(matches!(events[6], StreamEvent::Done));
}
#[test]
fn test_find_double_newline() {
assert_eq!(find_double_newline(b"abc\n\ndef"), Some(3));

View file

@ -4,16 +4,18 @@ pub use claude::ClaudeProvider;
use futures::Stream;
use crate::core::types::{ConversationMessage, StreamEvent};
use crate::core::types::{ConversationMessage, StreamEvent, ToolDefinition};
/// Trait for model providers that can stream conversation responses.
///
/// Implementors take a conversation history and return a stream of [`StreamEvent`]s.
/// The trait is provider-agnostic -- no Claude-specific types appear here.
pub trait ModelProvider: Send + Sync {
/// Stream a response from the model given the conversation history.
/// Stream a response from the model given the conversation history and
/// available tool definitions. Pass an empty slice if no tools are available.
fn stream<'a>(
&'a self,
messages: &'a [ConversationMessage],
tools: &'a [ToolDefinition],
) -> impl Stream<Item = StreamEvent> + Send + 'a;
}

View file

@ -0,0 +1,87 @@
//! `list_directory` tool: lists entries in a directory within the working directory.
use std::path::Path;
use async_trait::async_trait;
use super::{RiskLevel, Tool, ToolError, ToolOutput, validate_path};
/// Lists directory contents. Auto-approved (read-only).
pub struct ListDirectory;
#[async_trait]
impl Tool for ListDirectory {
fn name(&self) -> &str {
"list_directory"
}
fn description(&self) -> &str {
"List the files and subdirectories in a directory. The path is relative to the working directory."
}
fn input_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "The directory path to list, relative to the working directory. Use '.' for the working directory itself."
}
},
"required": ["path"]
})
}
fn risk_level(&self) -> RiskLevel {
RiskLevel::AutoApprove
}
async fn execute(
&self,
input: &serde_json::Value,
working_dir: &Path,
) -> Result<ToolOutput, ToolError> {
let path_str = input["path"]
.as_str()
.ok_or_else(|| ToolError::InvalidInput("missing 'path' string".to_string()))?;
let canonical = validate_path(working_dir, path_str)?;
let mut entries: Vec<String> = Vec::new();
let mut dir = tokio::fs::read_dir(&canonical).await?;
while let Some(entry) = dir.next_entry().await? {
let name = entry.file_name().to_string_lossy().to_string();
let suffix = if entry.file_type().await?.is_dir() {
"/"
} else {
""
};
entries.push(format!("{name}{suffix}"));
}
entries.sort();
Ok(ToolOutput {
content: entries.join("\n"),
is_error: false,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[tokio::test]
async fn list_directory_contents() {
let dir = TempDir::new().unwrap();
fs::write(dir.path().join("a.txt"), "").unwrap();
fs::create_dir(dir.path().join("subdir")).unwrap();
let tool = ListDirectory;
let input = serde_json::json!({"path": "."});
let out = tool.execute(&input, dir.path()).await.unwrap();
assert!(out.content.contains("a.txt"));
assert!(out.content.contains("subdir/"));
}
}

220
src/tools/mod.rs Normal file
View file

@ -0,0 +1,220 @@
//! Tool system: trait, registry, risk classification, and built-in tools.
//!
//! All tools implement the [`Tool`] trait. The [`ToolRegistry`] collects them
//! and provides lookup by name plus generation of [`ToolDefinition`]s for the
//! model provider.
mod list_directory;
mod read_file;
mod shell_exec;
mod write_file;
use std::path::{Path, PathBuf};
use async_trait::async_trait;
use crate::core::types::ToolDefinition;
/// The output of a tool execution.
#[derive(Debug)]
pub struct ToolOutput {
/// The text content returned to the model.
pub content: String,
/// Whether the tool encountered an error.
pub is_error: bool,
}
/// Risk classification for tool approval gating.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RiskLevel {
/// Safe to execute without user confirmation (e.g. read-only operations).
AutoApprove,
/// Requires explicit user approval before execution (e.g. writes, shell).
RequiresApproval,
}
/// A tool that the model can invoke.
///
/// The `execute` method is async so that tool implementations can use
/// `tokio::fs` and `tokio::process` without blocking a Tokio worker thread.
/// `#[async_trait]` desugars the async fn to a boxed future, which is required
/// for `dyn Tool` to remain object-safe.
#[async_trait]
pub trait Tool: Send + Sync {
/// The name the model uses to invoke this tool.
fn name(&self) -> &str;
/// Human-readable description for the model.
fn description(&self) -> &str;
/// JSON Schema for the tool's input parameters.
fn input_schema(&self) -> serde_json::Value;
/// The risk level of this tool.
fn risk_level(&self) -> RiskLevel;
/// Execute the tool with the given input, confined to `working_dir`.
async fn execute(
&self,
input: &serde_json::Value,
working_dir: &Path,
) -> Result<ToolOutput, ToolError>;
}
/// Errors from tool execution.
#[derive(Debug, thiserror::Error)]
pub enum ToolError {
/// The requested path escapes the working directory.
#[error("path escapes working directory: {0}")]
PathEscape(PathBuf),
/// An I/O error during tool execution.
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
/// A required input field is missing or has the wrong type.
#[error("invalid input: {0}")]
InvalidInput(String),
}
/// Validate that `requested` resolves to a path inside `working_dir`.
///
/// Joins `working_dir` with `requested`, canonicalizes the result (resolving
/// symlinks and `..` components), and checks that the canonical path starts
/// with the canonical working directory.
///
/// Returns the canonical path on success, or [`ToolError::PathEscape`] if the
/// path would escape the working directory.
pub fn validate_path(working_dir: &Path, requested: &str) -> Result<PathBuf, ToolError> {
let candidate = if Path::new(requested).is_absolute() {
PathBuf::from(requested)
} else {
working_dir.join(requested)
};
// For paths that don't exist yet (e.g. write_file creating a new file),
// canonicalize the parent directory and append the filename.
let canonical = if candidate.exists() {
candidate
.canonicalize()
.map_err(|_| ToolError::PathEscape(candidate.clone()))?
} else {
let parent = candidate
.parent()
.ok_or_else(|| ToolError::PathEscape(candidate.clone()))?;
let file_name = candidate
.file_name()
.ok_or_else(|| ToolError::PathEscape(candidate.clone()))?;
let canonical_parent = parent
.canonicalize()
.map_err(|_| ToolError::PathEscape(candidate.clone()))?;
canonical_parent.join(file_name)
};
let canonical_root = working_dir
.canonicalize()
.map_err(|_| ToolError::PathEscape(candidate.clone()))?;
if canonical.starts_with(&canonical_root) {
Ok(canonical)
} else {
Err(ToolError::PathEscape(candidate))
}
}
/// Collection of available tools with name-based lookup.
pub struct ToolRegistry {
tools: Vec<Box<dyn Tool>>,
}
impl ToolRegistry {
/// Create an empty registry with no tools.
#[allow(dead_code)]
pub fn empty() -> Self {
Self { tools: Vec::new() }
}
/// Create a registry with the default built-in tools.
pub fn default_tools() -> Self {
Self {
tools: vec![
Box::new(read_file::ReadFile),
Box::new(list_directory::ListDirectory),
Box::new(write_file::WriteFile),
Box::new(shell_exec::ShellExec),
],
}
}
/// Look up a tool by name.
pub fn get(&self, name: &str) -> Option<&dyn Tool> {
self.tools.iter().find(|t| t.name() == name).map(|t| &**t)
}
/// Generate [`ToolDefinition`]s for the model provider.
pub fn definitions(&self) -> Vec<ToolDefinition> {
self.tools
.iter()
.map(|t| ToolDefinition {
name: t.name().to_string(),
description: t.description().to_string(),
input_schema: t.input_schema(),
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[test]
fn validate_path_allows_subpath() {
let dir = TempDir::new().unwrap();
let sub = dir.path().join("sub");
fs::create_dir(&sub).unwrap();
let result = validate_path(dir.path(), "sub");
assert!(result.is_ok());
assert!(
result
.unwrap()
.starts_with(dir.path().canonicalize().unwrap())
);
}
#[test]
fn validate_path_rejects_traversal() {
let dir = TempDir::new().unwrap();
let result = validate_path(dir.path(), "../../../etc/passwd");
assert!(result.is_err());
assert!(matches!(result, Err(ToolError::PathEscape(_))));
}
#[test]
fn validate_path_rejects_absolute_outside() {
let dir = TempDir::new().unwrap();
let result = validate_path(dir.path(), "/etc/passwd");
assert!(result.is_err());
}
#[test]
fn validate_path_allows_new_file_in_working_dir() {
let dir = TempDir::new().unwrap();
let result = validate_path(dir.path(), "new_file.txt");
assert!(result.is_ok());
}
#[test]
fn registry_default_has_all_tools() {
let reg = ToolRegistry::default_tools();
assert!(reg.get("read_file").is_some());
assert!(reg.get("list_directory").is_some());
assert!(reg.get("write_file").is_some());
assert!(reg.get("shell_exec").is_some());
assert!(reg.get("nonexistent").is_none());
}
#[test]
fn registry_definitions_match_tools() {
let reg = ToolRegistry::default_tools();
let defs = reg.definitions();
assert_eq!(defs.len(), 4);
assert!(defs.iter().any(|d| d.name == "read_file"));
}
}

92
src/tools/read_file.rs Normal file
View file

@ -0,0 +1,92 @@
//! `read_file` tool: reads the contents of a file within the working directory.
use std::path::Path;
use async_trait::async_trait;
use super::{RiskLevel, Tool, ToolError, ToolOutput, validate_path};
/// Reads file contents. Auto-approved (read-only).
pub struct ReadFile;
#[async_trait]
impl Tool for ReadFile {
fn name(&self) -> &str {
"read_file"
}
fn description(&self) -> &str {
"Read the contents of a file. The path is relative to the working directory."
}
fn input_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "The file path to read, relative to the working directory."
}
},
"required": ["path"]
})
}
fn risk_level(&self) -> RiskLevel {
RiskLevel::AutoApprove
}
async fn execute(
&self,
input: &serde_json::Value,
working_dir: &Path,
) -> Result<ToolOutput, ToolError> {
let path_str = input["path"]
.as_str()
.ok_or_else(|| ToolError::InvalidInput("missing 'path' string".to_string()))?;
let canonical = validate_path(working_dir, path_str)?;
let content = tokio::fs::read_to_string(&canonical).await?;
Ok(ToolOutput {
content,
is_error: false,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[tokio::test]
async fn read_existing_file() {
let dir = TempDir::new().unwrap();
fs::write(dir.path().join("hello.txt"), "world").unwrap();
let tool = ReadFile;
let input = serde_json::json!({"path": "hello.txt"});
let out = tool.execute(&input, dir.path()).await.unwrap();
assert_eq!(out.content, "world");
assert!(!out.is_error);
}
#[tokio::test]
async fn read_nonexistent_file_errors() {
let dir = TempDir::new().unwrap();
let tool = ReadFile;
let input = serde_json::json!({"path": "nope.txt"});
let result = tool.execute(&input, dir.path()).await;
assert!(result.is_err());
}
#[tokio::test]
async fn read_path_traversal_rejected() {
let dir = TempDir::new().unwrap();
let tool = ReadFile;
let input = serde_json::json!({"path": "../../../etc/passwd"});
let result = tool.execute(&input, dir.path()).await;
assert!(matches!(result, Err(ToolError::PathEscape(_))));
}
}

108
src/tools/shell_exec.rs Normal file
View file

@ -0,0 +1,108 @@
//! `shell_exec` tool: runs a shell command within the working directory.
use std::path::Path;
use async_trait::async_trait;
use super::{RiskLevel, Tool, ToolError, ToolOutput};
/// Executes a shell command. Requires user approval.
pub struct ShellExec;
#[async_trait]
impl Tool for ShellExec {
fn name(&self) -> &str {
"shell_exec"
}
fn description(&self) -> &str {
"Execute a shell command in the working directory. Returns stdout and stderr."
}
fn input_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The shell command to execute."
}
},
"required": ["command"]
})
}
fn risk_level(&self) -> RiskLevel {
RiskLevel::RequiresApproval
}
async fn execute(
&self,
input: &serde_json::Value,
working_dir: &Path,
) -> Result<ToolOutput, ToolError> {
let command = input["command"]
.as_str()
.ok_or_else(|| ToolError::InvalidInput("missing 'command' string".to_string()))?;
let output = tokio::process::Command::new("sh")
.arg("-c")
.arg(command)
.current_dir(working_dir)
.output()
.await?;
let stdout = String::from_utf8_lossy(&output.stdout);
let stderr = String::from_utf8_lossy(&output.stderr);
let mut content = String::new();
if !stdout.is_empty() {
content.push_str(&stdout);
}
if !stderr.is_empty() {
if !content.is_empty() {
content.push('\n');
}
content.push_str("[stderr]\n");
content.push_str(&stderr);
}
if content.is_empty() {
content.push_str("(no output)");
}
let is_error = !output.status.success();
if is_error {
content.push_str(&format!(
"\n[exit code: {}]",
output.status.code().unwrap_or(-1)
));
}
Ok(ToolOutput { content, is_error })
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn shell_exec_echo() {
let dir = TempDir::new().unwrap();
let tool = ShellExec;
let input = serde_json::json!({"command": "echo hello"});
let out = tool.execute(&input, dir.path()).await.unwrap();
assert!(out.content.contains("hello"));
assert!(!out.is_error);
}
#[tokio::test]
async fn shell_exec_failing_command() {
let dir = TempDir::new().unwrap();
let tool = ShellExec;
let input = serde_json::json!({"command": "false"});
let out = tool.execute(&input, dir.path()).await.unwrap();
assert!(out.is_error);
}
}

98
src/tools/write_file.rs Normal file
View file

@ -0,0 +1,98 @@
//! `write_file` tool: writes content to a file within the working directory.
use std::path::Path;
use async_trait::async_trait;
use super::{RiskLevel, Tool, ToolError, ToolOutput, validate_path};
/// Writes content to a file. Requires user approval.
pub struct WriteFile;
#[async_trait]
impl Tool for WriteFile {
fn name(&self) -> &str {
"write_file"
}
fn description(&self) -> &str {
"Write content to a file. Creates the file if it doesn't exist, overwrites if it does. The path is relative to the working directory."
}
fn input_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "The file path to write to, relative to the working directory."
},
"content": {
"type": "string",
"description": "The content to write to the file."
}
},
"required": ["path", "content"]
})
}
fn risk_level(&self) -> RiskLevel {
RiskLevel::RequiresApproval
}
async fn execute(
&self,
input: &serde_json::Value,
working_dir: &Path,
) -> Result<ToolOutput, ToolError> {
let path_str = input["path"]
.as_str()
.ok_or_else(|| ToolError::InvalidInput("missing 'path' string".to_string()))?;
let content = input["content"]
.as_str()
.ok_or_else(|| ToolError::InvalidInput("missing 'content' string".to_string()))?;
let canonical = validate_path(working_dir, path_str)?;
// Create parent directories if needed.
if let Some(parent) = canonical.parent() {
tokio::fs::create_dir_all(parent).await?;
}
tokio::fs::write(&canonical, content).await?;
Ok(ToolOutput {
content: format!("Wrote {} bytes to {path_str}", content.len()),
is_error: false,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[tokio::test]
async fn write_creates_file() {
let dir = TempDir::new().unwrap();
let tool = WriteFile;
let input = serde_json::json!({"path": "out.txt", "content": "hello"});
let out = tool.execute(&input, dir.path()).await.unwrap();
assert!(!out.is_error);
assert_eq!(
fs::read_to_string(dir.path().join("out.txt")).unwrap(),
"hello"
);
}
#[tokio::test]
async fn write_path_traversal_rejected() {
let dir = TempDir::new().unwrap();
let tool = WriteFile;
let input = serde_json::json!({"path": "../../evil.txt", "content": "bad"});
let result = tool.execute(&input, dir.path()).await;
assert!(matches!(result, Err(ToolError::PathEscape(_))));
}
}

View file

@ -11,11 +11,14 @@ use crate::core::types::{Role, UIEvent};
/// This is non-blocking: it processes all currently-available events and returns
/// immediately when the channel is empty.
///
/// | Event | Effect |
/// |--------------------|------------------------------------------------------------|
/// | `StreamDelta(s)` | Append `s` to last message if it's `Assistant`; else push new |
/// | `TurnComplete` | No structural change; logged at debug level |
/// | `Error(msg)` | Push `(Assistant, "[error] {msg}")` |
/// | Event | Effect |
/// |------------------------|------------------------------------------------------------|
/// | `StreamDelta(s)` | Append `s` to last message if it's `Assistant`; else push |
/// | `ToolApprovalRequest` | Set `pending_approval` in state |
/// | `ToolExecuting` | Display tool execution info |
/// | `ToolResult` | Display tool result |
/// | `TurnComplete` | No structural change; logged at debug level |
/// | `Error(msg)` | Push `(Assistant, "[error] {msg}")` |
pub(super) fn drain_ui_events(event_rx: &mut mpsc::Receiver<UIEvent>, state: &mut AppState) {
while let Ok(event) = event_rx.try_recv() {
match event {
@ -26,6 +29,36 @@ pub(super) fn drain_ui_events(event_rx: &mut mpsc::Receiver<UIEvent>, state: &mu
state.messages.push((Role::Assistant, chunk));
}
}
UIEvent::ToolApprovalRequest {
tool_use_id,
tool_name,
input_summary,
} => {
state.pending_approval = Some(PendingApproval {
tool_use_id,
tool_name,
input_summary,
});
}
UIEvent::ToolExecuting {
tool_name,
input_summary,
} => {
state
.messages
.push((Role::Assistant, format!("[{tool_name}] {input_summary}")));
}
UIEvent::ToolResult {
tool_name,
output_summary,
is_error,
} => {
let prefix = if is_error { "error" } else { "result" };
state.messages.push((
Role::Assistant,
format!("[{tool_name} {prefix}] {output_summary}"),
));
}
UIEvent::TurnComplete => {
debug!("turn complete");
}
@ -38,6 +71,14 @@ pub(super) fn drain_ui_events(event_rx: &mut mpsc::Receiver<UIEvent>, state: &mu
}
}
/// A pending tool approval request waiting for user input.
#[derive(Debug, Clone)]
pub struct PendingApproval {
pub tool_use_id: String,
pub tool_name: String,
pub input_summary: String,
}
#[cfg(test)]
mod tests {
use super::*;
@ -69,4 +110,39 @@ mod tests {
assert_eq!(state.messages[1].0, Role::Assistant);
assert_eq!(state.messages[1].1, "hello");
}
#[tokio::test]
async fn drain_tool_approval_sets_pending() {
let (tx, mut rx) = tokio::sync::mpsc::channel(8);
let mut state = AppState::new();
tx.send(UIEvent::ToolApprovalRequest {
tool_use_id: "t1".to_string(),
tool_name: "write_file".to_string(),
input_summary: "path: foo.txt".to_string(),
})
.await
.unwrap();
drop(tx);
drain_ui_events(&mut rx, &mut state);
assert!(state.pending_approval.is_some());
let approval = state.pending_approval.unwrap();
assert_eq!(approval.tool_name, "write_file");
}
#[tokio::test]
async fn drain_tool_result_adds_message() {
let (tx, mut rx) = tokio::sync::mpsc::channel(8);
let mut state = AppState::new();
tx.send(UIEvent::ToolResult {
tool_name: "read_file".to_string(),
output_summary: "file contents...".to_string(),
is_error: false,
})
.await
.unwrap();
drop(tx);
drain_ui_events(&mut rx, &mut state);
assert_eq!(state.messages.len(), 1);
assert!(state.messages[0].1.contains("read_file result"));
}
}

View file

@ -13,6 +13,8 @@ pub(super) enum LoopControl {
Quit,
/// The user ran `:clear`; wipe the conversation.
ClearHistory,
/// The user responded to a tool approval prompt.
ToolApproval { tool_use_id: String, approved: bool },
}
/// Map a key event to a [`LoopControl`] signal, mutating `state` as a side-effect.
@ -23,6 +25,29 @@ pub(super) fn handle_key(key: Option<KeyEvent>, state: &mut AppState) -> Option<
let key = key?;
// Clear any transient status error on the next keypress.
state.status_error = None;
// If a tool approval is pending, intercept y/n before normal key handling.
if let Some(approval) = &state.pending_approval {
let tool_use_id = approval.tool_use_id.clone();
match key.code {
KeyCode::Char('y') | KeyCode::Char('Y') => {
state.pending_approval = None;
return Some(LoopControl::ToolApproval {
tool_use_id,
approved: true,
});
}
KeyCode::Char('n') | KeyCode::Char('N') => {
state.pending_approval = None;
return Some(LoopControl::ToolApproval {
tool_use_id,
approved: false,
});
}
_ => return None, // ignore other keys while approval pending
}
}
// Ctrl+C quits from any mode.
if key.modifiers.contains(KeyModifiers::CONTROL) && key.code == KeyCode::Char('c') {
return Some(LoopControl::Quit);

View file

@ -72,6 +72,8 @@ pub struct AppState {
pub viewport_height: u16,
/// Transient error message shown in the status bar, cleared on next keypress.
pub status_error: Option<String>,
/// A tool approval request waiting for user input (y/n).
pub pending_approval: Option<events::PendingApproval>,
}
impl AppState {
@ -85,6 +87,7 @@ impl AppState {
pending_keys: Vec::new(),
viewport_height: 0,
status_error: None,
pending_approval: None,
}
}
}
@ -185,6 +188,17 @@ pub async fn run(
Some(input::LoopControl::ClearHistory) => {
let _ = action_tx.send(UserAction::ClearHistory).await;
}
Some(input::LoopControl::ToolApproval {
tool_use_id,
approved,
}) => {
let _ = action_tx
.send(UserAction::ToolApprovalResponse {
tool_use_id,
approved,
})
.await;
}
None => {}
}
}

View file

@ -105,11 +105,37 @@ pub(super) fn render(frame: &mut Frame, state: &AppState) {
let output = Paragraph::new(lines)
.wrap(Wrap { trim: false })
.scroll((state.scroll, 0));
frame.render_widget(output, chunks[0]);
let output_area = chunks[0];
frame.render_widget(output, output_area);
// --- Tool approval overlay ---
if let Some(ref approval) = state.pending_approval {
let overlay_w = (output_area.width / 2).max(60).min(output_area.width);
let overlay_h: u16 = 5;
let overlay_x = output_area.x + (output_area.width.saturating_sub(overlay_w)) / 2;
let overlay_y = output_area.y + output_area.height.saturating_sub(overlay_h) / 2;
let overlay_area = Rect {
x: overlay_x,
y: overlay_y,
width: overlay_w,
height: overlay_h.min(output_area.height),
};
frame.render_widget(Clear, overlay_area);
let text = format!(
"{}: {}\n\ny = approve, n = deny",
approval.tool_name, approval.input_summary
);
let overlay = Paragraph::new(text).block(
Block::bordered()
.border_style(Style::default().fg(Color::Yellow))
.title("Tool Approval"),
);
frame.render_widget(overlay, overlay_area);
}
// --- Command overlay (floating box centered on output pane) ---
if state.mode == Mode::Command {
let overlay_area = command_overlay_rect(chunks[0]);
let overlay_area = command_overlay_rect(output_area);
// Clear the area behind the overlay so it appears floating.
frame.render_widget(Clear, overlay_area);
let overlay = Paragraph::new(format!(":{}", state.command_buffer)).block(
@ -146,7 +172,7 @@ pub(super) fn render(frame: &mut Frame, state: &AppState) {
}
Mode::Command => {
// Cursor in the floating overlay
let overlay = command_overlay_rect(chunks[0]);
let overlay = command_overlay_rect(output_area);
// border(1) + ":" (1) + buf len
let cursor_x = overlay.x + 1 + 1 + state.command_buffer.len() as u16;
let cursor_y = overlay.y + 1; // inside the border
@ -386,4 +412,31 @@ mod tests {
"expected error in status bar"
);
}
#[test]
fn render_approval_overlay_visible() {
let backend = TestBackend::new(80, 24);
let mut terminal = Terminal::new(backend).unwrap();
let mut state = AppState::new();
state.pending_approval = Some(super::super::events::PendingApproval {
tool_use_id: "t1".to_string(),
tool_name: "write_file".to_string(),
input_summary: "path: foo.txt".to_string(),
});
terminal.draw(|frame| render(frame, &state)).unwrap();
let buf = terminal.backend().buffer().clone();
let all_text: String = buf
.content()
.iter()
.map(|c| c.symbol().to_string())
.collect();
assert!(
all_text.contains("Tool Approval"),
"expected 'Tool Approval' overlay"
);
assert!(
all_text.contains("write_file"),
"expected tool name in overlay"
);
}
}