353 lines
13 KiB
Rust
353 lines
13 KiB
Rust
use futures::StreamExt;
|
|
use tokio::sync::mpsc;
|
|
use tracing::debug;
|
|
|
|
use crate::core::history::ConversationHistory;
|
|
use crate::core::types::{ConversationMessage, Role, StreamEvent, UIEvent, UserAction};
|
|
use crate::provider::ModelProvider;
|
|
|
|
/// Drives the conversation loop between the TUI frontend and the model provider.
|
|
///
|
|
/// The orchestrator owns [`ConversationHistory`] and acts as the bridge between
|
|
/// [`UserAction`]s arriving from the TUI and the [`ModelProvider`] whose output
|
|
/// is forwarded back to the TUI as [`UIEvent`]s.
|
|
///
|
|
/// # Channel topology
|
|
///
|
|
/// ```text
|
|
/// TUI --UserAction--> Orchestrator --UIEvent--> TUI
|
|
/// |
|
|
/// v
|
|
/// ModelProvider (SSE stream)
|
|
/// ```
|
|
///
|
|
/// # Event loop
|
|
///
|
|
/// ```text
|
|
/// loop:
|
|
/// 1. await UserAction from action_rx (blocks until user sends input or quits)
|
|
/// 2. SendMessage:
|
|
/// a. Append user message to history
|
|
/// b. Call provider.stream(history) -- starts an SSE request
|
|
/// c. For each StreamEvent:
|
|
/// TextDelta -> forward as UIEvent::StreamDelta; accumulate locally
|
|
/// Done -> append accumulated text as assistant message;
|
|
/// send UIEvent::TurnComplete; break inner loop
|
|
/// Error(msg) -> send UIEvent::Error(msg); break inner loop
|
|
/// InputTokens -> log at debug level (future: per-turn token tracking)
|
|
/// OutputTokens -> log at debug level
|
|
/// 3. Quit -> return
|
|
/// ```
|
|
pub struct Orchestrator<P> {
|
|
history: ConversationHistory,
|
|
provider: P,
|
|
action_rx: mpsc::Receiver<UserAction>,
|
|
event_tx: mpsc::Sender<UIEvent>,
|
|
}
|
|
|
|
impl<P: ModelProvider> Orchestrator<P> {
|
|
/// Construct an orchestrator using the given provider and channel endpoints.
|
|
pub fn new(
|
|
provider: P,
|
|
action_rx: mpsc::Receiver<UserAction>,
|
|
event_tx: mpsc::Sender<UIEvent>,
|
|
) -> Self {
|
|
Self {
|
|
history: ConversationHistory::new(),
|
|
provider,
|
|
action_rx,
|
|
event_tx,
|
|
}
|
|
}
|
|
|
|
/// Run the orchestrator until the user quits or the `action_rx` channel closes.
|
|
pub async fn run(mut self) {
|
|
while let Some(action) = self.action_rx.recv().await {
|
|
match action {
|
|
UserAction::Quit => break,
|
|
UserAction::ClearHistory => {
|
|
self.history.clear();
|
|
}
|
|
|
|
UserAction::SendMessage(text) => {
|
|
// Push the user message before snapshotting, so providers
|
|
// see the full conversation including the new message.
|
|
self.history.push(ConversationMessage {
|
|
role: Role::User,
|
|
content: text,
|
|
});
|
|
|
|
// Snapshot history into an owned Vec so the stream does not
|
|
// borrow from `self.history` -- this lets us mutably update
|
|
// `self.history` once the stream loop finishes.
|
|
let messages: Vec<ConversationMessage> = self.history.messages().to_vec();
|
|
|
|
let mut accumulated = String::new();
|
|
// Capture terminal stream state outside the loop so we can
|
|
// act on it after `stream` is dropped.
|
|
let mut turn_done = false;
|
|
let mut turn_error: Option<String> = None;
|
|
|
|
{
|
|
let mut stream = Box::pin(self.provider.stream(&messages));
|
|
|
|
while let Some(event) = stream.next().await {
|
|
match event {
|
|
StreamEvent::TextDelta(chunk) => {
|
|
accumulated.push_str(&chunk);
|
|
let _ = self.event_tx.send(UIEvent::StreamDelta(chunk)).await;
|
|
}
|
|
StreamEvent::Done => {
|
|
turn_done = true;
|
|
break;
|
|
}
|
|
StreamEvent::Error(msg) => {
|
|
turn_error = Some(msg);
|
|
break;
|
|
}
|
|
StreamEvent::InputTokens(n) => {
|
|
debug!(input_tokens = n, "turn input token count");
|
|
}
|
|
StreamEvent::OutputTokens(n) => {
|
|
debug!(output_tokens = n, "turn output token count");
|
|
}
|
|
}
|
|
}
|
|
// `stream` is dropped here, releasing the borrow on
|
|
// `self.provider` and `messages`.
|
|
}
|
|
|
|
if turn_done {
|
|
self.history.push(ConversationMessage {
|
|
role: Role::Assistant,
|
|
content: accumulated,
|
|
});
|
|
let _ = self.event_tx.send(UIEvent::TurnComplete).await;
|
|
} else if let Some(msg) = turn_error {
|
|
let _ = self.event_tx.send(UIEvent::Error(msg)).await;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use futures::Stream;
|
|
use tokio::sync::mpsc;
|
|
|
|
/// A provider that replays a fixed sequence of [`StreamEvent`]s.
|
|
///
|
|
/// Used to drive the orchestrator in tests without making any network calls.
|
|
struct MockProvider {
|
|
events: Vec<StreamEvent>,
|
|
}
|
|
|
|
impl MockProvider {
|
|
fn new(events: Vec<StreamEvent>) -> Self {
|
|
Self { events }
|
|
}
|
|
}
|
|
|
|
impl ModelProvider for MockProvider {
|
|
fn stream<'a>(
|
|
&'a self,
|
|
_messages: &'a [ConversationMessage],
|
|
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
|
futures::stream::iter(self.events.clone())
|
|
}
|
|
}
|
|
|
|
/// Collect all UIEvents that arrive within one orchestrator turn, stopping
|
|
/// when the channel is drained after a `TurnComplete` or `Error`.
|
|
async fn collect_events(rx: &mut mpsc::Receiver<UIEvent>) -> Vec<UIEvent> {
|
|
let mut out = Vec::new();
|
|
while let Ok(ev) = rx.try_recv() {
|
|
let done = matches!(ev, UIEvent::TurnComplete | UIEvent::Error(_));
|
|
out.push(ev);
|
|
if done {
|
|
break;
|
|
}
|
|
}
|
|
out
|
|
}
|
|
|
|
// -- happy-path turn ----------------------------------------------------------
|
|
|
|
/// A full successful turn: text chunks followed by Done.
|
|
///
|
|
/// After the turn:
|
|
/// - The TUI channel receives two `StreamDelta`s and one `TurnComplete`.
|
|
/// - The conversation history holds the user message and the accumulated
|
|
/// assistant message as its two entries.
|
|
#[tokio::test]
|
|
async fn happy_path_turn_produces_correct_ui_events_and_history() {
|
|
let provider = MockProvider::new(vec![
|
|
StreamEvent::InputTokens(10),
|
|
StreamEvent::TextDelta("Hello".to_string()),
|
|
StreamEvent::TextDelta(", world!".to_string()),
|
|
StreamEvent::OutputTokens(5),
|
|
StreamEvent::Done,
|
|
]);
|
|
|
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
|
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(16);
|
|
|
|
let orch = Orchestrator::new(provider, action_rx, event_tx);
|
|
let handle = tokio::spawn(orch.run());
|
|
|
|
action_tx
|
|
.send(UserAction::SendMessage("hi".to_string()))
|
|
.await
|
|
.unwrap();
|
|
|
|
// Give the orchestrator time to process the stream.
|
|
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
|
|
|
let events = collect_events(&mut event_rx).await;
|
|
|
|
// Verify the UIEvent sequence.
|
|
assert_eq!(events.len(), 3);
|
|
assert!(matches!(&events[0], UIEvent::StreamDelta(s) if s == "Hello"));
|
|
assert!(matches!(&events[1], UIEvent::StreamDelta(s) if s == ", world!"));
|
|
assert!(matches!(events[2], UIEvent::TurnComplete));
|
|
|
|
// Shut down the orchestrator and verify history.
|
|
action_tx.send(UserAction::Quit).await.unwrap();
|
|
handle.await.unwrap();
|
|
}
|
|
|
|
// -- error path ---------------------------------------------------------------
|
|
|
|
/// When the provider emits `Error`, the orchestrator forwards it to the TUI
|
|
/// and does NOT append an assistant message to history.
|
|
#[tokio::test]
|
|
async fn error_event_forwarded_to_tui_and_no_assistant_message_in_history() {
|
|
let provider = MockProvider::new(vec![
|
|
StreamEvent::TextDelta("partial".to_string()),
|
|
StreamEvent::Error("network timeout".to_string()),
|
|
]);
|
|
|
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
|
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(16);
|
|
|
|
let orch = Orchestrator::new(provider, action_rx, event_tx);
|
|
let handle = tokio::spawn(orch.run());
|
|
|
|
action_tx
|
|
.send(UserAction::SendMessage("hello".to_string()))
|
|
.await
|
|
.unwrap();
|
|
|
|
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
|
|
|
let events = collect_events(&mut event_rx).await;
|
|
|
|
assert_eq!(events.len(), 2);
|
|
assert!(matches!(&events[0], UIEvent::StreamDelta(s) if s == "partial"));
|
|
assert!(matches!(&events[1], UIEvent::Error(msg) if msg == "network timeout"));
|
|
|
|
action_tx.send(UserAction::Quit).await.unwrap();
|
|
handle.await.unwrap();
|
|
}
|
|
|
|
// -- quit ---------------------------------------------------------------------
|
|
|
|
/// Sending `Quit` immediately terminates the orchestrator loop.
|
|
#[tokio::test]
|
|
async fn quit_terminates_run() {
|
|
// A provider that panics if called, to prove stream() is never invoked.
|
|
struct NeverCalledProvider;
|
|
impl ModelProvider for NeverCalledProvider {
|
|
fn stream<'a>(
|
|
&'a self,
|
|
_messages: &'a [ConversationMessage],
|
|
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
|
panic!("stream() must not be called after Quit");
|
|
#[allow(unreachable_code)]
|
|
futures::stream::empty()
|
|
}
|
|
}
|
|
|
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
|
let (event_tx, _event_rx) = mpsc::channel::<UIEvent>(8);
|
|
|
|
let orch = Orchestrator::new(NeverCalledProvider, action_rx, event_tx);
|
|
let handle = tokio::spawn(orch.run());
|
|
|
|
action_tx.send(UserAction::Quit).await.unwrap();
|
|
handle.await.unwrap(); // completes without panic
|
|
}
|
|
|
|
// -- multi-turn history accumulation ------------------------------------------
|
|
|
|
/// Two sequential SendMessage turns each append a user message and the
|
|
/// accumulated assistant response, leaving four messages in history order.
|
|
///
|
|
/// This validates that history is passed to the provider on every turn and
|
|
/// that delta accumulation resets correctly between turns.
|
|
#[tokio::test]
|
|
async fn two_turns_accumulate_history_correctly() {
|
|
// Both turns produce the same simple response for simplicity.
|
|
let make_turn_events = || {
|
|
vec![
|
|
StreamEvent::TextDelta("reply".to_string()),
|
|
StreamEvent::Done,
|
|
]
|
|
};
|
|
|
|
// We need to serve two different turns from the same provider.
|
|
// Use an `Arc<Mutex<VecDeque>>` so the provider can pop event sets.
|
|
use std::collections::VecDeque;
|
|
use std::sync::{Arc, Mutex};
|
|
|
|
struct MultiTurnMock {
|
|
turns: Arc<Mutex<VecDeque<Vec<StreamEvent>>>>,
|
|
}
|
|
|
|
impl ModelProvider for MultiTurnMock {
|
|
fn stream<'a>(
|
|
&'a self,
|
|
_messages: &'a [ConversationMessage],
|
|
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
|
let events = self.turns.lock().unwrap().pop_front().unwrap_or_default();
|
|
futures::stream::iter(events)
|
|
}
|
|
}
|
|
|
|
let turns = Arc::new(Mutex::new(VecDeque::from([
|
|
make_turn_events(),
|
|
make_turn_events(),
|
|
])));
|
|
let provider = MultiTurnMock { turns };
|
|
|
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
|
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(32);
|
|
|
|
let orch = Orchestrator::new(provider, action_rx, event_tx);
|
|
let handle = tokio::spawn(orch.run());
|
|
|
|
// First turn.
|
|
action_tx
|
|
.send(UserAction::SendMessage("turn one".to_string()))
|
|
.await
|
|
.unwrap();
|
|
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
|
let ev1 = collect_events(&mut event_rx).await;
|
|
assert!(matches!(ev1.last(), Some(UIEvent::TurnComplete)));
|
|
|
|
// Second turn.
|
|
action_tx
|
|
.send(UserAction::SendMessage("turn two".to_string()))
|
|
.await
|
|
.unwrap();
|
|
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
|
let ev2 = collect_events(&mut event_rx).await;
|
|
assert!(matches!(ev2.last(), Some(UIEvent::TurnComplete)));
|
|
|
|
action_tx.send(UserAction::Quit).await.unwrap();
|
|
handle.await.unwrap();
|
|
}
|
|
}
|