Add tool use to the orchestrator (#4)
Add tool use without sandboxing. Currently available tools are list dir, read file, write file and exec bash. Reviewed-on: #4 Co-authored-by: Drew Galbraith <drew@tiramisu.one> Co-committed-by: Drew Galbraith <drew@tiramisu.one>
This commit is contained in:
parent
6b85ff3cb8
commit
797d7564b7
20 changed files with 1822 additions and 129 deletions
|
|
@ -1,10 +1,53 @@
|
|||
use futures::StreamExt;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::debug;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::core::history::ConversationHistory;
|
||||
use crate::core::types::{ConversationMessage, Role, StreamEvent, UIEvent, UserAction};
|
||||
use crate::core::types::{
|
||||
ContentBlock, ConversationMessage, Role, StreamEvent, ToolDefinition, UIEvent, UserAction,
|
||||
};
|
||||
use crate::provider::ModelProvider;
|
||||
use crate::tools::{RiskLevel, ToolOutput, ToolRegistry};
|
||||
|
||||
/// Accumulates data for a single tool-use block while it is being streamed.
|
||||
///
|
||||
/// Created on `ToolUseStart`, populated by `ToolUseInputDelta` fragments, and
|
||||
/// consumed on `ToolUseDone` via `TryFrom<ActiveToolUse> for ContentBlock`.
|
||||
struct ActiveToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
json_buf: String,
|
||||
}
|
||||
|
||||
impl ActiveToolUse {
|
||||
/// Start accumulating a new tool-use block identified by `id` and `name`.
|
||||
fn new(id: String, name: String) -> Self {
|
||||
Self {
|
||||
id,
|
||||
name,
|
||||
json_buf: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Append a JSON input fragment received from a `ToolUseInputDelta` event.
|
||||
fn append(&mut self, chunk: &str) {
|
||||
self.json_buf.push_str(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ActiveToolUse> for ContentBlock {
|
||||
type Error = serde_json::Error;
|
||||
|
||||
/// Parse the accumulated JSON buffer and produce a `ContentBlock::ToolUse`.
|
||||
fn try_from(t: ActiveToolUse) -> Result<Self, Self::Error> {
|
||||
let input = serde_json::from_str(&t.json_buf)?;
|
||||
Ok(ContentBlock::ToolUse {
|
||||
id: t.id,
|
||||
name: t.name,
|
||||
input,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Drives the conversation loop between the TUI frontend and the model provider.
|
||||
///
|
||||
|
|
@ -38,28 +81,321 @@ use crate::provider::ModelProvider;
|
|||
/// OutputTokens -> log at debug level
|
||||
/// 3. Quit -> return
|
||||
/// ```
|
||||
/// The result of consuming one provider stream (one assistant turn).
|
||||
enum StreamResult {
|
||||
/// The stream completed successfully with these content blocks.
|
||||
Done(Vec<ContentBlock>),
|
||||
/// The stream produced an error.
|
||||
Error(String),
|
||||
}
|
||||
|
||||
/// Maximum number of tool-use loop iterations per user message.
|
||||
const MAX_TOOL_ITERATIONS: usize = 25;
|
||||
|
||||
/// Truncate a string to `max_len` Unicode scalar values, appending "..." if
|
||||
/// truncated. Uses `char_indices` so that multibyte characters are never split
|
||||
/// at a byte boundary (which would panic on a bare slice).
|
||||
fn truncate(s: &str, max_len: usize) -> String {
|
||||
match s.char_indices().nth(max_len) {
|
||||
Some((i, _)) => format!("{}...", &s[..i]),
|
||||
None => s.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Orchestrator<P> {
|
||||
history: ConversationHistory,
|
||||
provider: P,
|
||||
tool_registry: ToolRegistry,
|
||||
working_dir: std::path::PathBuf,
|
||||
action_rx: mpsc::Receiver<UserAction>,
|
||||
event_tx: mpsc::Sender<UIEvent>,
|
||||
/// Messages typed by the user while an approval prompt is open. They are
|
||||
/// queued here and replayed as new turns once the current turn completes.
|
||||
queued_messages: Vec<String>,
|
||||
}
|
||||
|
||||
impl<P: ModelProvider> Orchestrator<P> {
|
||||
/// Construct an orchestrator using the given provider and channel endpoints.
|
||||
pub fn new(
|
||||
provider: P,
|
||||
tool_registry: ToolRegistry,
|
||||
working_dir: std::path::PathBuf,
|
||||
action_rx: mpsc::Receiver<UserAction>,
|
||||
event_tx: mpsc::Sender<UIEvent>,
|
||||
) -> Self {
|
||||
Self {
|
||||
history: ConversationHistory::new(),
|
||||
provider,
|
||||
tool_registry,
|
||||
working_dir,
|
||||
action_rx,
|
||||
event_tx,
|
||||
queued_messages: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Consume the provider's stream for one turn, returning the content blocks
|
||||
/// produced by the assistant (text and/or tool-use) or an error.
|
||||
///
|
||||
/// Tool-use input JSON is accumulated from `ToolUseInputDelta` fragments and
|
||||
/// parsed into `ContentBlock::ToolUse` on `ToolUseDone`.
|
||||
async fn consume_stream(
|
||||
&self,
|
||||
messages: &[ConversationMessage],
|
||||
tools: &[ToolDefinition],
|
||||
) -> StreamResult {
|
||||
let mut blocks: Vec<ContentBlock> = Vec::new();
|
||||
let mut text_buf = String::new();
|
||||
let mut active_tool: Option<ActiveToolUse> = None;
|
||||
|
||||
let mut stream = Box::pin(self.provider.stream(messages, tools));
|
||||
|
||||
while let Some(event) = stream.next().await {
|
||||
match event {
|
||||
StreamEvent::TextDelta(chunk) => {
|
||||
text_buf.push_str(&chunk);
|
||||
let _ = self.event_tx.send(UIEvent::StreamDelta(chunk)).await;
|
||||
}
|
||||
StreamEvent::ToolUseStart { id, name } => {
|
||||
// Flush any accumulated text before starting tool block.
|
||||
if !text_buf.is_empty() {
|
||||
blocks.push(ContentBlock::Text {
|
||||
text: std::mem::take(&mut text_buf),
|
||||
});
|
||||
}
|
||||
active_tool = Some(ActiveToolUse::new(id, name));
|
||||
}
|
||||
StreamEvent::ToolUseInputDelta(chunk) => {
|
||||
if let Some(t) = &mut active_tool {
|
||||
t.append(&chunk);
|
||||
} else {
|
||||
warn!(
|
||||
"received ToolUseInputDelta outside of an active tool-use block -- ignoring"
|
||||
);
|
||||
}
|
||||
}
|
||||
StreamEvent::ToolUseDone => {
|
||||
if let Some(t) = active_tool.take() {
|
||||
match ContentBlock::try_from(t) {
|
||||
Ok(block) => blocks.push(block),
|
||||
Err(e) => {
|
||||
return StreamResult::Error(format!(
|
||||
"failed to parse tool input JSON: {e}"
|
||||
));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!(
|
||||
"received ToolUseDone outside of an active tool-use block -- ignoring"
|
||||
);
|
||||
}
|
||||
}
|
||||
StreamEvent::Done => {
|
||||
// Flush trailing text.
|
||||
if !text_buf.is_empty() {
|
||||
blocks.push(ContentBlock::Text { text: text_buf });
|
||||
}
|
||||
return StreamResult::Done(blocks);
|
||||
}
|
||||
StreamEvent::Error(msg) => {
|
||||
return StreamResult::Error(msg);
|
||||
}
|
||||
StreamEvent::InputTokens(n) => {
|
||||
debug!(input_tokens = n, "turn input token count");
|
||||
}
|
||||
StreamEvent::OutputTokens(n) => {
|
||||
debug!(output_tokens = n, "turn output token count");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stream ended without Done -- treat as error.
|
||||
StreamResult::Error("stream ended unexpectedly".to_string())
|
||||
}
|
||||
|
||||
/// Execute one complete turn: stream from provider, execute tools if needed,
|
||||
/// loop back until the model produces a text-only response or we hit the
|
||||
/// iteration limit.
|
||||
async fn run_turn(&mut self) {
|
||||
let tool_defs = self.tool_registry.definitions();
|
||||
|
||||
for _ in 0..MAX_TOOL_ITERATIONS {
|
||||
let messages = self.history.messages().to_vec();
|
||||
let result = self.consume_stream(&messages, &tool_defs).await;
|
||||
|
||||
match result {
|
||||
StreamResult::Error(msg) => {
|
||||
let _ = self.event_tx.send(UIEvent::Error(msg)).await;
|
||||
return;
|
||||
}
|
||||
StreamResult::Done(blocks) => {
|
||||
let has_tool_use = blocks
|
||||
.iter()
|
||||
.any(|b| matches!(b, ContentBlock::ToolUse { .. }));
|
||||
|
||||
// Append the assistant message to history.
|
||||
self.history.push(ConversationMessage {
|
||||
role: Role::Assistant,
|
||||
content: blocks.clone(),
|
||||
});
|
||||
|
||||
if !has_tool_use {
|
||||
let _ = self.event_tx.send(UIEvent::TurnComplete).await;
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: execute tool-use blocks in parallel -- requires
|
||||
// refactoring Orchestrator to use &self + interior
|
||||
// mutability (Arc<Mutex<...>> around event_tx, action_rx,
|
||||
// history) so that multiple futures can borrow self
|
||||
// simultaneously via futures::future::join_all.
|
||||
// Execute each tool-use block and collect results.
|
||||
let mut tool_results: Vec<ContentBlock> = Vec::new();
|
||||
|
||||
for block in &blocks {
|
||||
if let ContentBlock::ToolUse { id, name, input } = block {
|
||||
let result = self.execute_tool_with_approval(id, name, input).await;
|
||||
tool_results.push(ContentBlock::ToolResult {
|
||||
tool_use_id: id.clone(),
|
||||
content: result.content,
|
||||
is_error: result.is_error,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Append tool results as a user message and loop.
|
||||
self.history.push(ConversationMessage {
|
||||
role: Role::User,
|
||||
content: tool_results,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
warn!("tool-use loop reached max iterations ({MAX_TOOL_ITERATIONS})");
|
||||
let _ = self
|
||||
.event_tx
|
||||
.send(UIEvent::Error(
|
||||
"tool-use loop reached maximum iterations".to_string(),
|
||||
))
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Execute a single tool, handling approval if needed.
|
||||
///
|
||||
/// For auto-approve tools, executes immediately. For tools requiring
|
||||
/// approval, sends a request to the TUI and waits for a response.
|
||||
async fn execute_tool_with_approval(
|
||||
&mut self,
|
||||
tool_use_id: &str,
|
||||
tool_name: &str,
|
||||
input: &serde_json::Value,
|
||||
) -> ToolOutput {
|
||||
// Extract tool info upfront to avoid holding a borrow on self across
|
||||
// the mutable wait_for_approval call.
|
||||
let risk = match self.tool_registry.get(tool_name) {
|
||||
Some(t) => t.risk_level(),
|
||||
None => {
|
||||
return ToolOutput {
|
||||
content: format!("unknown tool: {tool_name}"),
|
||||
is_error: true,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
let input_summary = serde_json::to_string(input).unwrap_or_default();
|
||||
|
||||
// Check approval.
|
||||
let approved = match risk {
|
||||
RiskLevel::AutoApprove => {
|
||||
let _ = self
|
||||
.event_tx
|
||||
.send(UIEvent::ToolExecuting {
|
||||
tool_name: tool_name.to_string(),
|
||||
input_summary: input_summary.clone(),
|
||||
})
|
||||
.await;
|
||||
true
|
||||
}
|
||||
RiskLevel::RequiresApproval => {
|
||||
let _ = self
|
||||
.event_tx
|
||||
.send(UIEvent::ToolApprovalRequest {
|
||||
tool_use_id: tool_use_id.to_string(),
|
||||
tool_name: tool_name.to_string(),
|
||||
input_summary: input_summary.clone(),
|
||||
})
|
||||
.await;
|
||||
|
||||
// Wait for approval response from TUI.
|
||||
self.wait_for_approval(tool_use_id).await
|
||||
}
|
||||
};
|
||||
|
||||
if !approved {
|
||||
return ToolOutput {
|
||||
content: "tool execution denied by user".to_string(),
|
||||
is_error: true,
|
||||
};
|
||||
}
|
||||
|
||||
// Re-fetch tool for execution (borrow was released above).
|
||||
let tool = self.tool_registry.get(tool_name).unwrap();
|
||||
match tool.execute(input, &self.working_dir).await {
|
||||
Ok(output) => {
|
||||
let _ = self
|
||||
.event_tx
|
||||
.send(UIEvent::ToolResult {
|
||||
tool_name: tool_name.to_string(),
|
||||
output_summary: truncate(&output.content, 200),
|
||||
is_error: output.is_error,
|
||||
})
|
||||
.await;
|
||||
output
|
||||
}
|
||||
Err(e) => {
|
||||
let msg = e.to_string();
|
||||
let _ = self
|
||||
.event_tx
|
||||
.send(UIEvent::ToolResult {
|
||||
tool_name: tool_name.to_string(),
|
||||
output_summary: msg.clone(),
|
||||
is_error: true,
|
||||
})
|
||||
.await;
|
||||
ToolOutput {
|
||||
content: msg,
|
||||
is_error: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wait for a `ToolApprovalResponse` matching `tool_use_id` from the TUI.
|
||||
///
|
||||
/// Any `SendMessage` actions that arrive during the wait are pushed onto
|
||||
/// `self.queued_messages` rather than discarded. They are replayed as new
|
||||
/// turns once the current tool-use loop completes (see `run()`).
|
||||
///
|
||||
/// Returns `true` if approved, `false` if denied or the channel closes.
|
||||
async fn wait_for_approval(&mut self, tool_use_id: &str) -> bool {
|
||||
while let Some(action) = self.action_rx.recv().await {
|
||||
match action {
|
||||
UserAction::ToolApprovalResponse {
|
||||
tool_use_id: id,
|
||||
approved,
|
||||
} if id == tool_use_id => {
|
||||
return approved;
|
||||
}
|
||||
UserAction::Quit => return false,
|
||||
UserAction::SendMessage(text) => {
|
||||
self.queued_messages.push(text);
|
||||
}
|
||||
_ => {} // discard stale approvals / ClearHistory during wait
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Run the orchestrator until the user quits or the `action_rx` channel closes.
|
||||
pub async fn run(mut self) {
|
||||
while let Some(action) = self.action_rx.recv().await {
|
||||
|
|
@ -69,62 +405,24 @@ impl<P: ModelProvider> Orchestrator<P> {
|
|||
self.history.clear();
|
||||
}
|
||||
|
||||
// Approval responses are handled inline during tool execution,
|
||||
// not in the main action loop. If one arrives here it's stale.
|
||||
UserAction::ToolApprovalResponse { .. } => {}
|
||||
|
||||
UserAction::SendMessage(text) => {
|
||||
// Push the user message before snapshotting, so providers
|
||||
// see the full conversation including the new message.
|
||||
self.history.push(ConversationMessage {
|
||||
role: Role::User,
|
||||
content: text,
|
||||
});
|
||||
self.history
|
||||
.push(ConversationMessage::text(Role::User, text));
|
||||
self.run_turn().await;
|
||||
|
||||
// Snapshot history into an owned Vec so the stream does not
|
||||
// borrow from `self.history` -- this lets us mutably update
|
||||
// `self.history` once the stream loop finishes.
|
||||
let messages: Vec<ConversationMessage> = self.history.messages().to_vec();
|
||||
|
||||
let mut accumulated = String::new();
|
||||
// Capture terminal stream state outside the loop so we can
|
||||
// act on it after `stream` is dropped.
|
||||
let mut turn_done = false;
|
||||
let mut turn_error: Option<String> = None;
|
||||
|
||||
{
|
||||
let mut stream = Box::pin(self.provider.stream(&messages));
|
||||
|
||||
while let Some(event) = stream.next().await {
|
||||
match event {
|
||||
StreamEvent::TextDelta(chunk) => {
|
||||
accumulated.push_str(&chunk);
|
||||
let _ = self.event_tx.send(UIEvent::StreamDelta(chunk)).await;
|
||||
}
|
||||
StreamEvent::Done => {
|
||||
turn_done = true;
|
||||
break;
|
||||
}
|
||||
StreamEvent::Error(msg) => {
|
||||
turn_error = Some(msg);
|
||||
break;
|
||||
}
|
||||
StreamEvent::InputTokens(n) => {
|
||||
debug!(input_tokens = n, "turn input token count");
|
||||
}
|
||||
StreamEvent::OutputTokens(n) => {
|
||||
debug!(output_tokens = n, "turn output token count");
|
||||
}
|
||||
}
|
||||
// Drain any messages queued while an approval prompt was
|
||||
// open. Each queued message is a full turn in sequence.
|
||||
while !self.queued_messages.is_empty() {
|
||||
let queued = std::mem::take(&mut self.queued_messages);
|
||||
for msg in queued {
|
||||
self.history
|
||||
.push(ConversationMessage::text(Role::User, msg));
|
||||
self.run_turn().await;
|
||||
}
|
||||
// `stream` is dropped here, releasing the borrow on
|
||||
// `self.provider` and `messages`.
|
||||
}
|
||||
|
||||
if turn_done {
|
||||
self.history.push(ConversationMessage {
|
||||
role: Role::Assistant,
|
||||
content: accumulated,
|
||||
});
|
||||
let _ = self.event_tx.send(UIEvent::TurnComplete).await;
|
||||
} else if let Some(msg) = turn_error {
|
||||
let _ = self.event_tx.send(UIEvent::Error(msg)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -135,6 +433,8 @@ impl<P: ModelProvider> Orchestrator<P> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::types::ToolDefinition;
|
||||
use crate::tools::ToolRegistry;
|
||||
use futures::Stream;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
|
|
@ -155,11 +455,27 @@ mod tests {
|
|||
fn stream<'a>(
|
||||
&'a self,
|
||||
_messages: &'a [ConversationMessage],
|
||||
_tools: &'a [ToolDefinition],
|
||||
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
||||
futures::stream::iter(self.events.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an Orchestrator with no tools for testing text-only flows.
|
||||
fn test_orchestrator<P: ModelProvider>(
|
||||
provider: P,
|
||||
action_rx: mpsc::Receiver<UserAction>,
|
||||
event_tx: mpsc::Sender<UIEvent>,
|
||||
) -> Orchestrator<P> {
|
||||
Orchestrator::new(
|
||||
provider,
|
||||
ToolRegistry::empty(),
|
||||
std::path::PathBuf::from("/tmp"),
|
||||
action_rx,
|
||||
event_tx,
|
||||
)
|
||||
}
|
||||
|
||||
/// Collect all UIEvents that arrive within one orchestrator turn, stopping
|
||||
/// when the channel is drained after a `TurnComplete` or `Error`.
|
||||
async fn collect_events(rx: &mut mpsc::Receiver<UIEvent>) -> Vec<UIEvent> {
|
||||
|
|
@ -195,7 +511,7 @@ mod tests {
|
|||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(16);
|
||||
|
||||
let orch = Orchestrator::new(provider, action_rx, event_tx);
|
||||
let orch = test_orchestrator(provider, action_rx, event_tx);
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
||||
action_tx
|
||||
|
|
@ -233,7 +549,7 @@ mod tests {
|
|||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(16);
|
||||
|
||||
let orch = Orchestrator::new(provider, action_rx, event_tx);
|
||||
let orch = test_orchestrator(provider, action_rx, event_tx);
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
||||
action_tx
|
||||
|
|
@ -264,6 +580,7 @@ mod tests {
|
|||
fn stream<'a>(
|
||||
&'a self,
|
||||
_messages: &'a [ConversationMessage],
|
||||
_tools: &'a [ToolDefinition],
|
||||
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
||||
panic!("stream() must not be called after Quit");
|
||||
#[allow(unreachable_code)]
|
||||
|
|
@ -274,7 +591,7 @@ mod tests {
|
|||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||
let (event_tx, _event_rx) = mpsc::channel::<UIEvent>(8);
|
||||
|
||||
let orch = Orchestrator::new(NeverCalledProvider, action_rx, event_tx);
|
||||
let orch = test_orchestrator(NeverCalledProvider, action_rx, event_tx);
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
||||
action_tx.send(UserAction::Quit).await.unwrap();
|
||||
|
|
@ -311,6 +628,7 @@ mod tests {
|
|||
fn stream<'a>(
|
||||
&'a self,
|
||||
_messages: &'a [ConversationMessage],
|
||||
_tools: &'a [ToolDefinition],
|
||||
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
||||
let events = self.turns.lock().unwrap().pop_front().unwrap_or_default();
|
||||
futures::stream::iter(events)
|
||||
|
|
@ -326,7 +644,7 @@ mod tests {
|
|||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(32);
|
||||
|
||||
let orch = Orchestrator::new(provider, action_rx, event_tx);
|
||||
let orch = test_orchestrator(provider, action_rx, event_tx);
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
||||
// First turn.
|
||||
|
|
@ -350,4 +668,263 @@ mod tests {
|
|||
action_tx.send(UserAction::Quit).await.unwrap();
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
// -- tool-use accumulation ----------------------------------------------------
|
||||
|
||||
/// When the provider emits tool-use events, the orchestrator executes the
|
||||
/// tool (auto-approve since read_file is AutoApprove), feeds the result back,
|
||||
/// and the provider's second call returns text.
|
||||
#[tokio::test]
|
||||
async fn tool_use_loop_executes_and_feeds_back() {
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
struct MultiCallMock {
|
||||
turns: Arc<Mutex<VecDeque<Vec<StreamEvent>>>>,
|
||||
}
|
||||
impl ModelProvider for MultiCallMock {
|
||||
fn stream<'a>(
|
||||
&'a self,
|
||||
_messages: &'a [ConversationMessage],
|
||||
_tools: &'a [ToolDefinition],
|
||||
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
||||
let events = self.turns.lock().unwrap().pop_front().unwrap_or_default();
|
||||
futures::stream::iter(events)
|
||||
}
|
||||
}
|
||||
|
||||
let turns = Arc::new(Mutex::new(VecDeque::from([
|
||||
// First call: tool use
|
||||
vec![
|
||||
StreamEvent::TextDelta("Let me read that.".to_string()),
|
||||
StreamEvent::ToolUseDone, // spurious stop for text block
|
||||
StreamEvent::ToolUseStart {
|
||||
id: "toolu_01".to_string(),
|
||||
name: "read_file".to_string(),
|
||||
},
|
||||
StreamEvent::ToolUseInputDelta("{\"path\":\"Cargo.toml\"}".to_string()),
|
||||
StreamEvent::ToolUseDone,
|
||||
StreamEvent::Done,
|
||||
],
|
||||
// Second call: text response after tool result
|
||||
vec![
|
||||
StreamEvent::TextDelta("Here's the file.".to_string()),
|
||||
StreamEvent::Done,
|
||||
],
|
||||
])));
|
||||
|
||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(32);
|
||||
|
||||
// Use a real ToolRegistry so read_file works.
|
||||
let dir = tempfile::TempDir::new().unwrap();
|
||||
std::fs::write(dir.path().join("Cargo.toml"), "[package]\nname = \"test\"").unwrap();
|
||||
let orch = Orchestrator::new(
|
||||
MultiCallMock { turns },
|
||||
ToolRegistry::default_tools(),
|
||||
dir.path().to_path_buf(),
|
||||
action_rx,
|
||||
event_tx,
|
||||
);
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
||||
action_tx
|
||||
.send(UserAction::SendMessage("read Cargo.toml".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
|
||||
let events = collect_events(&mut event_rx).await;
|
||||
// Should see text deltas from both calls, tool events, and TurnComplete
|
||||
assert!(
|
||||
events
|
||||
.iter()
|
||||
.any(|e| matches!(e, UIEvent::StreamDelta(s) if s == "Let me read that.")),
|
||||
"expected first text delta"
|
||||
);
|
||||
assert!(
|
||||
events
|
||||
.iter()
|
||||
.any(|e| matches!(e, UIEvent::StreamDelta(s) if s == "Here's the file.")),
|
||||
"expected second text delta"
|
||||
);
|
||||
assert!(
|
||||
matches!(events.last(), Some(UIEvent::TurnComplete)),
|
||||
"expected TurnComplete, got: {events:?}"
|
||||
);
|
||||
|
||||
action_tx.send(UserAction::Quit).await.unwrap();
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
// -- truncate -----------------------------------------------------------------
|
||||
|
||||
/// Truncating at a boundary inside a multibyte character must not panic.
|
||||
/// "hello " is 6 bytes; the emoji is 4 bytes. Requesting max_len=8 would
|
||||
/// previously slice at byte 8, splitting the emoji. The fixed version uses
|
||||
/// char_indices so the cut falls on a char boundary.
|
||||
#[test]
|
||||
fn truncate_multibyte_does_not_panic() {
|
||||
let s = "hello \u{1F30D} world"; // globe emoji is 4 bytes
|
||||
// Should not panic and the result should end with "..."
|
||||
let result = truncate(s, 8);
|
||||
assert!(
|
||||
result.ends_with("..."),
|
||||
"expected truncation suffix: {result}"
|
||||
);
|
||||
// The string before "..." must be valid UTF-8 (guaranteed by the fix).
|
||||
let prefix = result.trim_end_matches("...");
|
||||
assert!(std::str::from_utf8(prefix.as_bytes()).is_ok());
|
||||
}
|
||||
|
||||
// -- JSON parse error ---------------------------------------------------------
|
||||
|
||||
/// When the provider emits malformed tool-input JSON, the orchestrator must
|
||||
/// surface a descriptive UIEvent::Error rather than silently using Null.
|
||||
#[tokio::test]
|
||||
async fn malformed_tool_json_surfaces_error() {
|
||||
let provider = MockProvider::new(vec![
|
||||
StreamEvent::ToolUseStart {
|
||||
id: "toolu_bad".to_string(),
|
||||
name: "read_file".to_string(),
|
||||
},
|
||||
StreamEvent::ToolUseInputDelta("not valid json{{".to_string()),
|
||||
StreamEvent::ToolUseDone,
|
||||
StreamEvent::Done,
|
||||
]);
|
||||
|
||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(16);
|
||||
|
||||
let orch = test_orchestrator(provider, action_rx, event_tx);
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
||||
action_tx
|
||||
.send(UserAction::SendMessage("go".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
|
||||
let events = collect_events(&mut event_rx).await;
|
||||
assert!(
|
||||
events
|
||||
.iter()
|
||||
.any(|e| matches!(e, UIEvent::Error(msg) if msg.contains("tool input JSON"))),
|
||||
"expected JSON parse error event, got: {events:?}"
|
||||
);
|
||||
|
||||
action_tx.send(UserAction::Quit).await.unwrap();
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
// -- queued messages ----------------------------------------------------------
|
||||
|
||||
/// A SendMessage sent while an approval prompt is open must be processed
|
||||
/// after the current turn completes, not silently dropped.
|
||||
///
|
||||
/// This test uses a RequiresApproval tool (write_file) so that the
|
||||
/// orchestrator blocks in wait_for_approval. While blocked, we send a second
|
||||
/// message. After approving (denied here for simplicity), the queued message
|
||||
/// should still be processed.
|
||||
#[tokio::test]
|
||||
async fn send_message_during_approval_is_queued_and_processed() {
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
struct MultiCallMock {
|
||||
turns: Arc<Mutex<VecDeque<Vec<StreamEvent>>>>,
|
||||
}
|
||||
impl ModelProvider for MultiCallMock {
|
||||
fn stream<'a>(
|
||||
&'a self,
|
||||
_messages: &'a [ConversationMessage],
|
||||
_tools: &'a [ToolDefinition],
|
||||
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
||||
let events = self.turns.lock().unwrap().pop_front().unwrap_or_default();
|
||||
futures::stream::iter(events)
|
||||
}
|
||||
}
|
||||
|
||||
let turns = Arc::new(Mutex::new(VecDeque::from([
|
||||
// Turn 1: requests write_file (RequiresApproval)
|
||||
vec![
|
||||
StreamEvent::ToolUseStart {
|
||||
id: "toolu_w".to_string(),
|
||||
name: "write_file".to_string(),
|
||||
},
|
||||
StreamEvent::ToolUseInputDelta(
|
||||
"{\"path\":\"x.txt\",\"content\":\"hi\"}".to_string(),
|
||||
),
|
||||
StreamEvent::ToolUseDone,
|
||||
StreamEvent::Done,
|
||||
],
|
||||
// Turn 1 second iteration after tool result (denied)
|
||||
vec![
|
||||
StreamEvent::TextDelta("ok denied".to_string()),
|
||||
StreamEvent::Done,
|
||||
],
|
||||
// Turn 2: the queued message
|
||||
vec![
|
||||
StreamEvent::TextDelta("queued reply".to_string()),
|
||||
StreamEvent::Done,
|
||||
],
|
||||
])));
|
||||
|
||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(16);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(64);
|
||||
|
||||
let dir = tempfile::TempDir::new().unwrap();
|
||||
let orch = Orchestrator::new(
|
||||
MultiCallMock { turns },
|
||||
ToolRegistry::default_tools(),
|
||||
dir.path().to_path_buf(),
|
||||
action_rx,
|
||||
event_tx,
|
||||
);
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
||||
// Start turn 1 -- orchestrator will block on approval.
|
||||
action_tx
|
||||
.send(UserAction::SendMessage("turn1".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Let the orchestrator reach wait_for_approval.
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
|
||||
// Send a message while blocked -- should be queued.
|
||||
action_tx
|
||||
.send(UserAction::SendMessage("queued".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Deny the tool -- unblocks wait_for_approval.
|
||||
action_tx
|
||||
.send(UserAction::ToolApprovalResponse {
|
||||
tool_use_id: "toolu_w".to_string(),
|
||||
approved: false,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Wait for both turns to complete.
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
|
||||
// Collect everything.
|
||||
let mut all_events = Vec::new();
|
||||
while let Ok(ev) = event_rx.try_recv() {
|
||||
all_events.push(ev);
|
||||
}
|
||||
|
||||
// The queued message must have produced "queued reply".
|
||||
assert!(
|
||||
all_events
|
||||
.iter()
|
||||
.any(|e| matches!(e, UIEvent::StreamDelta(s) if s == "queued reply")),
|
||||
"queued message was not processed; events: {all_events:?}"
|
||||
);
|
||||
|
||||
action_tx.send(UserAction::Quit).await.unwrap();
|
||||
handle.await.unwrap();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue