Fix some issues with UI getting out of sync. (#7)
Reviewed-on: #7 Co-authored-by: Drew Galbraith <drew@tiramisu.one> Co-committed-by: Drew Galbraith <drew@tiramisu.one>
This commit is contained in:
parent
0fcdf4ed0d
commit
af080710cc
8 changed files with 199 additions and 84 deletions
|
|
@ -4,7 +4,8 @@ use tracing::{debug, warn};
|
|||
|
||||
use crate::core::history::ConversationHistory;
|
||||
use crate::core::types::{
|
||||
ContentBlock, ConversationMessage, Role, StreamEvent, ToolDefinition, UIEvent, UserAction,
|
||||
ContentBlock, ConversationMessage, Role, StampedEvent, StreamEvent, ToolDefinition, UIEvent,
|
||||
UserAction,
|
||||
};
|
||||
use crate::provider::ModelProvider;
|
||||
use crate::sandbox::Sandbox;
|
||||
|
|
@ -109,10 +110,13 @@ pub struct Orchestrator<P> {
|
|||
tool_registry: ToolRegistry,
|
||||
sandbox: Sandbox,
|
||||
action_rx: mpsc::Receiver<UserAction>,
|
||||
event_tx: mpsc::Sender<UIEvent>,
|
||||
event_tx: mpsc::Sender<StampedEvent>,
|
||||
/// Messages typed by the user while an approval prompt is open. They are
|
||||
/// queued here and replayed as new turns once the current turn completes.
|
||||
queued_messages: Vec<String>,
|
||||
/// Monotonic epoch incremented on each `:clear`. Events are tagged with
|
||||
/// this value so the TUI can discard stale in-flight messages.
|
||||
epoch: u64,
|
||||
}
|
||||
|
||||
impl<P: ModelProvider> Orchestrator<P> {
|
||||
|
|
@ -122,7 +126,7 @@ impl<P: ModelProvider> Orchestrator<P> {
|
|||
tool_registry: ToolRegistry,
|
||||
sandbox: Sandbox,
|
||||
action_rx: mpsc::Receiver<UserAction>,
|
||||
event_tx: mpsc::Sender<UIEvent>,
|
||||
event_tx: mpsc::Sender<StampedEvent>,
|
||||
) -> Self {
|
||||
Self {
|
||||
history: ConversationHistory::new(),
|
||||
|
|
@ -132,6 +136,25 @@ impl<P: ModelProvider> Orchestrator<P> {
|
|||
action_rx,
|
||||
event_tx,
|
||||
queued_messages: Vec::new(),
|
||||
epoch: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a [`UIEvent`] stamped with the current epoch.
|
||||
///
|
||||
/// Drops the event silently if the channel is full or closed -- this is
|
||||
/// intentional (TUI death should not stall the orchestrator), but the
|
||||
/// debug log below makes backpressure visible during development.
|
||||
async fn send(&self, event: UIEvent) {
|
||||
let result = self
|
||||
.event_tx
|
||||
.send(StampedEvent {
|
||||
epoch: self.epoch,
|
||||
event,
|
||||
})
|
||||
.await;
|
||||
if result.is_err() {
|
||||
debug!("ui event dropped -- channel closed or full");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -155,7 +178,7 @@ impl<P: ModelProvider> Orchestrator<P> {
|
|||
match event {
|
||||
StreamEvent::TextDelta(chunk) => {
|
||||
text_buf.push_str(&chunk);
|
||||
let _ = self.event_tx.send(UIEvent::StreamDelta(chunk)).await;
|
||||
self.send(UIEvent::StreamDelta(chunk)).await;
|
||||
}
|
||||
StreamEvent::ToolUseStart { id, name } => {
|
||||
// Flush any accumulated text before starting tool block.
|
||||
|
|
@ -226,7 +249,7 @@ impl<P: ModelProvider> Orchestrator<P> {
|
|||
|
||||
match result {
|
||||
StreamResult::Error(msg) => {
|
||||
let _ = self.event_tx.send(UIEvent::Error(msg)).await;
|
||||
self.send(UIEvent::Error(msg)).await;
|
||||
return;
|
||||
}
|
||||
StreamResult::Done(blocks) => {
|
||||
|
|
@ -241,7 +264,7 @@ impl<P: ModelProvider> Orchestrator<P> {
|
|||
});
|
||||
|
||||
if !has_tool_use {
|
||||
let _ = self.event_tx.send(UIEvent::TurnComplete).await;
|
||||
self.send(UIEvent::TurnComplete).await;
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -269,16 +292,15 @@ impl<P: ModelProvider> Orchestrator<P> {
|
|||
role: Role::User,
|
||||
content: tool_results,
|
||||
});
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
warn!("tool-use loop reached max iterations ({MAX_TOOL_ITERATIONS})");
|
||||
let _ = self
|
||||
.event_tx
|
||||
.send(UIEvent::Error(
|
||||
"tool-use loop reached maximum iterations".to_string(),
|
||||
))
|
||||
.await;
|
||||
self.send(UIEvent::Error(
|
||||
"tool-use loop reached maximum iterations".to_string(),
|
||||
))
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Execute a single tool, handling approval if needed.
|
||||
|
|
@ -308,24 +330,20 @@ impl<P: ModelProvider> Orchestrator<P> {
|
|||
// Check approval.
|
||||
let approved = match risk {
|
||||
RiskLevel::AutoApprove => {
|
||||
let _ = self
|
||||
.event_tx
|
||||
.send(UIEvent::ToolExecuting {
|
||||
tool_name: tool_name.to_string(),
|
||||
input_summary: input_summary.clone(),
|
||||
})
|
||||
.await;
|
||||
self.send(UIEvent::ToolExecuting {
|
||||
tool_name: tool_name.to_string(),
|
||||
input_summary: input_summary.clone(),
|
||||
})
|
||||
.await;
|
||||
true
|
||||
}
|
||||
RiskLevel::RequiresApproval => {
|
||||
let _ = self
|
||||
.event_tx
|
||||
.send(UIEvent::ToolApprovalRequest {
|
||||
tool_use_id: tool_use_id.to_string(),
|
||||
tool_name: tool_name.to_string(),
|
||||
input_summary: input_summary.clone(),
|
||||
})
|
||||
.await;
|
||||
self.send(UIEvent::ToolApprovalRequest {
|
||||
tool_use_id: tool_use_id.to_string(),
|
||||
tool_name: tool_name.to_string(),
|
||||
input_summary: input_summary.clone(),
|
||||
})
|
||||
.await;
|
||||
|
||||
// Wait for approval response from TUI.
|
||||
self.wait_for_approval(tool_use_id).await
|
||||
|
|
@ -343,26 +361,22 @@ impl<P: ModelProvider> Orchestrator<P> {
|
|||
let tool = self.tool_registry.get(tool_name).unwrap();
|
||||
match tool.execute(input, &self.sandbox).await {
|
||||
Ok(output) => {
|
||||
let _ = self
|
||||
.event_tx
|
||||
.send(UIEvent::ToolResult {
|
||||
tool_name: tool_name.to_string(),
|
||||
output_summary: truncate(&output.content, 200),
|
||||
is_error: output.is_error,
|
||||
})
|
||||
.await;
|
||||
self.send(UIEvent::ToolResult {
|
||||
tool_name: tool_name.to_string(),
|
||||
output_summary: truncate(&output.content, 200),
|
||||
is_error: output.is_error,
|
||||
})
|
||||
.await;
|
||||
output
|
||||
}
|
||||
Err(e) => {
|
||||
let msg = e.to_string();
|
||||
let _ = self
|
||||
.event_tx
|
||||
.send(UIEvent::ToolResult {
|
||||
tool_name: tool_name.to_string(),
|
||||
output_summary: msg.clone(),
|
||||
is_error: true,
|
||||
})
|
||||
.await;
|
||||
self.send(UIEvent::ToolResult {
|
||||
tool_name: tool_name.to_string(),
|
||||
output_summary: msg.clone(),
|
||||
is_error: true,
|
||||
})
|
||||
.await;
|
||||
ToolOutput {
|
||||
content: msg,
|
||||
is_error: true,
|
||||
|
|
@ -391,7 +405,15 @@ impl<P: ModelProvider> Orchestrator<P> {
|
|||
UserAction::SendMessage(text) => {
|
||||
self.queued_messages.push(text);
|
||||
}
|
||||
_ => {} // discard stale approvals / ClearHistory during wait
|
||||
UserAction::ClearHistory { epoch } => {
|
||||
// Keep epoch in sync even while blocked on an approval
|
||||
// prompt. Without this, events emitted after the approval
|
||||
// resolves would carry the pre-clear epoch and be silently
|
||||
// discarded by the TUI.
|
||||
self.epoch = epoch;
|
||||
self.history.clear();
|
||||
}
|
||||
_ => {} // discard stale approvals
|
||||
}
|
||||
}
|
||||
false
|
||||
|
|
@ -402,7 +424,8 @@ impl<P: ModelProvider> Orchestrator<P> {
|
|||
while let Some(action) = self.action_rx.recv().await {
|
||||
match action {
|
||||
UserAction::Quit => break,
|
||||
UserAction::ClearHistory => {
|
||||
UserAction::ClearHistory { epoch } => {
|
||||
self.epoch = epoch;
|
||||
self.history.clear();
|
||||
}
|
||||
|
||||
|
|
@ -412,10 +435,7 @@ impl<P: ModelProvider> Orchestrator<P> {
|
|||
|
||||
UserAction::SetNetworkPolicy(allowed) => {
|
||||
self.sandbox.set_network_allowed(allowed);
|
||||
let _ = self
|
||||
.event_tx
|
||||
.send(UIEvent::NetworkPolicyChanged(allowed))
|
||||
.await;
|
||||
self.send(UIEvent::NetworkPolicyChanged(allowed)).await;
|
||||
}
|
||||
|
||||
UserAction::SendMessage(text) => {
|
||||
|
|
@ -474,7 +494,7 @@ mod tests {
|
|||
fn test_orchestrator<P: ModelProvider>(
|
||||
provider: P,
|
||||
action_rx: mpsc::Receiver<UserAction>,
|
||||
event_tx: mpsc::Sender<UIEvent>,
|
||||
event_tx: mpsc::Sender<StampedEvent>,
|
||||
) -> Orchestrator<P> {
|
||||
use crate::sandbox::policy::SandboxPolicy;
|
||||
use crate::sandbox::{EnforcementMode, Sandbox};
|
||||
|
|
@ -494,11 +514,11 @@ mod tests {
|
|||
|
||||
/// Collect all UIEvents that arrive within one orchestrator turn, stopping
|
||||
/// when the channel is drained after a `TurnComplete` or `Error`.
|
||||
async fn collect_events(rx: &mut mpsc::Receiver<UIEvent>) -> Vec<UIEvent> {
|
||||
async fn collect_events(rx: &mut mpsc::Receiver<StampedEvent>) -> Vec<UIEvent> {
|
||||
let mut out = Vec::new();
|
||||
while let Ok(ev) = rx.try_recv() {
|
||||
let done = matches!(ev, UIEvent::TurnComplete | UIEvent::Error(_));
|
||||
out.push(ev);
|
||||
while let Ok(stamped) = rx.try_recv() {
|
||||
let done = matches!(stamped.event, UIEvent::TurnComplete | UIEvent::Error(_));
|
||||
out.push(stamped.event);
|
||||
if done {
|
||||
break;
|
||||
}
|
||||
|
|
@ -525,7 +545,7 @@ mod tests {
|
|||
]);
|
||||
|
||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(16);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(16);
|
||||
|
||||
let orch = test_orchestrator(provider, action_rx, event_tx);
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
|
@ -563,7 +583,7 @@ mod tests {
|
|||
]);
|
||||
|
||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(16);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(16);
|
||||
|
||||
let orch = test_orchestrator(provider, action_rx, event_tx);
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
|
@ -605,7 +625,7 @@ mod tests {
|
|||
}
|
||||
|
||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||
let (event_tx, _event_rx) = mpsc::channel::<UIEvent>(8);
|
||||
let (event_tx, _event_rx) = mpsc::channel::<StampedEvent>(8);
|
||||
|
||||
let orch = test_orchestrator(NeverCalledProvider, action_rx, event_tx);
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
|
@ -658,7 +678,7 @@ mod tests {
|
|||
let provider = MultiTurnMock { turns };
|
||||
|
||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(32);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(32);
|
||||
|
||||
let orch = test_orchestrator(provider, action_rx, event_tx);
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
|
@ -730,7 +750,7 @@ mod tests {
|
|||
])));
|
||||
|
||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(32);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(32);
|
||||
|
||||
// Use a real ToolRegistry so read_file works.
|
||||
use crate::sandbox::policy::SandboxPolicy;
|
||||
|
|
@ -817,7 +837,7 @@ mod tests {
|
|||
]);
|
||||
|
||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(16);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(16);
|
||||
|
||||
let orch = test_orchestrator(provider, action_rx, event_tx);
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
|
@ -894,7 +914,7 @@ mod tests {
|
|||
])));
|
||||
|
||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(16);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(64);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(64);
|
||||
|
||||
let orch = test_orchestrator(MultiCallMock { turns }, action_rx, event_tx);
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
|
@ -929,7 +949,7 @@ mod tests {
|
|||
// Collect everything.
|
||||
let mut all_events = Vec::new();
|
||||
while let Ok(ev) = event_rx.try_recv() {
|
||||
all_events.push(ev);
|
||||
all_events.push(ev.event);
|
||||
}
|
||||
|
||||
// The queued message must have produced "queued reply".
|
||||
|
|
@ -962,7 +982,7 @@ mod tests {
|
|||
}
|
||||
|
||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(8);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(8);
|
||||
|
||||
let orch = test_orchestrator(NeverCalledProvider, action_rx, event_tx);
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
|
@ -974,8 +994,8 @@ mod tests {
|
|||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
|
||||
let mut found = false;
|
||||
while let Ok(ev) = event_rx.try_recv() {
|
||||
if matches!(ev, UIEvent::NetworkPolicyChanged(true)) {
|
||||
while let Ok(stamped) = event_rx.try_recv() {
|
||||
if matches!(stamped.event, UIEvent::NetworkPolicyChanged(true)) {
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue