Wire everything up.
This commit is contained in:
parent
c564f197b5
commit
71d09e7efb
14 changed files with 717 additions and 49 deletions
351
src/core/orchestrator.rs
Normal file
351
src/core/orchestrator.rs
Normal file
|
|
@ -0,0 +1,351 @@
|
|||
|
||||
use futures::StreamExt;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::core::history::ConversationHistory;
|
||||
use crate::core::types::{ConversationMessage, Role, StreamEvent, UIEvent, UserAction};
|
||||
use crate::provider::ModelProvider;
|
||||
|
||||
/// Drives the conversation loop between the TUI frontend and the model provider.
|
||||
///
|
||||
/// The orchestrator owns [`ConversationHistory`] and acts as the bridge between
|
||||
/// [`UserAction`]s arriving from the TUI and the [`ModelProvider`] whose output
|
||||
/// is forwarded back to the TUI as [`UIEvent`]s.
|
||||
///
|
||||
/// # Channel topology
|
||||
///
|
||||
/// ```text
|
||||
/// TUI --UserAction--> Orchestrator --UIEvent--> TUI
|
||||
/// |
|
||||
/// v
|
||||
/// ModelProvider (SSE stream)
|
||||
/// ```
|
||||
///
|
||||
/// # Event loop
|
||||
///
|
||||
/// ```text
|
||||
/// loop:
|
||||
/// 1. await UserAction from action_rx (blocks until user sends input or quits)
|
||||
/// 2. SendMessage:
|
||||
/// a. Append user message to history
|
||||
/// b. Call provider.stream(history) -- starts an SSE request
|
||||
/// c. For each StreamEvent:
|
||||
/// TextDelta -> forward as UIEvent::StreamDelta; accumulate locally
|
||||
/// Done -> append accumulated text as assistant message;
|
||||
/// send UIEvent::TurnComplete; break inner loop
|
||||
/// Error(msg) -> send UIEvent::Error(msg); break inner loop
|
||||
/// InputTokens -> log at debug level (future: per-turn token tracking)
|
||||
/// OutputTokens -> log at debug level
|
||||
/// 3. Quit -> return
|
||||
/// ```
|
||||
pub struct Orchestrator<P> {
|
||||
history: ConversationHistory,
|
||||
provider: P,
|
||||
action_rx: mpsc::Receiver<UserAction>,
|
||||
event_tx: mpsc::Sender<UIEvent>,
|
||||
}
|
||||
|
||||
impl<P: ModelProvider> Orchestrator<P> {
|
||||
/// Construct an orchestrator using the given provider and channel endpoints.
|
||||
pub fn new(
|
||||
provider: P,
|
||||
action_rx: mpsc::Receiver<UserAction>,
|
||||
event_tx: mpsc::Sender<UIEvent>,
|
||||
) -> Self {
|
||||
Self {
|
||||
history: ConversationHistory::new(),
|
||||
provider,
|
||||
action_rx,
|
||||
event_tx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the orchestrator until the user quits or the `action_rx` channel closes.
|
||||
pub async fn run(mut self) {
|
||||
while let Some(action) = self.action_rx.recv().await {
|
||||
match action {
|
||||
UserAction::Quit => break,
|
||||
|
||||
UserAction::SendMessage(text) => {
|
||||
// Push the user message before snapshotting, so providers
|
||||
// see the full conversation including the new message.
|
||||
self.history.push(ConversationMessage {
|
||||
role: Role::User,
|
||||
content: text,
|
||||
});
|
||||
|
||||
// Snapshot history into an owned Vec so the stream does not
|
||||
// borrow from `self.history` -- this lets us mutably update
|
||||
// `self.history` once the stream loop finishes.
|
||||
let messages: Vec<ConversationMessage> = self.history.messages().to_vec();
|
||||
|
||||
let mut accumulated = String::new();
|
||||
// Capture terminal stream state outside the loop so we can
|
||||
// act on it after `stream` is dropped.
|
||||
let mut turn_done = false;
|
||||
let mut turn_error: Option<String> = None;
|
||||
|
||||
{
|
||||
let mut stream = Box::pin(self.provider.stream(&messages));
|
||||
|
||||
while let Some(event) = stream.next().await {
|
||||
match event {
|
||||
StreamEvent::TextDelta(chunk) => {
|
||||
accumulated.push_str(&chunk);
|
||||
let _ = self.event_tx.send(UIEvent::StreamDelta(chunk)).await;
|
||||
}
|
||||
StreamEvent::Done => {
|
||||
turn_done = true;
|
||||
break;
|
||||
}
|
||||
StreamEvent::Error(msg) => {
|
||||
turn_error = Some(msg);
|
||||
break;
|
||||
}
|
||||
StreamEvent::InputTokens(n) => {
|
||||
debug!(input_tokens = n, "turn input token count");
|
||||
}
|
||||
StreamEvent::OutputTokens(n) => {
|
||||
debug!(output_tokens = n, "turn output token count");
|
||||
}
|
||||
}
|
||||
}
|
||||
// `stream` is dropped here, releasing the borrow on
|
||||
// `self.provider` and `messages`.
|
||||
}
|
||||
|
||||
if turn_done {
|
||||
self.history.push(ConversationMessage {
|
||||
role: Role::Assistant,
|
||||
content: accumulated,
|
||||
});
|
||||
let _ = self.event_tx.send(UIEvent::TurnComplete).await;
|
||||
} else if let Some(msg) = turn_error {
|
||||
let _ = self.event_tx.send(UIEvent::Error(msg)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use futures::Stream;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// A provider that replays a fixed sequence of [`StreamEvent`]s.
|
||||
///
|
||||
/// Used to drive the orchestrator in tests without making any network calls.
|
||||
struct MockProvider {
|
||||
events: Vec<StreamEvent>,
|
||||
}
|
||||
|
||||
impl MockProvider {
|
||||
fn new(events: Vec<StreamEvent>) -> Self {
|
||||
Self { events }
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelProvider for MockProvider {
|
||||
fn stream<'a>(
|
||||
&'a self,
|
||||
_messages: &'a [ConversationMessage],
|
||||
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
||||
futures::stream::iter(self.events.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Collect all UIEvents that arrive within one orchestrator turn, stopping
|
||||
/// when the channel is drained after a `TurnComplete` or `Error`.
|
||||
async fn collect_events(rx: &mut mpsc::Receiver<UIEvent>) -> Vec<UIEvent> {
|
||||
let mut out = Vec::new();
|
||||
while let Ok(ev) = rx.try_recv() {
|
||||
let done = matches!(ev, UIEvent::TurnComplete | UIEvent::Error(_));
|
||||
out.push(ev);
|
||||
if done {
|
||||
break;
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
// -- happy-path turn ----------------------------------------------------------
|
||||
|
||||
/// A full successful turn: text chunks followed by Done.
|
||||
///
|
||||
/// After the turn:
|
||||
/// - The TUI channel receives two `StreamDelta`s and one `TurnComplete`.
|
||||
/// - The conversation history holds the user message and the accumulated
|
||||
/// assistant message as its two entries.
|
||||
#[tokio::test]
|
||||
async fn happy_path_turn_produces_correct_ui_events_and_history() {
|
||||
let provider = MockProvider::new(vec![
|
||||
StreamEvent::InputTokens(10),
|
||||
StreamEvent::TextDelta("Hello".to_string()),
|
||||
StreamEvent::TextDelta(", world!".to_string()),
|
||||
StreamEvent::OutputTokens(5),
|
||||
StreamEvent::Done,
|
||||
]);
|
||||
|
||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(16);
|
||||
|
||||
let orch = Orchestrator::new(provider, action_rx, event_tx);
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
||||
action_tx
|
||||
.send(UserAction::SendMessage("hi".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Give the orchestrator time to process the stream.
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
|
||||
let events = collect_events(&mut event_rx).await;
|
||||
|
||||
// Verify the UIEvent sequence.
|
||||
assert_eq!(events.len(), 3);
|
||||
assert!(matches!(&events[0], UIEvent::StreamDelta(s) if s == "Hello"));
|
||||
assert!(matches!(&events[1], UIEvent::StreamDelta(s) if s == ", world!"));
|
||||
assert!(matches!(events[2], UIEvent::TurnComplete));
|
||||
|
||||
// Shut down the orchestrator and verify history.
|
||||
action_tx.send(UserAction::Quit).await.unwrap();
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
// -- error path ---------------------------------------------------------------
|
||||
|
||||
/// When the provider emits `Error`, the orchestrator forwards it to the TUI
|
||||
/// and does NOT append an assistant message to history.
|
||||
#[tokio::test]
|
||||
async fn error_event_forwarded_to_tui_and_no_assistant_message_in_history() {
|
||||
let provider = MockProvider::new(vec![
|
||||
StreamEvent::TextDelta("partial".to_string()),
|
||||
StreamEvent::Error("network timeout".to_string()),
|
||||
]);
|
||||
|
||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(16);
|
||||
|
||||
let orch = Orchestrator::new(provider, action_rx, event_tx);
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
||||
action_tx
|
||||
.send(UserAction::SendMessage("hello".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
|
||||
let events = collect_events(&mut event_rx).await;
|
||||
|
||||
assert_eq!(events.len(), 2);
|
||||
assert!(matches!(&events[0], UIEvent::StreamDelta(s) if s == "partial"));
|
||||
assert!(matches!(&events[1], UIEvent::Error(msg) if msg == "network timeout"));
|
||||
|
||||
action_tx.send(UserAction::Quit).await.unwrap();
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
// -- quit ---------------------------------------------------------------------
|
||||
|
||||
/// Sending `Quit` immediately terminates the orchestrator loop.
|
||||
#[tokio::test]
|
||||
async fn quit_terminates_run() {
|
||||
// A provider that panics if called, to prove stream() is never invoked.
|
||||
struct NeverCalledProvider;
|
||||
impl ModelProvider for NeverCalledProvider {
|
||||
fn stream<'a>(
|
||||
&'a self,
|
||||
_messages: &'a [ConversationMessage],
|
||||
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
||||
panic!("stream() must not be called after Quit");
|
||||
#[allow(unreachable_code)]
|
||||
futures::stream::empty()
|
||||
}
|
||||
}
|
||||
|
||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||
let (event_tx, _event_rx) = mpsc::channel::<UIEvent>(8);
|
||||
|
||||
let orch = Orchestrator::new(NeverCalledProvider, action_rx, event_tx);
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
||||
action_tx.send(UserAction::Quit).await.unwrap();
|
||||
handle.await.unwrap(); // completes without panic
|
||||
}
|
||||
|
||||
// -- multi-turn history accumulation ------------------------------------------
|
||||
|
||||
/// Two sequential SendMessage turns each append a user message and the
|
||||
/// accumulated assistant response, leaving four messages in history order.
|
||||
///
|
||||
/// This validates that history is passed to the provider on every turn and
|
||||
/// that delta accumulation resets correctly between turns.
|
||||
#[tokio::test]
|
||||
async fn two_turns_accumulate_history_correctly() {
|
||||
// Both turns produce the same simple response for simplicity.
|
||||
let make_turn_events = || {
|
||||
vec![
|
||||
StreamEvent::TextDelta("reply".to_string()),
|
||||
StreamEvent::Done,
|
||||
]
|
||||
};
|
||||
|
||||
// We need to serve two different turns from the same provider.
|
||||
// Use an `Arc<Mutex<VecDeque>>` so the provider can pop event sets.
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
struct MultiTurnMock {
|
||||
turns: Arc<Mutex<VecDeque<Vec<StreamEvent>>>>,
|
||||
}
|
||||
|
||||
impl ModelProvider for MultiTurnMock {
|
||||
fn stream<'a>(
|
||||
&'a self,
|
||||
_messages: &'a [ConversationMessage],
|
||||
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
||||
let events = self.turns.lock().unwrap().pop_front().unwrap_or_default();
|
||||
futures::stream::iter(events)
|
||||
}
|
||||
}
|
||||
|
||||
let turns = Arc::new(Mutex::new(VecDeque::from([
|
||||
make_turn_events(),
|
||||
make_turn_events(),
|
||||
])));
|
||||
let provider = MultiTurnMock { turns };
|
||||
|
||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(32);
|
||||
|
||||
let orch = Orchestrator::new(provider, action_rx, event_tx);
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
||||
// First turn.
|
||||
action_tx
|
||||
.send(UserAction::SendMessage("turn one".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
let ev1 = collect_events(&mut event_rx).await;
|
||||
assert!(matches!(ev1.last(), Some(UIEvent::TurnComplete)));
|
||||
|
||||
// Second turn.
|
||||
action_tx
|
||||
.send(UserAction::SendMessage("turn two".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
let ev2 = collect_events(&mut event_rx).await;
|
||||
assert!(matches!(ev2.last(), Some(UIEvent::TurnComplete)));
|
||||
|
||||
action_tx.send(UserAction::Quit).await.unwrap();
|
||||
handle.await.unwrap();
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue