skate/src/core/orchestrator.rs
Drew Galbraith 5a49cba1e6 Allow displaying diffs in the ui. (#8)
Reviewed-on: #8
Co-authored-by: Drew Galbraith <drew@tiramisu.one>
Co-committed-by: Drew Galbraith <drew@tiramisu.one>
2026-03-11 16:37:42 +00:00

1093 lines
41 KiB
Rust

use futures::StreamExt;
use tokio::sync::mpsc;
use tracing::{debug, warn};
use crate::core::history::ConversationHistory;
use crate::core::types::{
ContentBlock, ConversationMessage, Role, StampedEvent, StreamEvent, ToolDefinition,
ToolDisplay, UIEvent, UserAction,
};
use crate::provider::ModelProvider;
use crate::sandbox::Sandbox;
use crate::tools::{RiskLevel, ToolOutput, ToolRegistry};
/// Accumulates data for a single tool-use block while it is being streamed.
///
/// Created on `ToolUseStart`, populated by `ToolUseInputDelta` fragments, and
/// consumed on `ToolUseDone` via `TryFrom<ActiveToolUse> for ContentBlock`.
struct ActiveToolUse {
id: String,
name: String,
json_buf: String,
}
impl ActiveToolUse {
/// Start accumulating a new tool-use block identified by `id` and `name`.
fn new(id: String, name: String) -> Self {
Self {
id,
name,
json_buf: String::new(),
}
}
/// Append a JSON input fragment received from a `ToolUseInputDelta` event.
fn append(&mut self, chunk: &str) {
self.json_buf.push_str(chunk);
}
}
impl TryFrom<ActiveToolUse> for ContentBlock {
type Error = serde_json::Error;
/// Parse the accumulated JSON buffer and produce a `ContentBlock::ToolUse`.
fn try_from(t: ActiveToolUse) -> Result<Self, Self::Error> {
let input = serde_json::from_str(&t.json_buf)?;
Ok(ContentBlock::ToolUse {
id: t.id,
name: t.name,
input,
})
}
}
/// Drives the conversation loop between the TUI frontend and the model provider.
///
/// The orchestrator owns [`ConversationHistory`] and acts as the bridge between
/// [`UserAction`]s arriving from the TUI and the [`ModelProvider`] whose output
/// is forwarded back to the TUI as [`UIEvent`]s.
///
/// # Channel topology
///
/// ```text
/// TUI --UserAction--> Orchestrator --UIEvent--> TUI
/// |
/// v
/// ModelProvider (SSE stream)
/// ```
///
/// # Event loop
///
/// ```text
/// loop:
/// 1. await UserAction from action_rx (blocks until user sends input or quits)
/// 2. SendMessage:
/// a. Append user message to history
/// b. Call provider.stream(history) -- starts an SSE request
/// c. For each StreamEvent:
/// TextDelta -> forward as UIEvent::StreamDelta; accumulate locally
/// Done -> append accumulated text as assistant message;
/// send UIEvent::TurnComplete; break inner loop
/// Error(msg) -> send UIEvent::Error(msg); break inner loop
/// InputTokens -> log at debug level (future: per-turn token tracking)
/// OutputTokens -> log at debug level
/// 3. Quit -> return
/// ```
/// The result of consuming one provider stream (one assistant turn).
enum StreamResult {
/// The stream completed successfully with these content blocks.
Done(Vec<ContentBlock>),
/// The stream produced an error.
Error(String),
}
/// Maximum number of tool-use loop iterations per user message.
const MAX_TOOL_ITERATIONS: usize = 25;
/// Truncate a string to `max_len` Unicode scalar values, appending "..." if
/// truncated. Uses `char_indices` so that multibyte characters are never split
/// at a byte boundary (which would panic on a bare slice).
fn truncate(s: &str, max_len: usize) -> String {
match s.char_indices().nth(max_len) {
Some((i, _)) => format!("{}...", &s[..i]),
None => s.to_string(),
}
}
pub struct Orchestrator<P> {
history: ConversationHistory,
provider: P,
tool_registry: ToolRegistry,
sandbox: Sandbox,
action_rx: mpsc::Receiver<UserAction>,
event_tx: mpsc::Sender<StampedEvent>,
/// Messages typed by the user while an approval prompt is open. They are
/// queued here and replayed as new turns once the current turn completes.
queued_messages: Vec<String>,
/// Monotonic epoch incremented on each `:clear`. Events are tagged with
/// this value so the TUI can discard stale in-flight messages.
epoch: u64,
}
impl<P: ModelProvider> Orchestrator<P> {
/// Construct an orchestrator using the given provider and channel endpoints.
pub fn new(
provider: P,
tool_registry: ToolRegistry,
sandbox: Sandbox,
action_rx: mpsc::Receiver<UserAction>,
event_tx: mpsc::Sender<StampedEvent>,
) -> Self {
Self {
history: ConversationHistory::new(),
provider,
tool_registry,
sandbox,
action_rx,
event_tx,
queued_messages: Vec::new(),
epoch: 0,
}
}
/// Send a [`UIEvent`] stamped with the current epoch.
///
/// Drops the event silently if the channel is full or closed -- this is
/// intentional (TUI death should not stall the orchestrator), but the
/// debug log below makes backpressure visible during development.
async fn send(&self, event: UIEvent) {
let result = self
.event_tx
.send(StampedEvent {
epoch: self.epoch,
event,
})
.await;
if result.is_err() {
debug!("ui event dropped -- channel closed or full");
}
}
/// Consume the provider's stream for one turn, returning the content blocks
/// produced by the assistant (text and/or tool-use) or an error.
///
/// Tool-use input JSON is accumulated from `ToolUseInputDelta` fragments and
/// parsed into `ContentBlock::ToolUse` on `ToolUseDone`.
async fn consume_stream(
&self,
messages: &[ConversationMessage],
tools: &[ToolDefinition],
) -> StreamResult {
let mut blocks: Vec<ContentBlock> = Vec::new();
let mut text_buf = String::new();
let mut active_tool: Option<ActiveToolUse> = None;
let mut stream = Box::pin(self.provider.stream(messages, tools));
while let Some(event) = stream.next().await {
match event {
StreamEvent::TextDelta(chunk) => {
text_buf.push_str(&chunk);
self.send(UIEvent::StreamDelta(chunk)).await;
}
StreamEvent::ToolUseStart { id, name } => {
// Flush any accumulated text before starting tool block.
if !text_buf.is_empty() {
blocks.push(ContentBlock::Text {
text: std::mem::take(&mut text_buf),
});
}
active_tool = Some(ActiveToolUse::new(id, name));
}
StreamEvent::ToolUseInputDelta(chunk) => {
if let Some(t) = &mut active_tool {
t.append(&chunk);
} else {
warn!(
"received ToolUseInputDelta outside of an active tool-use block -- ignoring"
);
}
}
StreamEvent::ToolUseDone => {
if let Some(t) = active_tool.take() {
match ContentBlock::try_from(t) {
Ok(block) => blocks.push(block),
Err(e) => {
return StreamResult::Error(format!(
"failed to parse tool input JSON: {e}"
));
}
}
} else {
warn!(
"received ToolUseDone outside of an active tool-use block -- ignoring"
);
}
}
StreamEvent::Done => {
// Flush trailing text.
if !text_buf.is_empty() {
blocks.push(ContentBlock::Text { text: text_buf });
}
return StreamResult::Done(blocks);
}
StreamEvent::Error(msg) => {
return StreamResult::Error(msg);
}
StreamEvent::InputTokens(n) => {
debug!(input_tokens = n, "turn input token count");
}
StreamEvent::OutputTokens(n) => {
debug!(output_tokens = n, "turn output token count");
}
}
}
// Stream ended without Done -- treat as error.
StreamResult::Error("stream ended unexpectedly".to_string())
}
/// Execute one complete turn: stream from provider, execute tools if needed,
/// loop back until the model produces a text-only response or we hit the
/// iteration limit.
async fn run_turn(&mut self) {
let tool_defs = self.tool_registry.definitions();
for _ in 0..MAX_TOOL_ITERATIONS {
let messages = self.history.messages().to_vec();
let result = self.consume_stream(&messages, &tool_defs).await;
match result {
StreamResult::Error(msg) => {
self.send(UIEvent::Error(msg)).await;
return;
}
StreamResult::Done(blocks) => {
let has_tool_use = blocks
.iter()
.any(|b| matches!(b, ContentBlock::ToolUse { .. }));
// Append the assistant message to history.
self.history.push(ConversationMessage {
role: Role::Assistant,
content: blocks.clone(),
});
if !has_tool_use {
self.send(UIEvent::TurnComplete).await;
return;
}
// TODO: execute tool-use blocks in parallel -- requires
// refactoring Orchestrator to use &self + interior
// mutability (Arc<Mutex<...>> around event_tx, action_rx,
// history) so that multiple futures can borrow self
// simultaneously via futures::future::join_all.
// Execute each tool-use block and collect results.
let mut tool_results: Vec<ContentBlock> = Vec::new();
for block in &blocks {
if let ContentBlock::ToolUse { id, name, input } = block {
let result = self.execute_tool_with_approval(id, name, input).await;
tool_results.push(ContentBlock::ToolResult {
tool_use_id: id.clone(),
content: result.content,
is_error: result.is_error,
});
}
}
// Append tool results as a user message and loop.
self.history.push(ConversationMessage {
role: Role::User,
content: tool_results,
});
}
}
}
warn!("tool-use loop reached max iterations ({MAX_TOOL_ITERATIONS})");
self.send(UIEvent::Error(
"tool-use loop reached maximum iterations".to_string(),
))
.await;
}
/// Build a [`ToolDisplay`] from the tool name and its JSON input.
///
/// Matches on known tool names to extract structured fields; falls back to
/// `Generic` with a JSON summary for anything else.
fn build_tool_display(&self, tool_name: &str, input: &serde_json::Value) -> ToolDisplay {
match tool_name {
"write_file" => {
let path = input["path"].as_str().unwrap_or("<unknown>").to_string();
let new_content = input["content"].as_str().unwrap_or("").to_string();
// Try to read existing content for diffing.
let old_content = self.sandbox.read_file(&path).ok();
ToolDisplay::WriteFile {
path,
old_content,
new_content,
}
}
"shell_exec" => {
let command = input["command"].as_str().unwrap_or("").to_string();
ToolDisplay::ShellExec { command }
}
"list_directory" => {
let path = input["path"].as_str().unwrap_or(".").to_string();
ToolDisplay::ListDirectory { path }
}
"read_file" => {
let path = input["path"].as_str().unwrap_or("<unknown>").to_string();
ToolDisplay::ReadFile { path }
}
_ => ToolDisplay::Generic {
summary: serde_json::to_string(input).unwrap_or_default(),
},
}
}
/// Build a [`ToolDisplay`] for a tool result, incorporating the output content.
fn build_result_display(
&self,
tool_name: &str,
input: &serde_json::Value,
output: &str,
) -> ToolDisplay {
match tool_name {
"write_file" => {
let path = input["path"].as_str().unwrap_or("<unknown>").to_string();
let new_content = input["content"].as_str().unwrap_or("").to_string();
// For results, old_content isn't available post-write. We already
// showed the diff at approval/executing time; the result just
// confirms the byte count via the display formatter.
ToolDisplay::WriteFile {
path,
old_content: None,
new_content,
}
}
"shell_exec" => {
let command = input["command"].as_str().unwrap_or("").to_string();
ToolDisplay::ShellExec {
command: format!("{command}\n{output}"),
}
}
"list_directory" => {
let path = input["path"].as_str().unwrap_or(".").to_string();
ToolDisplay::ListDirectory {
path: format!("{path}\n{output}"),
}
}
"read_file" => {
let path = input["path"].as_str().unwrap_or("<unknown>").to_string();
let line_count = output.lines().count();
ToolDisplay::ReadFile {
path: format!("{path} ({line_count} lines)"),
}
}
_ => ToolDisplay::Generic {
summary: truncate(output, 200),
},
}
}
/// Execute a single tool, handling approval if needed.
///
/// For auto-approve tools, executes immediately. For tools requiring
/// approval, sends a request to the TUI and waits for a response.
async fn execute_tool_with_approval(
&mut self,
tool_use_id: &str,
tool_name: &str,
input: &serde_json::Value,
) -> ToolOutput {
// Extract tool info upfront to avoid holding a borrow on self across
// the mutable wait_for_approval call.
let risk = match self.tool_registry.get(tool_name) {
Some(t) => t.risk_level(),
None => {
return ToolOutput {
content: format!("unknown tool: {tool_name}"),
is_error: true,
};
}
};
let display = self.build_tool_display(tool_name, input);
// Check approval.
let approved = match risk {
RiskLevel::AutoApprove => {
self.send(UIEvent::ToolExecuting {
tool_use_id: tool_use_id.to_string(),
tool_name: tool_name.to_string(),
display,
})
.await;
true
}
RiskLevel::RequiresApproval => {
self.send(UIEvent::ToolApprovalRequest {
tool_use_id: tool_use_id.to_string(),
tool_name: tool_name.to_string(),
display,
})
.await;
// Wait for approval response from TUI.
self.wait_for_approval(tool_use_id).await
}
};
if !approved {
return ToolOutput {
content: "tool execution denied by user".to_string(),
is_error: true,
};
}
// Re-fetch tool for execution (borrow was released above).
let tool = self.tool_registry.get(tool_name).unwrap();
match tool.execute(input, &self.sandbox).await {
Ok(output) => {
let result_display = self.build_result_display(tool_name, input, &output.content);
self.send(UIEvent::ToolResult {
tool_use_id: tool_use_id.to_string(),
tool_name: tool_name.to_string(),
display: result_display,
is_error: output.is_error,
})
.await;
output
}
Err(e) => {
let msg = e.to_string();
self.send(UIEvent::ToolResult {
tool_use_id: tool_use_id.to_string(),
tool_name: tool_name.to_string(),
display: ToolDisplay::Generic {
summary: msg.clone(),
},
is_error: true,
})
.await;
ToolOutput {
content: msg,
is_error: true,
}
}
}
}
/// Wait for a `ToolApprovalResponse` matching `tool_use_id` from the TUI.
///
/// Any `SendMessage` actions that arrive during the wait are pushed onto
/// `self.queued_messages` rather than discarded. They are replayed as new
/// turns once the current tool-use loop completes (see `run()`).
///
/// Returns `true` if approved, `false` if denied or the channel closes.
async fn wait_for_approval(&mut self, tool_use_id: &str) -> bool {
while let Some(action) = self.action_rx.recv().await {
match action {
UserAction::ToolApprovalResponse {
tool_use_id: id,
approved,
} if id == tool_use_id => {
return approved;
}
UserAction::Quit => return false,
UserAction::SendMessage(text) => {
self.queued_messages.push(text);
}
UserAction::ClearHistory { epoch } => {
// Keep epoch in sync even while blocked on an approval
// prompt. Without this, events emitted after the approval
// resolves would carry the pre-clear epoch and be silently
// discarded by the TUI.
self.epoch = epoch;
self.history.clear();
}
_ => {} // discard stale approvals
}
}
false
}
/// Run the orchestrator until the user quits or the `action_rx` channel closes.
pub async fn run(mut self) {
while let Some(action) = self.action_rx.recv().await {
match action {
UserAction::Quit => break,
UserAction::ClearHistory { epoch } => {
self.epoch = epoch;
self.history.clear();
}
// Approval responses are handled inline during tool execution,
// not in the main action loop. If one arrives here it's stale.
UserAction::ToolApprovalResponse { .. } => {}
UserAction::SetNetworkPolicy(allowed) => {
self.sandbox.set_network_allowed(allowed);
self.send(UIEvent::NetworkPolicyChanged(allowed)).await;
}
UserAction::SendMessage(text) => {
self.history
.push(ConversationMessage::text(Role::User, text));
self.run_turn().await;
// Drain any messages queued while an approval prompt was
// open. Each queued message is a full turn in sequence.
while !self.queued_messages.is_empty() {
let queued = std::mem::take(&mut self.queued_messages);
for msg in queued {
self.history
.push(ConversationMessage::text(Role::User, msg));
self.run_turn().await;
}
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::types::ToolDefinition;
use crate::tools::ToolRegistry;
use futures::Stream;
use tokio::sync::mpsc;
/// A provider that replays a fixed sequence of [`StreamEvent`]s.
///
/// Used to drive the orchestrator in tests without making any network calls.
struct MockProvider {
events: Vec<StreamEvent>,
}
impl MockProvider {
fn new(events: Vec<StreamEvent>) -> Self {
Self { events }
}
}
impl ModelProvider for MockProvider {
fn stream<'a>(
&'a self,
_messages: &'a [ConversationMessage],
_tools: &'a [ToolDefinition],
) -> impl Stream<Item = StreamEvent> + Send + 'a {
futures::stream::iter(self.events.clone())
}
}
/// Create an Orchestrator with no tools for testing text-only flows.
fn test_orchestrator<P: ModelProvider>(
provider: P,
action_rx: mpsc::Receiver<UserAction>,
event_tx: mpsc::Sender<StampedEvent>,
) -> Orchestrator<P> {
use crate::sandbox::policy::SandboxPolicy;
use crate::sandbox::{EnforcementMode, Sandbox};
let sandbox = Sandbox::new(
SandboxPolicy::for_project(std::path::PathBuf::from("/tmp")),
std::path::PathBuf::from("/tmp"),
EnforcementMode::Yolo,
);
Orchestrator::new(
provider,
ToolRegistry::empty(),
sandbox,
action_rx,
event_tx,
)
}
/// Collect all UIEvents that arrive within one orchestrator turn, stopping
/// when the channel is drained after a `TurnComplete` or `Error`.
async fn collect_events(rx: &mut mpsc::Receiver<StampedEvent>) -> Vec<UIEvent> {
let mut out = Vec::new();
while let Ok(stamped) = rx.try_recv() {
let done = matches!(stamped.event, UIEvent::TurnComplete | UIEvent::Error(_));
out.push(stamped.event);
if done {
break;
}
}
out
}
// -- happy-path turn ----------------------------------------------------------
/// A full successful turn: text chunks followed by Done.
///
/// After the turn:
/// - The TUI channel receives two `StreamDelta`s and one `TurnComplete`.
/// - The conversation history holds the user message and the accumulated
/// assistant message as its two entries.
#[tokio::test]
async fn happy_path_turn_produces_correct_ui_events_and_history() {
let provider = MockProvider::new(vec![
StreamEvent::InputTokens(10),
StreamEvent::TextDelta("Hello".to_string()),
StreamEvent::TextDelta(", world!".to_string()),
StreamEvent::OutputTokens(5),
StreamEvent::Done,
]);
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(16);
let orch = test_orchestrator(provider, action_rx, event_tx);
let handle = tokio::spawn(orch.run());
action_tx
.send(UserAction::SendMessage("hi".to_string()))
.await
.unwrap();
// Give the orchestrator time to process the stream.
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let events = collect_events(&mut event_rx).await;
// Verify the UIEvent sequence.
assert_eq!(events.len(), 3);
assert!(matches!(&events[0], UIEvent::StreamDelta(s) if s == "Hello"));
assert!(matches!(&events[1], UIEvent::StreamDelta(s) if s == ", world!"));
assert!(matches!(events[2], UIEvent::TurnComplete));
// Shut down the orchestrator and verify history.
action_tx.send(UserAction::Quit).await.unwrap();
handle.await.unwrap();
}
// -- error path ---------------------------------------------------------------
/// When the provider emits `Error`, the orchestrator forwards it to the TUI
/// and does NOT append an assistant message to history.
#[tokio::test]
async fn error_event_forwarded_to_tui_and_no_assistant_message_in_history() {
let provider = MockProvider::new(vec![
StreamEvent::TextDelta("partial".to_string()),
StreamEvent::Error("network timeout".to_string()),
]);
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(16);
let orch = test_orchestrator(provider, action_rx, event_tx);
let handle = tokio::spawn(orch.run());
action_tx
.send(UserAction::SendMessage("hello".to_string()))
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let events = collect_events(&mut event_rx).await;
assert_eq!(events.len(), 2);
assert!(matches!(&events[0], UIEvent::StreamDelta(s) if s == "partial"));
assert!(matches!(&events[1], UIEvent::Error(msg) if msg == "network timeout"));
action_tx.send(UserAction::Quit).await.unwrap();
handle.await.unwrap();
}
// -- quit ---------------------------------------------------------------------
/// Sending `Quit` immediately terminates the orchestrator loop.
#[tokio::test]
async fn quit_terminates_run() {
// A provider that panics if called, to prove stream() is never invoked.
struct NeverCalledProvider;
impl ModelProvider for NeverCalledProvider {
fn stream<'a>(
&'a self,
_messages: &'a [ConversationMessage],
_tools: &'a [ToolDefinition],
) -> impl Stream<Item = StreamEvent> + Send + 'a {
panic!("stream() must not be called after Quit");
#[allow(unreachable_code)]
futures::stream::empty()
}
}
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
let (event_tx, _event_rx) = mpsc::channel::<StampedEvent>(8);
let orch = test_orchestrator(NeverCalledProvider, action_rx, event_tx);
let handle = tokio::spawn(orch.run());
action_tx.send(UserAction::Quit).await.unwrap();
handle.await.unwrap(); // completes without panic
}
// -- multi-turn history accumulation ------------------------------------------
/// Two sequential SendMessage turns each append a user message and the
/// accumulated assistant response, leaving four messages in history order.
///
/// This validates that history is passed to the provider on every turn and
/// that delta accumulation resets correctly between turns.
#[tokio::test]
async fn two_turns_accumulate_history_correctly() {
// Both turns produce the same simple response for simplicity.
let make_turn_events = || {
vec![
StreamEvent::TextDelta("reply".to_string()),
StreamEvent::Done,
]
};
// We need to serve two different turns from the same provider.
// Use an `Arc<Mutex<VecDeque>>` so the provider can pop event sets.
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
struct MultiTurnMock {
turns: Arc<Mutex<VecDeque<Vec<StreamEvent>>>>,
}
impl ModelProvider for MultiTurnMock {
fn stream<'a>(
&'a self,
_messages: &'a [ConversationMessage],
_tools: &'a [ToolDefinition],
) -> impl Stream<Item = StreamEvent> + Send + 'a {
let events = self.turns.lock().unwrap().pop_front().unwrap_or_default();
futures::stream::iter(events)
}
}
let turns = Arc::new(Mutex::new(VecDeque::from([
make_turn_events(),
make_turn_events(),
])));
let provider = MultiTurnMock { turns };
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(32);
let orch = test_orchestrator(provider, action_rx, event_tx);
let handle = tokio::spawn(orch.run());
// First turn.
action_tx
.send(UserAction::SendMessage("turn one".to_string()))
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let ev1 = collect_events(&mut event_rx).await;
assert!(matches!(ev1.last(), Some(UIEvent::TurnComplete)));
// Second turn.
action_tx
.send(UserAction::SendMessage("turn two".to_string()))
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let ev2 = collect_events(&mut event_rx).await;
assert!(matches!(ev2.last(), Some(UIEvent::TurnComplete)));
action_tx.send(UserAction::Quit).await.unwrap();
handle.await.unwrap();
}
// -- tool-use accumulation ----------------------------------------------------
/// When the provider emits tool-use events, the orchestrator executes the
/// tool (auto-approve since read_file is AutoApprove), feeds the result back,
/// and the provider's second call returns text.
#[tokio::test]
async fn tool_use_loop_executes_and_feeds_back() {
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
struct MultiCallMock {
turns: Arc<Mutex<VecDeque<Vec<StreamEvent>>>>,
}
impl ModelProvider for MultiCallMock {
fn stream<'a>(
&'a self,
_messages: &'a [ConversationMessage],
_tools: &'a [ToolDefinition],
) -> impl Stream<Item = StreamEvent> + Send + 'a {
let events = self.turns.lock().unwrap().pop_front().unwrap_or_default();
futures::stream::iter(events)
}
}
let turns = Arc::new(Mutex::new(VecDeque::from([
// First call: tool use
vec![
StreamEvent::TextDelta("Let me read that.".to_string()),
StreamEvent::ToolUseDone, // spurious stop for text block
StreamEvent::ToolUseStart {
id: "toolu_01".to_string(),
name: "read_file".to_string(),
},
StreamEvent::ToolUseInputDelta("{\"path\":\"Cargo.toml\"}".to_string()),
StreamEvent::ToolUseDone,
StreamEvent::Done,
],
// Second call: text response after tool result
vec![
StreamEvent::TextDelta("Here's the file.".to_string()),
StreamEvent::Done,
],
])));
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(32);
// Use a real ToolRegistry so read_file works.
use crate::sandbox::policy::SandboxPolicy;
use crate::sandbox::{EnforcementMode, Sandbox};
let dir = tempfile::TempDir::new().unwrap();
std::fs::write(dir.path().join("Cargo.toml"), "[package]\nname = \"test\"").unwrap();
let sandbox = Sandbox::new(
SandboxPolicy::for_project(dir.path().to_path_buf()),
dir.path().to_path_buf(),
EnforcementMode::Yolo,
);
let orch = Orchestrator::new(
MultiCallMock { turns },
ToolRegistry::default_tools(),
sandbox,
action_rx,
event_tx,
);
let handle = tokio::spawn(orch.run());
action_tx
.send(UserAction::SendMessage("read Cargo.toml".to_string()))
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let events = collect_events(&mut event_rx).await;
// Should see text deltas from both calls, tool events, and TurnComplete
assert!(
events
.iter()
.any(|e| matches!(e, UIEvent::StreamDelta(s) if s == "Let me read that.")),
"expected first text delta"
);
assert!(
events
.iter()
.any(|e| matches!(e, UIEvent::StreamDelta(s) if s == "Here's the file.")),
"expected second text delta"
);
assert!(
matches!(events.last(), Some(UIEvent::TurnComplete)),
"expected TurnComplete, got: {events:?}"
);
action_tx.send(UserAction::Quit).await.unwrap();
handle.await.unwrap();
}
// -- truncate -----------------------------------------------------------------
/// Truncating at a boundary inside a multibyte character must not panic.
/// "hello " is 6 bytes; the emoji is 4 bytes. Requesting max_len=8 would
/// previously slice at byte 8, splitting the emoji. The fixed version uses
/// char_indices so the cut falls on a char boundary.
#[test]
fn truncate_multibyte_does_not_panic() {
let s = "hello \u{1F30D} world"; // globe emoji is 4 bytes
// Should not panic and the result should end with "..."
let result = truncate(s, 8);
assert!(
result.ends_with("..."),
"expected truncation suffix: {result}"
);
// The string before "..." must be valid UTF-8 (guaranteed by the fix).
let prefix = result.trim_end_matches("...");
assert!(std::str::from_utf8(prefix.as_bytes()).is_ok());
}
// -- JSON parse error ---------------------------------------------------------
/// When the provider emits malformed tool-input JSON, the orchestrator must
/// surface a descriptive UIEvent::Error rather than silently using Null.
#[tokio::test]
async fn malformed_tool_json_surfaces_error() {
let provider = MockProvider::new(vec![
StreamEvent::ToolUseStart {
id: "toolu_bad".to_string(),
name: "read_file".to_string(),
},
StreamEvent::ToolUseInputDelta("not valid json{{".to_string()),
StreamEvent::ToolUseDone,
StreamEvent::Done,
]);
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(16);
let orch = test_orchestrator(provider, action_rx, event_tx);
let handle = tokio::spawn(orch.run());
action_tx
.send(UserAction::SendMessage("go".to_string()))
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let events = collect_events(&mut event_rx).await;
assert!(
events
.iter()
.any(|e| matches!(e, UIEvent::Error(msg) if msg.contains("tool input JSON"))),
"expected JSON parse error event, got: {events:?}"
);
action_tx.send(UserAction::Quit).await.unwrap();
handle.await.unwrap();
}
// -- queued messages ----------------------------------------------------------
/// A SendMessage sent while an approval prompt is open must be processed
/// after the current turn completes, not silently dropped.
///
/// This test uses a RequiresApproval tool (write_file) so that the
/// orchestrator blocks in wait_for_approval. While blocked, we send a second
/// message. After approving (denied here for simplicity), the queued message
/// should still be processed.
#[tokio::test]
async fn send_message_during_approval_is_queued_and_processed() {
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
struct MultiCallMock {
turns: Arc<Mutex<VecDeque<Vec<StreamEvent>>>>,
}
impl ModelProvider for MultiCallMock {
fn stream<'a>(
&'a self,
_messages: &'a [ConversationMessage],
_tools: &'a [ToolDefinition],
) -> impl Stream<Item = StreamEvent> + Send + 'a {
let events = self.turns.lock().unwrap().pop_front().unwrap_or_default();
futures::stream::iter(events)
}
}
let turns = Arc::new(Mutex::new(VecDeque::from([
// Turn 1: requests write_file (RequiresApproval)
vec![
StreamEvent::ToolUseStart {
id: "toolu_w".to_string(),
name: "write_file".to_string(),
},
StreamEvent::ToolUseInputDelta(
"{\"path\":\"x.txt\",\"content\":\"hi\"}".to_string(),
),
StreamEvent::ToolUseDone,
StreamEvent::Done,
],
// Turn 1 second iteration after tool result (denied)
vec![
StreamEvent::TextDelta("ok denied".to_string()),
StreamEvent::Done,
],
// Turn 2: the queued message
vec![
StreamEvent::TextDelta("queued reply".to_string()),
StreamEvent::Done,
],
])));
let (action_tx, action_rx) = mpsc::channel::<UserAction>(16);
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(64);
let orch = test_orchestrator(MultiCallMock { turns }, action_rx, event_tx);
let handle = tokio::spawn(orch.run());
// Start turn 1 -- orchestrator will block on approval.
action_tx
.send(UserAction::SendMessage("turn1".to_string()))
.await
.unwrap();
// Let the orchestrator reach wait_for_approval.
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
// Send a message while blocked -- should be queued.
action_tx
.send(UserAction::SendMessage("queued".to_string()))
.await
.unwrap();
// Deny the tool -- unblocks wait_for_approval.
action_tx
.send(UserAction::ToolApprovalResponse {
tool_use_id: "toolu_w".to_string(),
approved: false,
})
.await
.unwrap();
// Wait for both turns to complete.
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
// Collect everything.
let mut all_events = Vec::new();
while let Ok(ev) = event_rx.try_recv() {
all_events.push(ev.event);
}
// The queued message must have produced "queued reply".
assert!(
all_events
.iter()
.any(|e| matches!(e, UIEvent::StreamDelta(s) if s == "queued reply")),
"queued message was not processed; events: {all_events:?}"
);
action_tx.send(UserAction::Quit).await.unwrap();
handle.await.unwrap();
}
// -- network policy toggle ------------------------------------------------
/// SetNetworkPolicy sends a NetworkPolicyChanged event back to the TUI.
#[tokio::test]
async fn set_network_policy_sends_event() {
// Provider that returns immediately to avoid blocking.
struct NeverCalledProvider;
impl ModelProvider for NeverCalledProvider {
fn stream<'a>(
&'a self,
_messages: &'a [ConversationMessage],
_tools: &'a [ToolDefinition],
) -> impl Stream<Item = StreamEvent> + Send + 'a {
futures::stream::empty()
}
}
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(8);
let orch = test_orchestrator(NeverCalledProvider, action_rx, event_tx);
let handle = tokio::spawn(orch.run());
action_tx
.send(UserAction::SetNetworkPolicy(true))
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let mut found = false;
while let Ok(stamped) = event_rx.try_recv() {
if matches!(stamped.event, UIEvent::NetworkPolicyChanged(true)) {
found = true;
}
}
assert!(found, "expected NetworkPolicyChanged(true) event");
action_tx.send(UserAction::Quit).await.unwrap();
handle.await.unwrap();
}
}