https://docs.kernel.org/userspace-api/landlock.html Reviewed-on: #5 Co-authored-by: Drew Galbraith <drew@tiramisu.one> Co-committed-by: Drew Galbraith <drew@tiramisu.one>
987 lines
36 KiB
Rust
987 lines
36 KiB
Rust
use futures::StreamExt;
|
|
use tokio::sync::mpsc;
|
|
use tracing::{debug, warn};
|
|
|
|
use crate::core::history::ConversationHistory;
|
|
use crate::core::types::{
|
|
ContentBlock, ConversationMessage, Role, StreamEvent, ToolDefinition, UIEvent, UserAction,
|
|
};
|
|
use crate::provider::ModelProvider;
|
|
use crate::sandbox::Sandbox;
|
|
use crate::tools::{RiskLevel, ToolOutput, ToolRegistry};
|
|
|
|
/// Accumulates data for a single tool-use block while it is being streamed.
|
|
///
|
|
/// Created on `ToolUseStart`, populated by `ToolUseInputDelta` fragments, and
|
|
/// consumed on `ToolUseDone` via `TryFrom<ActiveToolUse> for ContentBlock`.
|
|
struct ActiveToolUse {
|
|
id: String,
|
|
name: String,
|
|
json_buf: String,
|
|
}
|
|
|
|
impl ActiveToolUse {
|
|
/// Start accumulating a new tool-use block identified by `id` and `name`.
|
|
fn new(id: String, name: String) -> Self {
|
|
Self {
|
|
id,
|
|
name,
|
|
json_buf: String::new(),
|
|
}
|
|
}
|
|
|
|
/// Append a JSON input fragment received from a `ToolUseInputDelta` event.
|
|
fn append(&mut self, chunk: &str) {
|
|
self.json_buf.push_str(chunk);
|
|
}
|
|
}
|
|
|
|
impl TryFrom<ActiveToolUse> for ContentBlock {
|
|
type Error = serde_json::Error;
|
|
|
|
/// Parse the accumulated JSON buffer and produce a `ContentBlock::ToolUse`.
|
|
fn try_from(t: ActiveToolUse) -> Result<Self, Self::Error> {
|
|
let input = serde_json::from_str(&t.json_buf)?;
|
|
Ok(ContentBlock::ToolUse {
|
|
id: t.id,
|
|
name: t.name,
|
|
input,
|
|
})
|
|
}
|
|
}
|
|
|
|
/// Drives the conversation loop between the TUI frontend and the model provider.
|
|
///
|
|
/// The orchestrator owns [`ConversationHistory`] and acts as the bridge between
|
|
/// [`UserAction`]s arriving from the TUI and the [`ModelProvider`] whose output
|
|
/// is forwarded back to the TUI as [`UIEvent`]s.
|
|
///
|
|
/// # Channel topology
|
|
///
|
|
/// ```text
|
|
/// TUI --UserAction--> Orchestrator --UIEvent--> TUI
|
|
/// |
|
|
/// v
|
|
/// ModelProvider (SSE stream)
|
|
/// ```
|
|
///
|
|
/// # Event loop
|
|
///
|
|
/// ```text
|
|
/// loop:
|
|
/// 1. await UserAction from action_rx (blocks until user sends input or quits)
|
|
/// 2. SendMessage:
|
|
/// a. Append user message to history
|
|
/// b. Call provider.stream(history) -- starts an SSE request
|
|
/// c. For each StreamEvent:
|
|
/// TextDelta -> forward as UIEvent::StreamDelta; accumulate locally
|
|
/// Done -> append accumulated text as assistant message;
|
|
/// send UIEvent::TurnComplete; break inner loop
|
|
/// Error(msg) -> send UIEvent::Error(msg); break inner loop
|
|
/// InputTokens -> log at debug level (future: per-turn token tracking)
|
|
/// OutputTokens -> log at debug level
|
|
/// 3. Quit -> return
|
|
/// ```
|
|
/// The result of consuming one provider stream (one assistant turn).
|
|
enum StreamResult {
|
|
/// The stream completed successfully with these content blocks.
|
|
Done(Vec<ContentBlock>),
|
|
/// The stream produced an error.
|
|
Error(String),
|
|
}
|
|
|
|
/// Maximum number of tool-use loop iterations per user message.
|
|
const MAX_TOOL_ITERATIONS: usize = 25;
|
|
|
|
/// Truncate a string to `max_len` Unicode scalar values, appending "..." if
|
|
/// truncated. Uses `char_indices` so that multibyte characters are never split
|
|
/// at a byte boundary (which would panic on a bare slice).
|
|
fn truncate(s: &str, max_len: usize) -> String {
|
|
match s.char_indices().nth(max_len) {
|
|
Some((i, _)) => format!("{}...", &s[..i]),
|
|
None => s.to_string(),
|
|
}
|
|
}
|
|
|
|
pub struct Orchestrator<P> {
|
|
history: ConversationHistory,
|
|
provider: P,
|
|
tool_registry: ToolRegistry,
|
|
sandbox: Sandbox,
|
|
action_rx: mpsc::Receiver<UserAction>,
|
|
event_tx: mpsc::Sender<UIEvent>,
|
|
/// Messages typed by the user while an approval prompt is open. They are
|
|
/// queued here and replayed as new turns once the current turn completes.
|
|
queued_messages: Vec<String>,
|
|
}
|
|
|
|
impl<P: ModelProvider> Orchestrator<P> {
|
|
/// Construct an orchestrator using the given provider and channel endpoints.
|
|
pub fn new(
|
|
provider: P,
|
|
tool_registry: ToolRegistry,
|
|
sandbox: Sandbox,
|
|
action_rx: mpsc::Receiver<UserAction>,
|
|
event_tx: mpsc::Sender<UIEvent>,
|
|
) -> Self {
|
|
Self {
|
|
history: ConversationHistory::new(),
|
|
provider,
|
|
tool_registry,
|
|
sandbox,
|
|
action_rx,
|
|
event_tx,
|
|
queued_messages: Vec::new(),
|
|
}
|
|
}
|
|
|
|
/// Consume the provider's stream for one turn, returning the content blocks
|
|
/// produced by the assistant (text and/or tool-use) or an error.
|
|
///
|
|
/// Tool-use input JSON is accumulated from `ToolUseInputDelta` fragments and
|
|
/// parsed into `ContentBlock::ToolUse` on `ToolUseDone`.
|
|
async fn consume_stream(
|
|
&self,
|
|
messages: &[ConversationMessage],
|
|
tools: &[ToolDefinition],
|
|
) -> StreamResult {
|
|
let mut blocks: Vec<ContentBlock> = Vec::new();
|
|
let mut text_buf = String::new();
|
|
let mut active_tool: Option<ActiveToolUse> = None;
|
|
|
|
let mut stream = Box::pin(self.provider.stream(messages, tools));
|
|
|
|
while let Some(event) = stream.next().await {
|
|
match event {
|
|
StreamEvent::TextDelta(chunk) => {
|
|
text_buf.push_str(&chunk);
|
|
let _ = self.event_tx.send(UIEvent::StreamDelta(chunk)).await;
|
|
}
|
|
StreamEvent::ToolUseStart { id, name } => {
|
|
// Flush any accumulated text before starting tool block.
|
|
if !text_buf.is_empty() {
|
|
blocks.push(ContentBlock::Text {
|
|
text: std::mem::take(&mut text_buf),
|
|
});
|
|
}
|
|
active_tool = Some(ActiveToolUse::new(id, name));
|
|
}
|
|
StreamEvent::ToolUseInputDelta(chunk) => {
|
|
if let Some(t) = &mut active_tool {
|
|
t.append(&chunk);
|
|
} else {
|
|
warn!(
|
|
"received ToolUseInputDelta outside of an active tool-use block -- ignoring"
|
|
);
|
|
}
|
|
}
|
|
StreamEvent::ToolUseDone => {
|
|
if let Some(t) = active_tool.take() {
|
|
match ContentBlock::try_from(t) {
|
|
Ok(block) => blocks.push(block),
|
|
Err(e) => {
|
|
return StreamResult::Error(format!(
|
|
"failed to parse tool input JSON: {e}"
|
|
));
|
|
}
|
|
}
|
|
} else {
|
|
warn!(
|
|
"received ToolUseDone outside of an active tool-use block -- ignoring"
|
|
);
|
|
}
|
|
}
|
|
StreamEvent::Done => {
|
|
// Flush trailing text.
|
|
if !text_buf.is_empty() {
|
|
blocks.push(ContentBlock::Text { text: text_buf });
|
|
}
|
|
return StreamResult::Done(blocks);
|
|
}
|
|
StreamEvent::Error(msg) => {
|
|
return StreamResult::Error(msg);
|
|
}
|
|
StreamEvent::InputTokens(n) => {
|
|
debug!(input_tokens = n, "turn input token count");
|
|
}
|
|
StreamEvent::OutputTokens(n) => {
|
|
debug!(output_tokens = n, "turn output token count");
|
|
}
|
|
}
|
|
}
|
|
|
|
// Stream ended without Done -- treat as error.
|
|
StreamResult::Error("stream ended unexpectedly".to_string())
|
|
}
|
|
|
|
/// Execute one complete turn: stream from provider, execute tools if needed,
|
|
/// loop back until the model produces a text-only response or we hit the
|
|
/// iteration limit.
|
|
async fn run_turn(&mut self) {
|
|
let tool_defs = self.tool_registry.definitions();
|
|
|
|
for _ in 0..MAX_TOOL_ITERATIONS {
|
|
let messages = self.history.messages().to_vec();
|
|
let result = self.consume_stream(&messages, &tool_defs).await;
|
|
|
|
match result {
|
|
StreamResult::Error(msg) => {
|
|
let _ = self.event_tx.send(UIEvent::Error(msg)).await;
|
|
return;
|
|
}
|
|
StreamResult::Done(blocks) => {
|
|
let has_tool_use = blocks
|
|
.iter()
|
|
.any(|b| matches!(b, ContentBlock::ToolUse { .. }));
|
|
|
|
// Append the assistant message to history.
|
|
self.history.push(ConversationMessage {
|
|
role: Role::Assistant,
|
|
content: blocks.clone(),
|
|
});
|
|
|
|
if !has_tool_use {
|
|
let _ = self.event_tx.send(UIEvent::TurnComplete).await;
|
|
return;
|
|
}
|
|
|
|
// TODO: execute tool-use blocks in parallel -- requires
|
|
// refactoring Orchestrator to use &self + interior
|
|
// mutability (Arc<Mutex<...>> around event_tx, action_rx,
|
|
// history) so that multiple futures can borrow self
|
|
// simultaneously via futures::future::join_all.
|
|
// Execute each tool-use block and collect results.
|
|
let mut tool_results: Vec<ContentBlock> = Vec::new();
|
|
|
|
for block in &blocks {
|
|
if let ContentBlock::ToolUse { id, name, input } = block {
|
|
let result = self.execute_tool_with_approval(id, name, input).await;
|
|
tool_results.push(ContentBlock::ToolResult {
|
|
tool_use_id: id.clone(),
|
|
content: result.content,
|
|
is_error: result.is_error,
|
|
});
|
|
}
|
|
}
|
|
|
|
// Append tool results as a user message and loop.
|
|
self.history.push(ConversationMessage {
|
|
role: Role::User,
|
|
content: tool_results,
|
|
});
|
|
}
|
|
}
|
|
}
|
|
warn!("tool-use loop reached max iterations ({MAX_TOOL_ITERATIONS})");
|
|
let _ = self
|
|
.event_tx
|
|
.send(UIEvent::Error(
|
|
"tool-use loop reached maximum iterations".to_string(),
|
|
))
|
|
.await;
|
|
}
|
|
|
|
/// Execute a single tool, handling approval if needed.
|
|
///
|
|
/// For auto-approve tools, executes immediately. For tools requiring
|
|
/// approval, sends a request to the TUI and waits for a response.
|
|
async fn execute_tool_with_approval(
|
|
&mut self,
|
|
tool_use_id: &str,
|
|
tool_name: &str,
|
|
input: &serde_json::Value,
|
|
) -> ToolOutput {
|
|
// Extract tool info upfront to avoid holding a borrow on self across
|
|
// the mutable wait_for_approval call.
|
|
let risk = match self.tool_registry.get(tool_name) {
|
|
Some(t) => t.risk_level(),
|
|
None => {
|
|
return ToolOutput {
|
|
content: format!("unknown tool: {tool_name}"),
|
|
is_error: true,
|
|
};
|
|
}
|
|
};
|
|
|
|
let input_summary = serde_json::to_string(input).unwrap_or_default();
|
|
|
|
// Check approval.
|
|
let approved = match risk {
|
|
RiskLevel::AutoApprove => {
|
|
let _ = self
|
|
.event_tx
|
|
.send(UIEvent::ToolExecuting {
|
|
tool_name: tool_name.to_string(),
|
|
input_summary: input_summary.clone(),
|
|
})
|
|
.await;
|
|
true
|
|
}
|
|
RiskLevel::RequiresApproval => {
|
|
let _ = self
|
|
.event_tx
|
|
.send(UIEvent::ToolApprovalRequest {
|
|
tool_use_id: tool_use_id.to_string(),
|
|
tool_name: tool_name.to_string(),
|
|
input_summary: input_summary.clone(),
|
|
})
|
|
.await;
|
|
|
|
// Wait for approval response from TUI.
|
|
self.wait_for_approval(tool_use_id).await
|
|
}
|
|
};
|
|
|
|
if !approved {
|
|
return ToolOutput {
|
|
content: "tool execution denied by user".to_string(),
|
|
is_error: true,
|
|
};
|
|
}
|
|
|
|
// Re-fetch tool for execution (borrow was released above).
|
|
let tool = self.tool_registry.get(tool_name).unwrap();
|
|
match tool.execute(input, &self.sandbox).await {
|
|
Ok(output) => {
|
|
let _ = self
|
|
.event_tx
|
|
.send(UIEvent::ToolResult {
|
|
tool_name: tool_name.to_string(),
|
|
output_summary: truncate(&output.content, 200),
|
|
is_error: output.is_error,
|
|
})
|
|
.await;
|
|
output
|
|
}
|
|
Err(e) => {
|
|
let msg = e.to_string();
|
|
let _ = self
|
|
.event_tx
|
|
.send(UIEvent::ToolResult {
|
|
tool_name: tool_name.to_string(),
|
|
output_summary: msg.clone(),
|
|
is_error: true,
|
|
})
|
|
.await;
|
|
ToolOutput {
|
|
content: msg,
|
|
is_error: true,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Wait for a `ToolApprovalResponse` matching `tool_use_id` from the TUI.
|
|
///
|
|
/// Any `SendMessage` actions that arrive during the wait are pushed onto
|
|
/// `self.queued_messages` rather than discarded. They are replayed as new
|
|
/// turns once the current tool-use loop completes (see `run()`).
|
|
///
|
|
/// Returns `true` if approved, `false` if denied or the channel closes.
|
|
async fn wait_for_approval(&mut self, tool_use_id: &str) -> bool {
|
|
while let Some(action) = self.action_rx.recv().await {
|
|
match action {
|
|
UserAction::ToolApprovalResponse {
|
|
tool_use_id: id,
|
|
approved,
|
|
} if id == tool_use_id => {
|
|
return approved;
|
|
}
|
|
UserAction::Quit => return false,
|
|
UserAction::SendMessage(text) => {
|
|
self.queued_messages.push(text);
|
|
}
|
|
_ => {} // discard stale approvals / ClearHistory during wait
|
|
}
|
|
}
|
|
false
|
|
}
|
|
|
|
/// Run the orchestrator until the user quits or the `action_rx` channel closes.
|
|
pub async fn run(mut self) {
|
|
while let Some(action) = self.action_rx.recv().await {
|
|
match action {
|
|
UserAction::Quit => break,
|
|
UserAction::ClearHistory => {
|
|
self.history.clear();
|
|
}
|
|
|
|
// Approval responses are handled inline during tool execution,
|
|
// not in the main action loop. If one arrives here it's stale.
|
|
UserAction::ToolApprovalResponse { .. } => {}
|
|
|
|
UserAction::SetNetworkPolicy(allowed) => {
|
|
self.sandbox.set_network_allowed(allowed);
|
|
let _ = self
|
|
.event_tx
|
|
.send(UIEvent::NetworkPolicyChanged(allowed))
|
|
.await;
|
|
}
|
|
|
|
UserAction::SendMessage(text) => {
|
|
self.history
|
|
.push(ConversationMessage::text(Role::User, text));
|
|
self.run_turn().await;
|
|
|
|
// Drain any messages queued while an approval prompt was
|
|
// open. Each queued message is a full turn in sequence.
|
|
while !self.queued_messages.is_empty() {
|
|
let queued = std::mem::take(&mut self.queued_messages);
|
|
for msg in queued {
|
|
self.history
|
|
.push(ConversationMessage::text(Role::User, msg));
|
|
self.run_turn().await;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::core::types::ToolDefinition;
|
|
use crate::tools::ToolRegistry;
|
|
use futures::Stream;
|
|
use tokio::sync::mpsc;
|
|
|
|
/// A provider that replays a fixed sequence of [`StreamEvent`]s.
|
|
///
|
|
/// Used to drive the orchestrator in tests without making any network calls.
|
|
struct MockProvider {
|
|
events: Vec<StreamEvent>,
|
|
}
|
|
|
|
impl MockProvider {
|
|
fn new(events: Vec<StreamEvent>) -> Self {
|
|
Self { events }
|
|
}
|
|
}
|
|
|
|
impl ModelProvider for MockProvider {
|
|
fn stream<'a>(
|
|
&'a self,
|
|
_messages: &'a [ConversationMessage],
|
|
_tools: &'a [ToolDefinition],
|
|
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
|
futures::stream::iter(self.events.clone())
|
|
}
|
|
}
|
|
|
|
/// Create an Orchestrator with no tools for testing text-only flows.
|
|
fn test_orchestrator<P: ModelProvider>(
|
|
provider: P,
|
|
action_rx: mpsc::Receiver<UserAction>,
|
|
event_tx: mpsc::Sender<UIEvent>,
|
|
) -> Orchestrator<P> {
|
|
use crate::sandbox::policy::SandboxPolicy;
|
|
use crate::sandbox::{EnforcementMode, Sandbox};
|
|
let sandbox = Sandbox::new(
|
|
SandboxPolicy::for_project(std::path::PathBuf::from("/tmp")),
|
|
std::path::PathBuf::from("/tmp"),
|
|
EnforcementMode::Yolo,
|
|
);
|
|
Orchestrator::new(
|
|
provider,
|
|
ToolRegistry::empty(),
|
|
sandbox,
|
|
action_rx,
|
|
event_tx,
|
|
)
|
|
}
|
|
|
|
/// Collect all UIEvents that arrive within one orchestrator turn, stopping
|
|
/// when the channel is drained after a `TurnComplete` or `Error`.
|
|
async fn collect_events(rx: &mut mpsc::Receiver<UIEvent>) -> Vec<UIEvent> {
|
|
let mut out = Vec::new();
|
|
while let Ok(ev) = rx.try_recv() {
|
|
let done = matches!(ev, UIEvent::TurnComplete | UIEvent::Error(_));
|
|
out.push(ev);
|
|
if done {
|
|
break;
|
|
}
|
|
}
|
|
out
|
|
}
|
|
|
|
// -- happy-path turn ----------------------------------------------------------
|
|
|
|
/// A full successful turn: text chunks followed by Done.
|
|
///
|
|
/// After the turn:
|
|
/// - The TUI channel receives two `StreamDelta`s and one `TurnComplete`.
|
|
/// - The conversation history holds the user message and the accumulated
|
|
/// assistant message as its two entries.
|
|
#[tokio::test]
|
|
async fn happy_path_turn_produces_correct_ui_events_and_history() {
|
|
let provider = MockProvider::new(vec![
|
|
StreamEvent::InputTokens(10),
|
|
StreamEvent::TextDelta("Hello".to_string()),
|
|
StreamEvent::TextDelta(", world!".to_string()),
|
|
StreamEvent::OutputTokens(5),
|
|
StreamEvent::Done,
|
|
]);
|
|
|
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
|
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(16);
|
|
|
|
let orch = test_orchestrator(provider, action_rx, event_tx);
|
|
let handle = tokio::spawn(orch.run());
|
|
|
|
action_tx
|
|
.send(UserAction::SendMessage("hi".to_string()))
|
|
.await
|
|
.unwrap();
|
|
|
|
// Give the orchestrator time to process the stream.
|
|
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
|
|
|
let events = collect_events(&mut event_rx).await;
|
|
|
|
// Verify the UIEvent sequence.
|
|
assert_eq!(events.len(), 3);
|
|
assert!(matches!(&events[0], UIEvent::StreamDelta(s) if s == "Hello"));
|
|
assert!(matches!(&events[1], UIEvent::StreamDelta(s) if s == ", world!"));
|
|
assert!(matches!(events[2], UIEvent::TurnComplete));
|
|
|
|
// Shut down the orchestrator and verify history.
|
|
action_tx.send(UserAction::Quit).await.unwrap();
|
|
handle.await.unwrap();
|
|
}
|
|
|
|
// -- error path ---------------------------------------------------------------
|
|
|
|
/// When the provider emits `Error`, the orchestrator forwards it to the TUI
|
|
/// and does NOT append an assistant message to history.
|
|
#[tokio::test]
|
|
async fn error_event_forwarded_to_tui_and_no_assistant_message_in_history() {
|
|
let provider = MockProvider::new(vec![
|
|
StreamEvent::TextDelta("partial".to_string()),
|
|
StreamEvent::Error("network timeout".to_string()),
|
|
]);
|
|
|
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
|
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(16);
|
|
|
|
let orch = test_orchestrator(provider, action_rx, event_tx);
|
|
let handle = tokio::spawn(orch.run());
|
|
|
|
action_tx
|
|
.send(UserAction::SendMessage("hello".to_string()))
|
|
.await
|
|
.unwrap();
|
|
|
|
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
|
|
|
let events = collect_events(&mut event_rx).await;
|
|
|
|
assert_eq!(events.len(), 2);
|
|
assert!(matches!(&events[0], UIEvent::StreamDelta(s) if s == "partial"));
|
|
assert!(matches!(&events[1], UIEvent::Error(msg) if msg == "network timeout"));
|
|
|
|
action_tx.send(UserAction::Quit).await.unwrap();
|
|
handle.await.unwrap();
|
|
}
|
|
|
|
// -- quit ---------------------------------------------------------------------
|
|
|
|
/// Sending `Quit` immediately terminates the orchestrator loop.
|
|
#[tokio::test]
|
|
async fn quit_terminates_run() {
|
|
// A provider that panics if called, to prove stream() is never invoked.
|
|
struct NeverCalledProvider;
|
|
impl ModelProvider for NeverCalledProvider {
|
|
fn stream<'a>(
|
|
&'a self,
|
|
_messages: &'a [ConversationMessage],
|
|
_tools: &'a [ToolDefinition],
|
|
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
|
panic!("stream() must not be called after Quit");
|
|
#[allow(unreachable_code)]
|
|
futures::stream::empty()
|
|
}
|
|
}
|
|
|
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
|
let (event_tx, _event_rx) = mpsc::channel::<UIEvent>(8);
|
|
|
|
let orch = test_orchestrator(NeverCalledProvider, action_rx, event_tx);
|
|
let handle = tokio::spawn(orch.run());
|
|
|
|
action_tx.send(UserAction::Quit).await.unwrap();
|
|
handle.await.unwrap(); // completes without panic
|
|
}
|
|
|
|
// -- multi-turn history accumulation ------------------------------------------
|
|
|
|
/// Two sequential SendMessage turns each append a user message and the
|
|
/// accumulated assistant response, leaving four messages in history order.
|
|
///
|
|
/// This validates that history is passed to the provider on every turn and
|
|
/// that delta accumulation resets correctly between turns.
|
|
#[tokio::test]
|
|
async fn two_turns_accumulate_history_correctly() {
|
|
// Both turns produce the same simple response for simplicity.
|
|
let make_turn_events = || {
|
|
vec![
|
|
StreamEvent::TextDelta("reply".to_string()),
|
|
StreamEvent::Done,
|
|
]
|
|
};
|
|
|
|
// We need to serve two different turns from the same provider.
|
|
// Use an `Arc<Mutex<VecDeque>>` so the provider can pop event sets.
|
|
use std::collections::VecDeque;
|
|
use std::sync::{Arc, Mutex};
|
|
|
|
struct MultiTurnMock {
|
|
turns: Arc<Mutex<VecDeque<Vec<StreamEvent>>>>,
|
|
}
|
|
|
|
impl ModelProvider for MultiTurnMock {
|
|
fn stream<'a>(
|
|
&'a self,
|
|
_messages: &'a [ConversationMessage],
|
|
_tools: &'a [ToolDefinition],
|
|
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
|
let events = self.turns.lock().unwrap().pop_front().unwrap_or_default();
|
|
futures::stream::iter(events)
|
|
}
|
|
}
|
|
|
|
let turns = Arc::new(Mutex::new(VecDeque::from([
|
|
make_turn_events(),
|
|
make_turn_events(),
|
|
])));
|
|
let provider = MultiTurnMock { turns };
|
|
|
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
|
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(32);
|
|
|
|
let orch = test_orchestrator(provider, action_rx, event_tx);
|
|
let handle = tokio::spawn(orch.run());
|
|
|
|
// First turn.
|
|
action_tx
|
|
.send(UserAction::SendMessage("turn one".to_string()))
|
|
.await
|
|
.unwrap();
|
|
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
|
let ev1 = collect_events(&mut event_rx).await;
|
|
assert!(matches!(ev1.last(), Some(UIEvent::TurnComplete)));
|
|
|
|
// Second turn.
|
|
action_tx
|
|
.send(UserAction::SendMessage("turn two".to_string()))
|
|
.await
|
|
.unwrap();
|
|
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
|
let ev2 = collect_events(&mut event_rx).await;
|
|
assert!(matches!(ev2.last(), Some(UIEvent::TurnComplete)));
|
|
|
|
action_tx.send(UserAction::Quit).await.unwrap();
|
|
handle.await.unwrap();
|
|
}
|
|
|
|
// -- tool-use accumulation ----------------------------------------------------
|
|
|
|
/// When the provider emits tool-use events, the orchestrator executes the
|
|
/// tool (auto-approve since read_file is AutoApprove), feeds the result back,
|
|
/// and the provider's second call returns text.
|
|
#[tokio::test]
|
|
async fn tool_use_loop_executes_and_feeds_back() {
|
|
use std::collections::VecDeque;
|
|
use std::sync::{Arc, Mutex};
|
|
|
|
struct MultiCallMock {
|
|
turns: Arc<Mutex<VecDeque<Vec<StreamEvent>>>>,
|
|
}
|
|
impl ModelProvider for MultiCallMock {
|
|
fn stream<'a>(
|
|
&'a self,
|
|
_messages: &'a [ConversationMessage],
|
|
_tools: &'a [ToolDefinition],
|
|
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
|
let events = self.turns.lock().unwrap().pop_front().unwrap_or_default();
|
|
futures::stream::iter(events)
|
|
}
|
|
}
|
|
|
|
let turns = Arc::new(Mutex::new(VecDeque::from([
|
|
// First call: tool use
|
|
vec![
|
|
StreamEvent::TextDelta("Let me read that.".to_string()),
|
|
StreamEvent::ToolUseDone, // spurious stop for text block
|
|
StreamEvent::ToolUseStart {
|
|
id: "toolu_01".to_string(),
|
|
name: "read_file".to_string(),
|
|
},
|
|
StreamEvent::ToolUseInputDelta("{\"path\":\"Cargo.toml\"}".to_string()),
|
|
StreamEvent::ToolUseDone,
|
|
StreamEvent::Done,
|
|
],
|
|
// Second call: text response after tool result
|
|
vec![
|
|
StreamEvent::TextDelta("Here's the file.".to_string()),
|
|
StreamEvent::Done,
|
|
],
|
|
])));
|
|
|
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
|
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(32);
|
|
|
|
// Use a real ToolRegistry so read_file works.
|
|
use crate::sandbox::policy::SandboxPolicy;
|
|
use crate::sandbox::{EnforcementMode, Sandbox};
|
|
let dir = tempfile::TempDir::new().unwrap();
|
|
std::fs::write(dir.path().join("Cargo.toml"), "[package]\nname = \"test\"").unwrap();
|
|
let sandbox = Sandbox::new(
|
|
SandboxPolicy::for_project(dir.path().to_path_buf()),
|
|
dir.path().to_path_buf(),
|
|
EnforcementMode::Yolo,
|
|
);
|
|
let orch = Orchestrator::new(
|
|
MultiCallMock { turns },
|
|
ToolRegistry::default_tools(),
|
|
sandbox,
|
|
action_rx,
|
|
event_tx,
|
|
);
|
|
let handle = tokio::spawn(orch.run());
|
|
|
|
action_tx
|
|
.send(UserAction::SendMessage("read Cargo.toml".to_string()))
|
|
.await
|
|
.unwrap();
|
|
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
|
|
|
let events = collect_events(&mut event_rx).await;
|
|
// Should see text deltas from both calls, tool events, and TurnComplete
|
|
assert!(
|
|
events
|
|
.iter()
|
|
.any(|e| matches!(e, UIEvent::StreamDelta(s) if s == "Let me read that.")),
|
|
"expected first text delta"
|
|
);
|
|
assert!(
|
|
events
|
|
.iter()
|
|
.any(|e| matches!(e, UIEvent::StreamDelta(s) if s == "Here's the file.")),
|
|
"expected second text delta"
|
|
);
|
|
assert!(
|
|
matches!(events.last(), Some(UIEvent::TurnComplete)),
|
|
"expected TurnComplete, got: {events:?}"
|
|
);
|
|
|
|
action_tx.send(UserAction::Quit).await.unwrap();
|
|
handle.await.unwrap();
|
|
}
|
|
|
|
// -- truncate -----------------------------------------------------------------
|
|
|
|
/// Truncating at a boundary inside a multibyte character must not panic.
|
|
/// "hello " is 6 bytes; the emoji is 4 bytes. Requesting max_len=8 would
|
|
/// previously slice at byte 8, splitting the emoji. The fixed version uses
|
|
/// char_indices so the cut falls on a char boundary.
|
|
#[test]
|
|
fn truncate_multibyte_does_not_panic() {
|
|
let s = "hello \u{1F30D} world"; // globe emoji is 4 bytes
|
|
// Should not panic and the result should end with "..."
|
|
let result = truncate(s, 8);
|
|
assert!(
|
|
result.ends_with("..."),
|
|
"expected truncation suffix: {result}"
|
|
);
|
|
// The string before "..." must be valid UTF-8 (guaranteed by the fix).
|
|
let prefix = result.trim_end_matches("...");
|
|
assert!(std::str::from_utf8(prefix.as_bytes()).is_ok());
|
|
}
|
|
|
|
// -- JSON parse error ---------------------------------------------------------
|
|
|
|
/// When the provider emits malformed tool-input JSON, the orchestrator must
|
|
/// surface a descriptive UIEvent::Error rather than silently using Null.
|
|
#[tokio::test]
|
|
async fn malformed_tool_json_surfaces_error() {
|
|
let provider = MockProvider::new(vec![
|
|
StreamEvent::ToolUseStart {
|
|
id: "toolu_bad".to_string(),
|
|
name: "read_file".to_string(),
|
|
},
|
|
StreamEvent::ToolUseInputDelta("not valid json{{".to_string()),
|
|
StreamEvent::ToolUseDone,
|
|
StreamEvent::Done,
|
|
]);
|
|
|
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
|
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(16);
|
|
|
|
let orch = test_orchestrator(provider, action_rx, event_tx);
|
|
let handle = tokio::spawn(orch.run());
|
|
|
|
action_tx
|
|
.send(UserAction::SendMessage("go".to_string()))
|
|
.await
|
|
.unwrap();
|
|
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
|
|
|
let events = collect_events(&mut event_rx).await;
|
|
assert!(
|
|
events
|
|
.iter()
|
|
.any(|e| matches!(e, UIEvent::Error(msg) if msg.contains("tool input JSON"))),
|
|
"expected JSON parse error event, got: {events:?}"
|
|
);
|
|
|
|
action_tx.send(UserAction::Quit).await.unwrap();
|
|
handle.await.unwrap();
|
|
}
|
|
|
|
// -- queued messages ----------------------------------------------------------
|
|
|
|
/// A SendMessage sent while an approval prompt is open must be processed
|
|
/// after the current turn completes, not silently dropped.
|
|
///
|
|
/// This test uses a RequiresApproval tool (write_file) so that the
|
|
/// orchestrator blocks in wait_for_approval. While blocked, we send a second
|
|
/// message. After approving (denied here for simplicity), the queued message
|
|
/// should still be processed.
|
|
#[tokio::test]
|
|
async fn send_message_during_approval_is_queued_and_processed() {
|
|
use std::collections::VecDeque;
|
|
use std::sync::{Arc, Mutex};
|
|
|
|
struct MultiCallMock {
|
|
turns: Arc<Mutex<VecDeque<Vec<StreamEvent>>>>,
|
|
}
|
|
impl ModelProvider for MultiCallMock {
|
|
fn stream<'a>(
|
|
&'a self,
|
|
_messages: &'a [ConversationMessage],
|
|
_tools: &'a [ToolDefinition],
|
|
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
|
let events = self.turns.lock().unwrap().pop_front().unwrap_or_default();
|
|
futures::stream::iter(events)
|
|
}
|
|
}
|
|
|
|
let turns = Arc::new(Mutex::new(VecDeque::from([
|
|
// Turn 1: requests write_file (RequiresApproval)
|
|
vec![
|
|
StreamEvent::ToolUseStart {
|
|
id: "toolu_w".to_string(),
|
|
name: "write_file".to_string(),
|
|
},
|
|
StreamEvent::ToolUseInputDelta(
|
|
"{\"path\":\"x.txt\",\"content\":\"hi\"}".to_string(),
|
|
),
|
|
StreamEvent::ToolUseDone,
|
|
StreamEvent::Done,
|
|
],
|
|
// Turn 1 second iteration after tool result (denied)
|
|
vec![
|
|
StreamEvent::TextDelta("ok denied".to_string()),
|
|
StreamEvent::Done,
|
|
],
|
|
// Turn 2: the queued message
|
|
vec![
|
|
StreamEvent::TextDelta("queued reply".to_string()),
|
|
StreamEvent::Done,
|
|
],
|
|
])));
|
|
|
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(16);
|
|
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(64);
|
|
|
|
let orch = test_orchestrator(MultiCallMock { turns }, action_rx, event_tx);
|
|
let handle = tokio::spawn(orch.run());
|
|
|
|
// Start turn 1 -- orchestrator will block on approval.
|
|
action_tx
|
|
.send(UserAction::SendMessage("turn1".to_string()))
|
|
.await
|
|
.unwrap();
|
|
|
|
// Let the orchestrator reach wait_for_approval.
|
|
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
|
|
|
// Send a message while blocked -- should be queued.
|
|
action_tx
|
|
.send(UserAction::SendMessage("queued".to_string()))
|
|
.await
|
|
.unwrap();
|
|
|
|
// Deny the tool -- unblocks wait_for_approval.
|
|
action_tx
|
|
.send(UserAction::ToolApprovalResponse {
|
|
tool_use_id: "toolu_w".to_string(),
|
|
approved: false,
|
|
})
|
|
.await
|
|
.unwrap();
|
|
|
|
// Wait for both turns to complete.
|
|
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
|
|
|
// Collect everything.
|
|
let mut all_events = Vec::new();
|
|
while let Ok(ev) = event_rx.try_recv() {
|
|
all_events.push(ev);
|
|
}
|
|
|
|
// The queued message must have produced "queued reply".
|
|
assert!(
|
|
all_events
|
|
.iter()
|
|
.any(|e| matches!(e, UIEvent::StreamDelta(s) if s == "queued reply")),
|
|
"queued message was not processed; events: {all_events:?}"
|
|
);
|
|
|
|
action_tx.send(UserAction::Quit).await.unwrap();
|
|
handle.await.unwrap();
|
|
}
|
|
|
|
// -- network policy toggle ------------------------------------------------
|
|
|
|
/// SetNetworkPolicy sends a NetworkPolicyChanged event back to the TUI.
|
|
#[tokio::test]
|
|
async fn set_network_policy_sends_event() {
|
|
// Provider that returns immediately to avoid blocking.
|
|
struct NeverCalledProvider;
|
|
impl ModelProvider for NeverCalledProvider {
|
|
fn stream<'a>(
|
|
&'a self,
|
|
_messages: &'a [ConversationMessage],
|
|
_tools: &'a [ToolDefinition],
|
|
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
|
futures::stream::empty()
|
|
}
|
|
}
|
|
|
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
|
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(8);
|
|
|
|
let orch = test_orchestrator(NeverCalledProvider, action_rx, event_tx);
|
|
let handle = tokio::spawn(orch.run());
|
|
|
|
action_tx
|
|
.send(UserAction::SetNetworkPolicy(true))
|
|
.await
|
|
.unwrap();
|
|
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
|
|
|
let mut found = false;
|
|
while let Ok(ev) = event_rx.try_recv() {
|
|
if matches!(ev, UIEvent::NetworkPolicyChanged(true)) {
|
|
found = true;
|
|
}
|
|
}
|
|
assert!(found, "expected NetworkPolicyChanged(true) event");
|
|
|
|
action_tx.send(UserAction::Quit).await.unwrap();
|
|
handle.await.unwrap();
|
|
}
|
|
}
|