Skeleton for the Coding Agent. #1
15 changed files with 5071 additions and 12 deletions
37
.forgejo/workflows/ci.yml
Normal file
37
.forgejo/workflows/ci.yml
Normal file
|
|
@ -0,0 +1,37 @@
|
||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
# TODO: re-enable once runners are fixed
|
||||||
|
push:
|
||||||
|
branches: [does_not_exist] # [ main ]
|
||||||
|
# pull_request:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
ci:
|
||||||
|
runs-on: docker
|
||||||
|
container:
|
||||||
|
image: rust:latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install rustfmt and clippy
|
||||||
|
run: rustup component add rustfmt clippy
|
||||||
|
|
||||||
|
- name: Cache cargo registry
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.cargo/registry
|
||||||
|
~/.cargo/git
|
||||||
|
target
|
||||||
|
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||||
|
restore-keys: ${{ runner.os }}-cargo-
|
||||||
|
|
||||||
|
- name: Build
|
||||||
|
run: cargo build
|
||||||
|
|
||||||
|
- name: Format
|
||||||
|
run: cargo fmt --check
|
||||||
|
|
||||||
|
- name: Clippy
|
||||||
|
run: cargo clippy -- -D warnings
|
||||||
42
CLAUDE.md
42
CLAUDE.md
|
|
@ -16,28 +16,46 @@ Rust TUI coding agent. Ratatui + Crossterm + Tokio. See DESIGN.md for architectu
|
||||||
|
|
||||||
Six modules with strict boundaries:
|
Six modules with strict boundaries:
|
||||||
|
|
||||||
- `src/app/` — Wiring, lifecycle, tokio runtime setup
|
- `src/app/` -- Wiring, lifecycle, tokio runtime setup
|
||||||
- `src/tui/` — Ratatui rendering, input handling, vim modes. Communicates with core ONLY via channels (`UserAction` → core, `UIEvent` ← core). Never touches conversation state directly.
|
- `src/tui/` -- Ratatui rendering, input handling, vim modes. Communicates with core ONLY via channels (`UserAction` -> core, `UIEvent` <- core). Never touches conversation state directly.
|
||||||
- `src/core/` — Conversation tree, orchestrator loop, sub-agent lifecycle
|
- `src/core/` -- Conversation tree, orchestrator loop, sub-agent lifecycle
|
||||||
- `src/provider/` — `ModelProvider` trait + Claude implementation. Leaf module, no internal dependencies.
|
- `src/provider/` -- `ModelProvider` trait + Claude implementation. Leaf module, no internal dependencies.
|
||||||
- `src/tools/` — `Tool` trait, registry, built-in tools. Depends only on `sandbox`.
|
- `src/tools/` -- `Tool` trait, registry, built-in tools. Depends only on `sandbox`.
|
||||||
- `src/sandbox/` — Landlock policy, path validation, command execution. Leaf module.
|
- `src/sandbox/` -- Landlock policy, path validation, command execution. Leaf module.
|
||||||
- `src/session/` — JSONL logging, session read/write. Leaf module.
|
- `src/session/` -- JSONL logging, session read/write. Leaf module.
|
||||||
|
|
||||||
The channel boundary between `tui` and `core` is critical — never bypass it. The TUI is a frontend; core is the engine. This separation enables headless mode for benchmarking.
|
The channel boundary between `tui` and `core` is critical -- never bypass it. The TUI is a frontend; core is the engine. This separation enables headless mode for benchmarking.
|
||||||
|
|
||||||
## Code Style
|
## Code Style
|
||||||
|
|
||||||
- Use `thiserror` for error types, not `anyhow` in library code (`anyhow` only in `main.rs`/`app`)
|
- Use `thiserror` for error types, not `anyhow` in library code (`anyhow` only in `main.rs`/`app`)
|
||||||
- Prefer `impl Trait` return types over boxing when possible
|
- Prefer `impl Trait` return types over boxing when possible
|
||||||
- All public types need doc comments
|
- All public types need doc comments
|
||||||
- No `unwrap()` in non-test code — use `?` or explicit error handling
|
- No `unwrap()` in non-test code -- use `?` or explicit error handling
|
||||||
- Async functions should be cancel-safe where possible
|
- Async functions should be cancel-safe where possible
|
||||||
- Use `tracing` for structured logging, not `println!` or `log`
|
- Use `tracing` for structured logging, not `println!` or `log`
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
Prefer a literate style: doc comments should explain *why* and *how*, not just restate the signature.
|
||||||
|
|
||||||
|
Use only characters available on a standard US QWERTY keyboard in all doc comments and inline comments. Specifically:
|
||||||
|
- Use `->` and `<-` instead of Unicode arrow glyphs
|
||||||
|
- Use `--` instead of em dashes or en dashes
|
||||||
|
- Use `+`, `-`, `|` for ASCII box diagrams instead of Unicode box-drawing characters
|
||||||
|
- Use `...` instead of the ellipsis character
|
||||||
|
- Spell out "Section N.N" instead of the section-sign glyph
|
||||||
|
|
||||||
|
When a function or type implements an external protocol or spec:
|
||||||
|
- Document the relevant portion of the protocol inline (packet shapes, event sequences, state machines)
|
||||||
|
- Link to the authoritative external source -- API reference, RFC, WHATWG spec, etc.
|
||||||
|
- Include a mapping table or lifecycle diagram when there are multiple cases to distinguish
|
||||||
|
|
||||||
|
For example, `run_stream` in `src/provider/claude.rs` documents the full SSE event sequence in a text diagram and links to both the Anthropic streaming reference and the WHATWG SSE spec. Aim for that level of context in any code that speaks a wire format or external API.
|
||||||
|
|
||||||
## Conversation Data Model
|
## Conversation Data Model
|
||||||
|
|
||||||
Events use parent IDs forming a tree (not a flat list). This enables future branching. Every event has: id, parent_id, timestamp, event_type, token_usage. A "turn" is all events between two user messages — this is the unit for token tracking.
|
Events use parent IDs forming a tree (not a flat list). This enables future branching. Every event has: id, parent_id, timestamp, event_type, token_usage. A "turn" is all events between two user messages -- this is the unit for token tracking.
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
|
|
@ -50,8 +68,8 @@ Events use parent IDs forming a tree (not a flat list). This enables future bran
|
||||||
|
|
||||||
## Key Constraints
|
## Key Constraints
|
||||||
|
|
||||||
- All file I/O and process spawning in tools MUST go through `Sandbox` — never use `std::fs` or `std::process::Command` directly in tool implementations
|
- All file I/O and process spawning in tools MUST go through `Sandbox` -- never use `std::fs` or `std::process::Command` directly in tool implementations
|
||||||
- The `ModelProvider` trait must remain provider-agnostic — no Claude-specific types in the trait interface
|
- The `ModelProvider` trait must remain provider-agnostic -- no Claude-specific types in the trait interface
|
||||||
- Session JSONL is append-only. Never rewrite history. Branching works by writing new events with different parent IDs.
|
- Session JSONL is append-only. Never rewrite history. Branching works by writing new events with different parent IDs.
|
||||||
- Token usage must be tracked per-event and aggregatable per-turn
|
- Token usage must be tracked per-event and aggregatable per-turn
|
||||||
|
|
||||||
|
|
|
||||||
3339
Cargo.lock
generated
Normal file
3339
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
17
Cargo.toml
Normal file
17
Cargo.toml
Normal file
|
|
@ -0,0 +1,17 @@
|
||||||
|
[package]
|
||||||
|
name = "skate"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2024"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anyhow = "1"
|
||||||
|
ratatui = "0.30"
|
||||||
|
crossterm = { version = "0.29", features = ["event-stream"] }
|
||||||
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
serde_json = "1"
|
||||||
|
thiserror = "2"
|
||||||
|
tracing = "0.1"
|
||||||
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
|
reqwest = { version = "0.13", features = ["stream", "json"] }
|
||||||
|
futures = "0.3"
|
||||||
4
TODO.md
Normal file
4
TODO.md
Normal file
|
|
@ -0,0 +1,4 @@
|
||||||
|
# Cleanups
|
||||||
|
|
||||||
|
- Move keyboard/event reads in the TUI to a separate thread or async/io loop
|
||||||
|
- Keep UI and orchestrator in sync (i.e. messages display out of order if you queue up many.)
|
||||||
85
src/app/mod.rs
Normal file
85
src/app/mod.rs
Normal file
|
|
@ -0,0 +1,85 @@
|
||||||
|
//! Application wiring: tracing initialisation, channel setup, and task
|
||||||
|
//! orchestration.
|
||||||
|
//!
|
||||||
|
//! This module is the only place that knows about all subsystems simultaneously.
|
||||||
|
//! It creates the two channels that connect the TUI to the core orchestrator,
|
||||||
|
//! spawns the orchestrator as a background tokio task, and then hands control to
|
||||||
|
//! the TUI event loop on the calling task.
|
||||||
|
//!
|
||||||
|
//! # Shutdown sequence
|
||||||
|
//!
|
||||||
|
//! ```text
|
||||||
|
//! User presses Ctrl-C / Ctrl-D
|
||||||
|
//! -> tui::run sends UserAction::Quit, breaks loop, drops action_tx
|
||||||
|
//! -> restore_terminal(), tui::run returns Ok(())
|
||||||
|
//! -> app::run returns Ok(())
|
||||||
|
//! -> tokio runtime drops the spawned orchestrator task
|
||||||
|
//! (action_rx channel closed -> orchestrator recv() returns None -> run() returns)
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
mod workspace;
|
||||||
|
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use anyhow::Context;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
|
use crate::core::orchestrator::Orchestrator;
|
||||||
|
use crate::core::types::{UIEvent, UserAction};
|
||||||
|
use crate::provider::ClaudeProvider;
|
||||||
|
|
||||||
|
/// Model ID sent on every request.
|
||||||
|
///
|
||||||
|
/// See the [models overview] for the full list of available model IDs.
|
||||||
|
///
|
||||||
|
/// [models overview]: https://docs.anthropic.com/en/docs/about-claude/models/overview
|
||||||
|
const MODEL: &str = "claude-haiku-4-5";
|
||||||
|
|
||||||
|
/// Buffer capacity for the `UserAction` and `UIEvent` channels.
|
||||||
|
///
|
||||||
|
/// 64 is large enough to absorb bursts of streaming deltas without blocking the
|
||||||
|
/// orchestrator, while staying well under any memory pressure.
|
||||||
|
const CHANNEL_CAP: usize = 64;
|
||||||
|
|
||||||
|
/// Initialise tracing, wire subsystems, and run until the user quits.
|
||||||
|
///
|
||||||
|
/// Steps:
|
||||||
|
/// 1. Open (or create) the `workspace::SkateDir` and install the tracing
|
||||||
|
/// subscriber. All structured log output goes to `.skate/skate.log` --
|
||||||
|
/// writing to stdout would corrupt the TUI.
|
||||||
|
/// 2. Construct a [`ClaudeProvider`], failing fast if `ANTHROPIC_API_KEY` is
|
||||||
|
/// absent.
|
||||||
|
/// 3. Create the `UserAction` (TUI -> core) and `UIEvent` (core -> TUI) channel
|
||||||
|
/// pair.
|
||||||
|
/// 4. Spawn the [`Orchestrator`] event loop on a tokio worker task.
|
||||||
|
/// 5. Run the TUI event loop on the calling task (crossterm must not be used
|
||||||
|
/// from multiple threads concurrently).
|
||||||
|
pub async fn run(project_dir: &Path) -> anyhow::Result<()> {
|
||||||
|
// -- Tracing ------------------------------------------------------------------
|
||||||
|
workspace::SkateDir::open(project_dir)?.init_tracing()?;
|
||||||
|
|
||||||
|
tracing::info!(project_dir = %project_dir.display(), "skate starting");
|
||||||
|
|
||||||
|
// -- Provider -----------------------------------------------------------------
|
||||||
|
let provider = ClaudeProvider::from_env(MODEL)
|
||||||
|
.context("failed to construct Claude provider (is ANTHROPIC_API_KEY set?)")?;
|
||||||
|
|
||||||
|
// -- Channels -----------------------------------------------------------------
|
||||||
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(CHANNEL_CAP);
|
||||||
|
let (event_tx, event_rx) = mpsc::channel::<UIEvent>(CHANNEL_CAP);
|
||||||
|
|
||||||
|
// -- Orchestrator (background task) -------------------------------------------
|
||||||
|
let orch = Orchestrator::new(provider, action_rx, event_tx);
|
||||||
|
tokio::spawn(orch.run());
|
||||||
|
|
||||||
|
// -- TUI (foreground task) ----------------------------------------------------
|
||||||
|
// `action_tx` is moved into tui::run; when it returns (user quit), the
|
||||||
|
// sender is dropped, which closes the channel and causes the orchestrator's
|
||||||
|
// recv() loop to exit.
|
||||||
|
crate::tui::run(action_tx, event_rx)
|
||||||
|
.await
|
||||||
|
.context("TUI error")?;
|
||||||
|
|
||||||
|
tracing::info!("skate exiting cleanly");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
103
src/app/workspace.rs
Normal file
103
src/app/workspace.rs
Normal file
|
|
@ -0,0 +1,103 @@
|
||||||
|
//! `.skate/` runtime directory management.
|
||||||
|
//!
|
||||||
|
//! The `.skate/` directory lives inside the user's project and holds all
|
||||||
|
//! runtime artefacts produced by a skate session: structured logs, session
|
||||||
|
//! indices, and (in future) per-run snapshots. None of these should ever be
|
||||||
|
//! committed to the project's VCS, so the first time the directory is created
|
||||||
|
//! we drop a `.gitignore` containing `*` -- ignoring everything, including the
|
||||||
|
//! `.gitignore` itself.
|
||||||
|
//!
|
||||||
|
//! # Lifecycle
|
||||||
|
//!
|
||||||
|
//! ```text
|
||||||
|
//! app::run
|
||||||
|
//! -> SkateDir::open(project_dir) -- creates dir + .gitignore if needed
|
||||||
|
//! -> skate_dir.init_tracing() -- opens skate.log, installs subscriber
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::sync::Mutex;
|
||||||
|
|
||||||
|
use anyhow::Context;
|
||||||
|
use tracing_subscriber::EnvFilter;
|
||||||
|
|
||||||
|
/// The `.skate/` runtime directory inside a project.
|
||||||
|
///
|
||||||
|
/// Created on first use; subsequent calls are no-ops. All knowledge of
|
||||||
|
/// well-known child paths stays inside this module so callers never
|
||||||
|
/// construct them by hand.
|
||||||
|
pub struct SkateDir {
|
||||||
|
path: PathBuf,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SkateDir {
|
||||||
|
/// Open (or create) the `.skate/` directory inside `project_dir`.
|
||||||
|
///
|
||||||
|
/// On first call this also writes a `.gitignore` containing `*` so that
|
||||||
|
/// none of the runtime files are accidentally committed. Concretely:
|
||||||
|
///
|
||||||
|
/// 1. `create_dir_all` -- idempotent, works whether the dir already exists
|
||||||
|
/// or is being created for the first time.
|
||||||
|
/// 2. `OpenOptions::create_new` on `.gitignore` -- atomic write-once; the
|
||||||
|
/// `AlreadyExists` error is silently ignored so repeated calls are safe.
|
||||||
|
///
|
||||||
|
/// Returns `Err` on any I/O failure other than `AlreadyExists`.
|
||||||
|
pub fn open(project_dir: &Path) -> anyhow::Result<Self> {
|
||||||
|
let path = project_dir.join(".skate");
|
||||||
|
|
||||||
|
std::fs::create_dir_all(&path)
|
||||||
|
.with_context(|| format!("cannot create .skate directory {}", path.display()))?;
|
||||||
|
|
||||||
|
// Write .gitignore on first creation; no-op if it already exists.
|
||||||
|
// Content is "*": ignore everything in this directory including the
|
||||||
|
// .gitignore itself -- none of the skate runtime files should be committed.
|
||||||
|
let gitignore_path = path.join(".gitignore");
|
||||||
|
match std::fs::OpenOptions::new()
|
||||||
|
.write(true)
|
||||||
|
.create_new(true)
|
||||||
|
.open(&gitignore_path)
|
||||||
|
{
|
||||||
|
Ok(mut f) => {
|
||||||
|
use std::io::Write;
|
||||||
|
f.write_all(b"*\n")
|
||||||
|
.with_context(|| format!("cannot write {}", gitignore_path.display()))?;
|
||||||
|
}
|
||||||
|
Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {}
|
||||||
|
Err(e) => {
|
||||||
|
return Err(e)
|
||||||
|
.with_context(|| format!("cannot create {}", gitignore_path.display()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self { path })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Install the global `tracing` subscriber, writing to `skate.log`.
|
||||||
|
///
|
||||||
|
/// Opens (or creates) `skate.log` in append mode, then registers a
|
||||||
|
/// `tracing_subscriber::fmt` subscriber that writes structured JSON-ish
|
||||||
|
/// text to that file. Writing to stdout is not possible because the TUI
|
||||||
|
/// owns the terminal.
|
||||||
|
///
|
||||||
|
/// RUST_LOG controls verbosity; falls back to `info` if absent or
|
||||||
|
/// unparseable. Must be called at most once per process -- the underlying
|
||||||
|
/// `tracing` registry panics on a second `init()` call.
|
||||||
|
pub fn init_tracing(&self) -> anyhow::Result<()> {
|
||||||
|
let log_path = self.path.join("skate.log");
|
||||||
|
let log_file = std::fs::OpenOptions::new()
|
||||||
|
.create(true)
|
||||||
|
.append(true)
|
||||||
|
.open(&log_path)
|
||||||
|
.with_context(|| format!("cannot open log file {}", log_path.display()))?;
|
||||||
|
|
||||||
|
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
|
||||||
|
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.with_writer(Mutex::new(log_file))
|
||||||
|
.with_ansi(false)
|
||||||
|
.with_env_filter(filter)
|
||||||
|
.init();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
89
src/core/history.rs
Normal file
89
src/core/history.rs
Normal file
|
|
@ -0,0 +1,89 @@
|
||||||
|
use crate::core::types::ConversationMessage;
|
||||||
|
|
||||||
|
/// The in-memory conversation history for the current session.
|
||||||
|
///
|
||||||
|
/// Stores messages as a flat ordered list. Each [`push`][`Self::push`] appends
|
||||||
|
/// one message; [`messages`][`Self::messages`] returns a slice over all of them.
|
||||||
|
///
|
||||||
|
/// This is a flat list for Phase 1. Phases 3+ will introduce a tree structure
|
||||||
|
/// (each event carrying a `parent_id`) to support conversation branching and
|
||||||
|
/// sub-agent threads. The flat model is upward-compatible: a tree is just a
|
||||||
|
/// linear chain of parent IDs when there is no branching.
|
||||||
|
pub struct ConversationHistory {
|
||||||
|
messages: Vec<ConversationMessage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConversationHistory {
|
||||||
|
/// Create an empty history.
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
messages: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Append one message to the end of the history.
|
||||||
|
pub fn push(&mut self, message: ConversationMessage) {
|
||||||
|
self.messages.push(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the full ordered message list, oldest-first.
|
||||||
|
///
|
||||||
|
/// This slice is what gets serialised and sent to the provider on each
|
||||||
|
/// turn -- the provider needs the full prior context to generate a coherent
|
||||||
|
/// continuation.
|
||||||
|
pub fn messages(&self) -> &[ConversationMessage] {
|
||||||
|
&self.messages
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ConversationHistory {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::core::types::Role;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn new_history_is_empty() {
|
||||||
|
let history = ConversationHistory::new();
|
||||||
|
assert!(history.messages().is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn push_and_read_roundtrip() {
|
||||||
|
let mut history = ConversationHistory::new();
|
||||||
|
history.push(ConversationMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: "hello".to_string(),
|
||||||
|
});
|
||||||
|
history.push(ConversationMessage {
|
||||||
|
role: Role::Assistant,
|
||||||
|
content: "hi there".to_string(),
|
||||||
|
});
|
||||||
|
|
||||||
|
let msgs = history.messages();
|
||||||
|
assert_eq!(msgs.len(), 2);
|
||||||
|
assert_eq!(msgs[0].role, Role::User);
|
||||||
|
assert_eq!(msgs[0].content, "hello");
|
||||||
|
assert_eq!(msgs[1].role, Role::Assistant);
|
||||||
|
assert_eq!(msgs[1].content, "hi there");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn messages_preserves_insertion_order() {
|
||||||
|
let mut history = ConversationHistory::new();
|
||||||
|
for i in 0u32..5 {
|
||||||
|
history.push(ConversationMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: format!("msg {i}"),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
for (i, msg) in history.messages().iter().enumerate() {
|
||||||
|
assert_eq!(msg.content, format!("msg {i}"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
3
src/core/mod.rs
Normal file
3
src/core/mod.rs
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
pub mod history;
|
||||||
|
pub mod orchestrator;
|
||||||
|
pub mod types;
|
||||||
350
src/core/orchestrator.rs
Normal file
350
src/core/orchestrator.rs
Normal file
|
|
@ -0,0 +1,350 @@
|
||||||
|
use futures::StreamExt;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
|
use crate::core::history::ConversationHistory;
|
||||||
|
use crate::core::types::{ConversationMessage, Role, StreamEvent, UIEvent, UserAction};
|
||||||
|
use crate::provider::ModelProvider;
|
||||||
|
|
||||||
|
/// Drives the conversation loop between the TUI frontend and the model provider.
|
||||||
|
///
|
||||||
|
/// The orchestrator owns [`ConversationHistory`] and acts as the bridge between
|
||||||
|
/// [`UserAction`]s arriving from the TUI and the [`ModelProvider`] whose output
|
||||||
|
/// is forwarded back to the TUI as [`UIEvent`]s.
|
||||||
|
///
|
||||||
|
/// # Channel topology
|
||||||
|
///
|
||||||
|
/// ```text
|
||||||
|
/// TUI --UserAction--> Orchestrator --UIEvent--> TUI
|
||||||
|
/// |
|
||||||
|
/// v
|
||||||
|
/// ModelProvider (SSE stream)
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// # Event loop
|
||||||
|
///
|
||||||
|
/// ```text
|
||||||
|
/// loop:
|
||||||
|
/// 1. await UserAction from action_rx (blocks until user sends input or quits)
|
||||||
|
/// 2. SendMessage:
|
||||||
|
/// a. Append user message to history
|
||||||
|
/// b. Call provider.stream(history) -- starts an SSE request
|
||||||
|
/// c. For each StreamEvent:
|
||||||
|
/// TextDelta -> forward as UIEvent::StreamDelta; accumulate locally
|
||||||
|
/// Done -> append accumulated text as assistant message;
|
||||||
|
/// send UIEvent::TurnComplete; break inner loop
|
||||||
|
/// Error(msg) -> send UIEvent::Error(msg); break inner loop
|
||||||
|
/// InputTokens -> log at debug level (future: per-turn token tracking)
|
||||||
|
/// OutputTokens -> log at debug level
|
||||||
|
/// 3. Quit -> return
|
||||||
|
/// ```
|
||||||
|
pub struct Orchestrator<P> {
|
||||||
|
history: ConversationHistory,
|
||||||
|
provider: P,
|
||||||
|
action_rx: mpsc::Receiver<UserAction>,
|
||||||
|
event_tx: mpsc::Sender<UIEvent>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<P: ModelProvider> Orchestrator<P> {
|
||||||
|
/// Construct an orchestrator using the given provider and channel endpoints.
|
||||||
|
pub fn new(
|
||||||
|
provider: P,
|
||||||
|
action_rx: mpsc::Receiver<UserAction>,
|
||||||
|
event_tx: mpsc::Sender<UIEvent>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
history: ConversationHistory::new(),
|
||||||
|
provider,
|
||||||
|
action_rx,
|
||||||
|
event_tx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the orchestrator until the user quits or the `action_rx` channel closes.
|
||||||
|
pub async fn run(mut self) {
|
||||||
|
while let Some(action) = self.action_rx.recv().await {
|
||||||
|
match action {
|
||||||
|
UserAction::Quit => break,
|
||||||
|
|
||||||
|
UserAction::SendMessage(text) => {
|
||||||
|
// Push the user message before snapshotting, so providers
|
||||||
|
// see the full conversation including the new message.
|
||||||
|
self.history.push(ConversationMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: text,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Snapshot history into an owned Vec so the stream does not
|
||||||
|
// borrow from `self.history` -- this lets us mutably update
|
||||||
|
// `self.history` once the stream loop finishes.
|
||||||
|
let messages: Vec<ConversationMessage> = self.history.messages().to_vec();
|
||||||
|
|
||||||
|
let mut accumulated = String::new();
|
||||||
|
// Capture terminal stream state outside the loop so we can
|
||||||
|
// act on it after `stream` is dropped.
|
||||||
|
let mut turn_done = false;
|
||||||
|
let mut turn_error: Option<String> = None;
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut stream = Box::pin(self.provider.stream(&messages));
|
||||||
|
|
||||||
|
while let Some(event) = stream.next().await {
|
||||||
|
match event {
|
||||||
|
StreamEvent::TextDelta(chunk) => {
|
||||||
|
accumulated.push_str(&chunk);
|
||||||
|
let _ = self.event_tx.send(UIEvent::StreamDelta(chunk)).await;
|
||||||
|
}
|
||||||
|
StreamEvent::Done => {
|
||||||
|
turn_done = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
StreamEvent::Error(msg) => {
|
||||||
|
turn_error = Some(msg);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
StreamEvent::InputTokens(n) => {
|
||||||
|
debug!(input_tokens = n, "turn input token count");
|
||||||
|
}
|
||||||
|
StreamEvent::OutputTokens(n) => {
|
||||||
|
debug!(output_tokens = n, "turn output token count");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// `stream` is dropped here, releasing the borrow on
|
||||||
|
// `self.provider` and `messages`.
|
||||||
|
}
|
||||||
|
|
||||||
|
if turn_done {
|
||||||
|
self.history.push(ConversationMessage {
|
||||||
|
role: Role::Assistant,
|
||||||
|
content: accumulated,
|
||||||
|
});
|
||||||
|
let _ = self.event_tx.send(UIEvent::TurnComplete).await;
|
||||||
|
} else if let Some(msg) = turn_error {
|
||||||
|
let _ = self.event_tx.send(UIEvent::Error(msg)).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use futures::Stream;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
|
/// A provider that replays a fixed sequence of [`StreamEvent`]s.
|
||||||
|
///
|
||||||
|
/// Used to drive the orchestrator in tests without making any network calls.
|
||||||
|
struct MockProvider {
|
||||||
|
events: Vec<StreamEvent>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MockProvider {
|
||||||
|
fn new(events: Vec<StreamEvent>) -> Self {
|
||||||
|
Self { events }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelProvider for MockProvider {
|
||||||
|
fn stream<'a>(
|
||||||
|
&'a self,
|
||||||
|
_messages: &'a [ConversationMessage],
|
||||||
|
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
||||||
|
futures::stream::iter(self.events.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Collect all UIEvents that arrive within one orchestrator turn, stopping
|
||||||
|
/// when the channel is drained after a `TurnComplete` or `Error`.
|
||||||
|
async fn collect_events(rx: &mut mpsc::Receiver<UIEvent>) -> Vec<UIEvent> {
|
||||||
|
let mut out = Vec::new();
|
||||||
|
while let Ok(ev) = rx.try_recv() {
|
||||||
|
let done = matches!(ev, UIEvent::TurnComplete | UIEvent::Error(_));
|
||||||
|
out.push(ev);
|
||||||
|
if done {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
// -- happy-path turn ----------------------------------------------------------
|
||||||
|
|
||||||
|
/// A full successful turn: text chunks followed by Done.
|
||||||
|
///
|
||||||
|
/// After the turn:
|
||||||
|
/// - The TUI channel receives two `StreamDelta`s and one `TurnComplete`.
|
||||||
|
/// - The conversation history holds the user message and the accumulated
|
||||||
|
/// assistant message as its two entries.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn happy_path_turn_produces_correct_ui_events_and_history() {
|
||||||
|
let provider = MockProvider::new(vec![
|
||||||
|
StreamEvent::InputTokens(10),
|
||||||
|
StreamEvent::TextDelta("Hello".to_string()),
|
||||||
|
StreamEvent::TextDelta(", world!".to_string()),
|
||||||
|
StreamEvent::OutputTokens(5),
|
||||||
|
StreamEvent::Done,
|
||||||
|
]);
|
||||||
|
|
||||||
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||||
|
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(16);
|
||||||
|
|
||||||
|
let orch = Orchestrator::new(provider, action_rx, event_tx);
|
||||||
|
let handle = tokio::spawn(orch.run());
|
||||||
|
|
||||||
|
action_tx
|
||||||
|
.send(UserAction::SendMessage("hi".to_string()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Give the orchestrator time to process the stream.
|
||||||
|
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||||
|
|
||||||
|
let events = collect_events(&mut event_rx).await;
|
||||||
|
|
||||||
|
// Verify the UIEvent sequence.
|
||||||
|
assert_eq!(events.len(), 3);
|
||||||
|
assert!(matches!(&events[0], UIEvent::StreamDelta(s) if s == "Hello"));
|
||||||
|
assert!(matches!(&events[1], UIEvent::StreamDelta(s) if s == ", world!"));
|
||||||
|
assert!(matches!(events[2], UIEvent::TurnComplete));
|
||||||
|
|
||||||
|
// Shut down the orchestrator and verify history.
|
||||||
|
action_tx.send(UserAction::Quit).await.unwrap();
|
||||||
|
handle.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// -- error path ---------------------------------------------------------------
|
||||||
|
|
||||||
|
/// When the provider emits `Error`, the orchestrator forwards it to the TUI
|
||||||
|
/// and does NOT append an assistant message to history.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn error_event_forwarded_to_tui_and_no_assistant_message_in_history() {
|
||||||
|
let provider = MockProvider::new(vec![
|
||||||
|
StreamEvent::TextDelta("partial".to_string()),
|
||||||
|
StreamEvent::Error("network timeout".to_string()),
|
||||||
|
]);
|
||||||
|
|
||||||
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||||
|
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(16);
|
||||||
|
|
||||||
|
let orch = Orchestrator::new(provider, action_rx, event_tx);
|
||||||
|
let handle = tokio::spawn(orch.run());
|
||||||
|
|
||||||
|
action_tx
|
||||||
|
.send(UserAction::SendMessage("hello".to_string()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||||
|
|
||||||
|
let events = collect_events(&mut event_rx).await;
|
||||||
|
|
||||||
|
assert_eq!(events.len(), 2);
|
||||||
|
assert!(matches!(&events[0], UIEvent::StreamDelta(s) if s == "partial"));
|
||||||
|
assert!(matches!(&events[1], UIEvent::Error(msg) if msg == "network timeout"));
|
||||||
|
|
||||||
|
action_tx.send(UserAction::Quit).await.unwrap();
|
||||||
|
handle.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// -- quit ---------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Sending `Quit` immediately terminates the orchestrator loop.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn quit_terminates_run() {
|
||||||
|
// A provider that panics if called, to prove stream() is never invoked.
|
||||||
|
struct NeverCalledProvider;
|
||||||
|
impl ModelProvider for NeverCalledProvider {
|
||||||
|
fn stream<'a>(
|
||||||
|
&'a self,
|
||||||
|
_messages: &'a [ConversationMessage],
|
||||||
|
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
||||||
|
panic!("stream() must not be called after Quit");
|
||||||
|
#[allow(unreachable_code)]
|
||||||
|
futures::stream::empty()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||||
|
let (event_tx, _event_rx) = mpsc::channel::<UIEvent>(8);
|
||||||
|
|
||||||
|
let orch = Orchestrator::new(NeverCalledProvider, action_rx, event_tx);
|
||||||
|
let handle = tokio::spawn(orch.run());
|
||||||
|
|
||||||
|
action_tx.send(UserAction::Quit).await.unwrap();
|
||||||
|
handle.await.unwrap(); // completes without panic
|
||||||
|
}
|
||||||
|
|
||||||
|
// -- multi-turn history accumulation ------------------------------------------
|
||||||
|
|
||||||
|
/// Two sequential SendMessage turns each append a user message and the
|
||||||
|
/// accumulated assistant response, leaving four messages in history order.
|
||||||
|
///
|
||||||
|
/// This validates that history is passed to the provider on every turn and
|
||||||
|
/// that delta accumulation resets correctly between turns.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn two_turns_accumulate_history_correctly() {
|
||||||
|
// Both turns produce the same simple response for simplicity.
|
||||||
|
let make_turn_events = || {
|
||||||
|
vec![
|
||||||
|
StreamEvent::TextDelta("reply".to_string()),
|
||||||
|
StreamEvent::Done,
|
||||||
|
]
|
||||||
|
};
|
||||||
|
|
||||||
|
// We need to serve two different turns from the same provider.
|
||||||
|
// Use an `Arc<Mutex<VecDeque>>` so the provider can pop event sets.
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
struct MultiTurnMock {
|
||||||
|
turns: Arc<Mutex<VecDeque<Vec<StreamEvent>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelProvider for MultiTurnMock {
|
||||||
|
fn stream<'a>(
|
||||||
|
&'a self,
|
||||||
|
_messages: &'a [ConversationMessage],
|
||||||
|
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
||||||
|
let events = self.turns.lock().unwrap().pop_front().unwrap_or_default();
|
||||||
|
futures::stream::iter(events)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let turns = Arc::new(Mutex::new(VecDeque::from([
|
||||||
|
make_turn_events(),
|
||||||
|
make_turn_events(),
|
||||||
|
])));
|
||||||
|
let provider = MultiTurnMock { turns };
|
||||||
|
|
||||||
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||||
|
let (event_tx, mut event_rx) = mpsc::channel::<UIEvent>(32);
|
||||||
|
|
||||||
|
let orch = Orchestrator::new(provider, action_rx, event_tx);
|
||||||
|
let handle = tokio::spawn(orch.run());
|
||||||
|
|
||||||
|
// First turn.
|
||||||
|
action_tx
|
||||||
|
.send(UserAction::SendMessage("turn one".to_string()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||||
|
let ev1 = collect_events(&mut event_rx).await;
|
||||||
|
assert!(matches!(ev1.last(), Some(UIEvent::TurnComplete)));
|
||||||
|
|
||||||
|
// Second turn.
|
||||||
|
action_tx
|
||||||
|
.send(UserAction::SendMessage("turn two".to_string()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||||
|
let ev2 = collect_events(&mut event_rx).await;
|
||||||
|
assert!(matches!(ev2.last(), Some(UIEvent::TurnComplete)));
|
||||||
|
|
||||||
|
action_tx.send(UserAction::Quit).await.unwrap();
|
||||||
|
handle.await.unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
53
src/core/types.rs
Normal file
53
src/core/types.rs
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
/// A streaming event emitted by the model provider.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum StreamEvent {
|
||||||
|
/// A text chunk from the assistant's response.
|
||||||
|
TextDelta(String),
|
||||||
|
/// Number of input tokens used in this request.
|
||||||
|
InputTokens(u32),
|
||||||
|
/// Number of output tokens generated so far.
|
||||||
|
OutputTokens(u32),
|
||||||
|
/// The stream has completed successfully.
|
||||||
|
Done,
|
||||||
|
/// An error occurred during streaming.
|
||||||
|
Error(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An action sent from the TUI to the core orchestrator.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum UserAction {
|
||||||
|
/// The user has submitted a message.
|
||||||
|
SendMessage(String),
|
||||||
|
/// The user has requested to quit.
|
||||||
|
Quit,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An event sent from the core orchestrator to the TUI.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum UIEvent {
|
||||||
|
/// A text chunk to append to the current assistant message.
|
||||||
|
StreamDelta(String),
|
||||||
|
/// The current assistant turn has completed.
|
||||||
|
TurnComplete,
|
||||||
|
/// An error to display to the user.
|
||||||
|
Error(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The role of a participant in a conversation.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum Role {
|
||||||
|
/// A message from the human user.
|
||||||
|
User,
|
||||||
|
/// A message from the AI assistant.
|
||||||
|
Assistant,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A single message in the conversation history.
|
||||||
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||||
|
pub struct ConversationMessage {
|
||||||
|
/// The role of the message author.
|
||||||
|
pub role: Role,
|
||||||
|
/// The text content of the message.
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
39
src/main.rs
Normal file
39
src/main.rs
Normal file
|
|
@ -0,0 +1,39 @@
|
||||||
|
mod app;
|
||||||
|
mod core;
|
||||||
|
mod provider;
|
||||||
|
mod tui;
|
||||||
|
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use anyhow::Context;
|
||||||
|
|
||||||
|
/// Run skate against a project directory.
|
||||||
|
///
|
||||||
|
/// ```text
|
||||||
|
/// Usage: skate --project-dir <path>
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// `ANTHROPIC_API_KEY` must be set in the environment.
|
||||||
|
/// `RUST_LOG` controls log verbosity (default: `info`); logs go to
|
||||||
|
/// `<project-dir>/.skate/skate.log`.
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> anyhow::Result<()> {
|
||||||
|
let project_dir = parse_project_dir()?;
|
||||||
|
app::run(&project_dir).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract the value of `--project-dir` from `argv`.
|
||||||
|
///
|
||||||
|
/// Returns an error if the flag is absent or is not followed by a value.
|
||||||
|
fn parse_project_dir() -> anyhow::Result<PathBuf> {
|
||||||
|
let mut args = std::env::args().skip(1); // skip the binary name
|
||||||
|
while let Some(arg) = args.next() {
|
||||||
|
if arg == "--project-dir" {
|
||||||
|
let value = args
|
||||||
|
.next()
|
||||||
|
.context("--project-dir requires a path argument")?;
|
||||||
|
return Ok(PathBuf::from(value));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
anyhow::bail!("Usage: skate --project-dir <path>")
|
||||||
|
}
|
||||||
443
src/provider/claude.rs
Normal file
443
src/provider/claude.rs
Normal file
|
|
@ -0,0 +1,443 @@
|
||||||
|
use futures::{SinkExt, Stream, StreamExt};
|
||||||
|
use reqwest::Client;
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
use crate::core::types::{ConversationMessage, StreamEvent};
|
||||||
|
|
||||||
|
use super::ModelProvider;
|
||||||
|
|
||||||
|
/// Errors that can occur when constructing or using a [`ClaudeProvider`].
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum ClaudeProviderError {
|
||||||
|
/// The `ANTHROPIC_API_KEY` environment variable is not set.
|
||||||
|
#[error("ANTHROPIC_API_KEY environment variable not set")]
|
||||||
|
MissingApiKey,
|
||||||
|
/// An HTTP-level error from the reqwest client.
|
||||||
|
#[error("HTTP request failed: {0}")]
|
||||||
|
Http(#[from] reqwest::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// [`ModelProvider`] implementation that streams responses from the Anthropic Messages API.
|
||||||
|
///
|
||||||
|
/// Calls `POST /v1/messages` with `"stream": true` and parses the resulting
|
||||||
|
/// [Server-Sent Events][sse] stream into [`StreamEvent`]s.
|
||||||
|
///
|
||||||
|
/// # Authentication
|
||||||
|
///
|
||||||
|
/// Reads the API key from the `ANTHROPIC_API_KEY` environment variable.
|
||||||
|
/// See the [Anthropic authentication docs][auth] for how to obtain a key.
|
||||||
|
///
|
||||||
|
/// # API version
|
||||||
|
///
|
||||||
|
/// Sends the `anthropic-version: 2023-06-01` header on every request, which is
|
||||||
|
/// the stable baseline version required by the API. See the
|
||||||
|
/// [versioning docs][versioning] for details on how Anthropic handles API versions.
|
||||||
|
///
|
||||||
|
/// [sse]: https://html.spec.whatwg.org/multipage/server-sent-events.html
|
||||||
|
/// [auth]: https://docs.anthropic.com/en/api/getting-started#authentication
|
||||||
|
/// [versioning]: https://docs.anthropic.com/en/api/versioning
|
||||||
|
pub struct ClaudeProvider {
|
||||||
|
api_key: String,
|
||||||
|
client: Client,
|
||||||
|
model: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClaudeProvider {
|
||||||
|
/// Create a `ClaudeProvider` reading `ANTHROPIC_API_KEY` from the environment.
|
||||||
|
/// The caller must supply the model ID (e.g. `"claude-opus-4-6"`).
|
||||||
|
///
|
||||||
|
/// See the [models overview][models] for available model IDs.
|
||||||
|
///
|
||||||
|
/// [models]: https://docs.anthropic.com/en/docs/about-claude/models/overview
|
||||||
|
pub fn from_env(model: impl Into<String>) -> Result<Self, ClaudeProviderError> {
|
||||||
|
let api_key =
|
||||||
|
std::env::var("ANTHROPIC_API_KEY").map_err(|_| ClaudeProviderError::MissingApiKey)?;
|
||||||
|
Ok(Self {
|
||||||
|
api_key,
|
||||||
|
client: Client::new(),
|
||||||
|
model: model.into(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelProvider for ClaudeProvider {
|
||||||
|
fn stream<'a>(
|
||||||
|
&'a self,
|
||||||
|
messages: &'a [ConversationMessage],
|
||||||
|
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
||||||
|
let (mut tx, rx) = futures::channel::mpsc::channel(32);
|
||||||
|
let client = self.client.clone();
|
||||||
|
let api_key = self.api_key.clone();
|
||||||
|
let model = self.model.clone();
|
||||||
|
let messages = messages.to_vec();
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
run_stream(client, api_key, model, messages, &mut tx).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
rx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// POST to `/v1/messages` with `stream: true`, then parse the SSE response into
|
||||||
|
/// [`StreamEvent`]s and forward them to `tx`.
|
||||||
|
///
|
||||||
|
/// # Request shape
|
||||||
|
///
|
||||||
|
/// ```json
|
||||||
|
/// {
|
||||||
|
/// "model": "<model-id>",
|
||||||
|
/// "max_tokens": 8192,
|
||||||
|
/// "stream": true,
|
||||||
|
/// "messages": [{ "role": "user"|"assistant", "content": "<text>" }, ...]
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// See the [Messages API reference][messages-api] for the full schema.
|
||||||
|
///
|
||||||
|
/// # SSE stream lifecycle
|
||||||
|
///
|
||||||
|
/// With streaming enabled the API sends a sequence of
|
||||||
|
/// [Server-Sent Events][sse] separated by blank lines (`\n\n`). Each event
|
||||||
|
/// has an `event:` line naming its type and a `data:` line containing a JSON
|
||||||
|
/// object. The full event sequence for a successful turn is:
|
||||||
|
///
|
||||||
|
/// ```text
|
||||||
|
/// event: message_start -> InputTokens(n)
|
||||||
|
/// event: content_block_start -> (ignored -- signals a new content block)
|
||||||
|
/// event: ping -> (ignored -- keepalive)
|
||||||
|
/// event: content_block_delta -> TextDelta(chunk) (repeated)
|
||||||
|
/// event: content_block_stop -> (ignored -- signals end of content block)
|
||||||
|
/// event: message_delta -> OutputTokens(n)
|
||||||
|
/// event: message_stop -> Done
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// We stop reading as soon as `Done` is emitted; any bytes arriving after
|
||||||
|
/// `message_stop` are discarded.
|
||||||
|
///
|
||||||
|
/// See the [streaming reference][streaming] for the authoritative description
|
||||||
|
/// of each event type and its JSON payload.
|
||||||
|
///
|
||||||
|
/// [messages-api]: https://docs.anthropic.com/en/api/messages
|
||||||
|
/// [streaming]: https://docs.anthropic.com/en/api/messages-streaming
|
||||||
|
/// [sse]: https://html.spec.whatwg.org/multipage/server-sent-events.html
|
||||||
|
async fn run_stream(
|
||||||
|
client: Client,
|
||||||
|
api_key: String,
|
||||||
|
model: String,
|
||||||
|
messages: Vec<ConversationMessage>,
|
||||||
|
tx: &mut futures::channel::mpsc::Sender<StreamEvent>,
|
||||||
|
) {
|
||||||
|
let body = serde_json::json!({
|
||||||
|
"model": model,
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"stream": true,
|
||||||
|
"messages": messages,
|
||||||
|
});
|
||||||
|
|
||||||
|
let response = match client
|
||||||
|
.post("https://api.anthropic.com/v1/messages")
|
||||||
|
.header("x-api-key", &api_key)
|
||||||
|
.header("anthropic-version", "2023-06-01")
|
||||||
|
.header("content-type", "application/json")
|
||||||
|
.json(&body)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
let _ = tx.send(StreamEvent::Error(e.to_string())).await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
let status = response.status();
|
||||||
|
let body_text = match response.text().await {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => format!("(failed to read error body: {e})"),
|
||||||
|
};
|
||||||
|
let _ = tx
|
||||||
|
.send(StreamEvent::Error(format!("HTTP {status}: {body_text}")))
|
||||||
|
.await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut stream = response.bytes_stream();
|
||||||
|
let mut buffer: Vec<u8> = Vec::new();
|
||||||
|
|
||||||
|
while let Some(chunk) = stream.next().await {
|
||||||
|
match chunk {
|
||||||
|
Err(e) => {
|
||||||
|
let _ = tx.send(StreamEvent::Error(e.to_string())).await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Ok(bytes) => {
|
||||||
|
buffer.extend_from_slice(&bytes);
|
||||||
|
// Drain complete SSE events (delimited by blank lines).
|
||||||
|
loop {
|
||||||
|
match find_double_newline(&buffer) {
|
||||||
|
None => break,
|
||||||
|
Some(pos) => {
|
||||||
|
let event_bytes: Vec<u8> = buffer.drain(..pos + 2).collect();
|
||||||
|
let event_str = String::from_utf8_lossy(&event_bytes);
|
||||||
|
if let Some(event) = parse_sse_event(&event_str) {
|
||||||
|
let is_done = matches!(event, StreamEvent::Done);
|
||||||
|
let _ = tx.send(event).await;
|
||||||
|
if is_done {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = tx.send(StreamEvent::Done).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the byte offset of the first `\n\n` in `buf`, or `None`.
|
||||||
|
///
|
||||||
|
/// SSE uses a blank line (two consecutive newlines) as the event boundary.
|
||||||
|
/// See [Section 9.2.6 of the SSE spec][sse-dispatch].
|
||||||
|
///
|
||||||
|
/// [sse-dispatch]: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation
|
||||||
|
fn find_double_newline(buf: &[u8]) -> Option<usize> {
|
||||||
|
buf.windows(2).position(|w| w == b"\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// -- SSE JSON types -----------------------------------------------------------
|
||||||
|
//
|
||||||
|
// These structs mirror the subset of the Anthropic SSE payload we actually
|
||||||
|
// consume. Unknown fields are silently ignored by serde. Full schemas are
|
||||||
|
// documented in the [streaming reference][streaming].
|
||||||
|
//
|
||||||
|
// [streaming]: https://docs.anthropic.com/en/api/messages-streaming
|
||||||
|
|
||||||
|
/// Top-level SSE data object. The `type` field selects which other fields
|
||||||
|
/// are present; we use `Option` for all of them so a single struct covers
|
||||||
|
/// every event type without needing an enum.
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
struct SseEvent {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
event_type: String,
|
||||||
|
/// Present on `content_block_delta` events.
|
||||||
|
delta: Option<SseDelta>,
|
||||||
|
/// Present on `message_start` events; carries initial token usage.
|
||||||
|
message: Option<SseMessageStart>,
|
||||||
|
/// Present on `message_delta` events; carries final output token count.
|
||||||
|
usage: Option<SseUsage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The `delta` object inside a `content_block_delta` event.
|
||||||
|
///
|
||||||
|
/// `type` is `"text_delta"` for plain text chunks; other delta types
|
||||||
|
/// (e.g. `"input_json_delta"` for tool-use blocks) are not yet handled.
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
struct SseDelta {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
delta_type: Option<String>,
|
||||||
|
/// The text chunk; present when `delta_type == "text_delta"`.
|
||||||
|
text: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The `message` object inside a `message_start` event.
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
struct SseMessageStart {
|
||||||
|
usage: Option<SseUsage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Token counts reported at the start and end of a turn.
|
||||||
|
///
|
||||||
|
/// `input_tokens` is set in the `message_start` event;
|
||||||
|
/// `output_tokens` is set in the `message_delta` event.
|
||||||
|
/// Both fields are `Option` so the same struct works for both events.
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
struct SseUsage {
|
||||||
|
input_tokens: Option<u32>,
|
||||||
|
output_tokens: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a single SSE event string into a [`StreamEvent`], returning `None` for
|
||||||
|
/// event types we don't care about (`ping`, `content_block_start`,
|
||||||
|
/// `content_block_stop`).
|
||||||
|
///
|
||||||
|
/// # SSE format
|
||||||
|
///
|
||||||
|
/// Each event is a block of `field: value` lines. We only read the `data:`
|
||||||
|
/// field; the `event:` line is redundant with the `type` key inside the JSON
|
||||||
|
/// payload so we ignore it. See the [SSE spec][sse-fields] for the full field
|
||||||
|
/// grammar.
|
||||||
|
///
|
||||||
|
/// # Mapping to [`StreamEvent`]
|
||||||
|
///
|
||||||
|
/// | API event type | JSON path | Emits |
|
||||||
|
/// |----------------------|------------------------------------|------------------------------|
|
||||||
|
/// | `message_start` | `.message.usage.input_tokens` | `InputTokens(n)` |
|
||||||
|
/// | `content_block_delta`| `.delta.type == "text_delta"` | `TextDelta(chunk)` |
|
||||||
|
/// | `message_delta` | `.usage.output_tokens` | `OutputTokens(n)` |
|
||||||
|
/// | `message_stop` | n/a | `Done` |
|
||||||
|
/// | everything else | n/a | `None` (caller skips) |
|
||||||
|
///
|
||||||
|
/// [sse-fields]: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream
|
||||||
|
fn parse_sse_event(event_str: &str) -> Option<StreamEvent> {
|
||||||
|
// SSE events may have multiple fields; we only need `data:`.
|
||||||
|
let data = event_str
|
||||||
|
.lines()
|
||||||
|
.find_map(|line| line.strip_prefix("data: "))?;
|
||||||
|
|
||||||
|
let event: SseEvent = serde_json::from_str(data).ok()?;
|
||||||
|
|
||||||
|
match event.event_type.as_str() {
|
||||||
|
"message_start" => event
|
||||||
|
.message
|
||||||
|
.and_then(|m| m.usage)
|
||||||
|
.and_then(|u| u.input_tokens)
|
||||||
|
.map(StreamEvent::InputTokens),
|
||||||
|
|
||||||
|
"content_block_delta" => {
|
||||||
|
let delta = event.delta?;
|
||||||
|
if delta.delta_type.as_deref() == Some("text_delta") {
|
||||||
|
delta.text.map(StreamEvent::TextDelta)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// usage lives at the top level of message_delta, not inside delta.
|
||||||
|
"message_delta" => event
|
||||||
|
.usage
|
||||||
|
.and_then(|u| u.output_tokens)
|
||||||
|
.map(StreamEvent::OutputTokens),
|
||||||
|
|
||||||
|
"message_stop" => Some(StreamEvent::Done),
|
||||||
|
|
||||||
|
// error, ping, content_block_start, content_block_stop -- ignored or
|
||||||
|
// handled by the caller.
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -- Tests --------------------------------------------------------------------
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::core::types::Role;
|
||||||
|
|
||||||
|
/// A minimal but complete Anthropic SSE fixture.
|
||||||
|
const SSE_FIXTURE: &str = concat!(
|
||||||
|
"event: message_start\n",
|
||||||
|
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"type\":\"message\",",
|
||||||
|
"\"role\":\"assistant\",\"content\":[],\"model\":\"claude-opus-4-6\",",
|
||||||
|
"\"stop_reason\":null,\"stop_sequence\":null,",
|
||||||
|
"\"usage\":{\"input_tokens\":10,\"cache_creation_input_tokens\":0,",
|
||||||
|
"\"cache_read_input_tokens\":0,\"output_tokens\":1}}}\n",
|
||||||
|
"\n",
|
||||||
|
"event: content_block_start\n",
|
||||||
|
"data: {\"type\":\"content_block_start\",\"index\":0,",
|
||||||
|
"\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n",
|
||||||
|
"\n",
|
||||||
|
"event: ping\n",
|
||||||
|
"data: {\"type\":\"ping\"}\n",
|
||||||
|
"\n",
|
||||||
|
"event: content_block_delta\n",
|
||||||
|
"data: {\"type\":\"content_block_delta\",\"index\":0,",
|
||||||
|
"\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n",
|
||||||
|
"\n",
|
||||||
|
"event: content_block_delta\n",
|
||||||
|
"data: {\"type\":\"content_block_delta\",\"index\":0,",
|
||||||
|
"\"delta\":{\"type\":\"text_delta\",\"text\":\", world!\"}}\n",
|
||||||
|
"\n",
|
||||||
|
"event: content_block_stop\n",
|
||||||
|
"data: {\"type\":\"content_block_stop\",\"index\":0}\n",
|
||||||
|
"\n",
|
||||||
|
"event: message_delta\n",
|
||||||
|
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",",
|
||||||
|
"\"stop_sequence\":null},\"usage\":{\"output_tokens\":5}}\n",
|
||||||
|
"\n",
|
||||||
|
"event: message_stop\n",
|
||||||
|
"data: {\"type\":\"message_stop\"}\n",
|
||||||
|
"\n",
|
||||||
|
);
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_sse_events_from_fixture() {
|
||||||
|
let events: Vec<StreamEvent> = SSE_FIXTURE
|
||||||
|
.split("\n\n")
|
||||||
|
.filter(|s| !s.trim().is_empty())
|
||||||
|
.filter_map(parse_sse_event)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// content_block_start, ping, content_block_stop -> None (filtered out)
|
||||||
|
assert_eq!(events.len(), 5);
|
||||||
|
assert!(matches!(events[0], StreamEvent::InputTokens(10)));
|
||||||
|
assert!(matches!(&events[1], StreamEvent::TextDelta(s) if s == "Hello"));
|
||||||
|
assert!(matches!(&events[2], StreamEvent::TextDelta(s) if s == ", world!"));
|
||||||
|
assert!(matches!(events[3], StreamEvent::OutputTokens(5)));
|
||||||
|
assert!(matches!(events[4], StreamEvent::Done));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_message_stop_yields_done() {
|
||||||
|
let event_str = "event: message_stop\ndata: {\"type\":\"message_stop\"}\n";
|
||||||
|
assert!(matches!(
|
||||||
|
parse_sse_event(event_str),
|
||||||
|
Some(StreamEvent::Done)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_ping_yields_none() {
|
||||||
|
let event_str = "event: ping\ndata: {\"type\":\"ping\"}\n";
|
||||||
|
assert!(parse_sse_event(event_str).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_content_block_start_yields_none() {
|
||||||
|
let event_str = concat!(
|
||||||
|
"event: content_block_start\n",
|
||||||
|
"data: {\"type\":\"content_block_start\",\"index\":0,",
|
||||||
|
"\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n",
|
||||||
|
);
|
||||||
|
assert!(parse_sse_event(event_str).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_messages_serialize_to_anthropic_format() {
|
||||||
|
let messages = vec![
|
||||||
|
ConversationMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: "Hello".to_string(),
|
||||||
|
},
|
||||||
|
ConversationMessage {
|
||||||
|
role: Role::Assistant,
|
||||||
|
content: "Hi there!".to_string(),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
let json = serde_json::json!({
|
||||||
|
"model": "claude-opus-4-6",
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"stream": true,
|
||||||
|
"messages": messages,
|
||||||
|
});
|
||||||
|
|
||||||
|
assert_eq!(json["messages"][0]["role"], "user");
|
||||||
|
assert_eq!(json["messages"][0]["content"], "Hello");
|
||||||
|
assert_eq!(json["messages"][1]["role"], "assistant");
|
||||||
|
assert_eq!(json["messages"][1]["content"], "Hi there!");
|
||||||
|
assert_eq!(json["stream"], true);
|
||||||
|
assert!(json["max_tokens"].as_u64().unwrap() > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_find_double_newline() {
|
||||||
|
assert_eq!(find_double_newline(b"abc\n\ndef"), Some(3));
|
||||||
|
assert_eq!(find_double_newline(b"abc\ndef"), None);
|
||||||
|
assert_eq!(find_double_newline(b"\n\n"), Some(0));
|
||||||
|
assert_eq!(find_double_newline(b""), None);
|
||||||
|
}
|
||||||
|
}
|
||||||
19
src/provider/mod.rs
Normal file
19
src/provider/mod.rs
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
mod claude;
|
||||||
|
|
||||||
|
pub use claude::ClaudeProvider;
|
||||||
|
|
||||||
|
use futures::Stream;
|
||||||
|
|
||||||
|
use crate::core::types::{ConversationMessage, StreamEvent};
|
||||||
|
|
||||||
|
/// Trait for model providers that can stream conversation responses.
|
||||||
|
///
|
||||||
|
/// Implementors take a conversation history and return a stream of [`StreamEvent`]s.
|
||||||
|
/// The trait is provider-agnostic -- no Claude-specific types appear here.
|
||||||
|
pub trait ModelProvider: Send + Sync {
|
||||||
|
/// Stream a response from the model given the conversation history.
|
||||||
|
fn stream<'a>(
|
||||||
|
&'a self,
|
||||||
|
messages: &'a [ConversationMessage],
|
||||||
|
) -> impl Stream<Item = StreamEvent> + Send + 'a;
|
||||||
|
}
|
||||||
460
src/tui/mod.rs
Normal file
460
src/tui/mod.rs
Normal file
|
|
@ -0,0 +1,460 @@
|
||||||
|
//! TUI frontend: terminal lifecycle, rendering, and input handling.
|
||||||
|
//!
|
||||||
|
//! All communication with the core orchestrator flows through channels:
|
||||||
|
//! - [`UserAction`] sent via `action_tx` when the user submits input or quits
|
||||||
|
//! - [`UIEvent`] received via `event_rx` to display streaming assistant responses
|
||||||
|
//!
|
||||||
|
//! The terminal lifecycle follows the standard crossterm pattern:
|
||||||
|
//! 1. Enable raw mode
|
||||||
|
//! 2. Enter alternate screen
|
||||||
|
//! 3. On exit (or panic), disable raw mode and leave the alternate screen
|
||||||
|
|
||||||
|
use std::io::{self, Stdout};
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use crossterm::event::{Event, EventStream, KeyCode, KeyEvent, KeyModifiers};
|
||||||
|
use crossterm::execute;
|
||||||
|
use crossterm::terminal::{
|
||||||
|
EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode,
|
||||||
|
};
|
||||||
|
use futures::StreamExt;
|
||||||
|
use ratatui::backend::CrosstermBackend;
|
||||||
|
use ratatui::layout::{Constraint, Layout, Rect};
|
||||||
|
use ratatui::style::{Color, Style};
|
||||||
|
use ratatui::text::{Line, Span};
|
||||||
|
use ratatui::widgets::{Block, Paragraph, Wrap};
|
||||||
|
use ratatui::{Frame, Terminal};
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
|
use crate::core::types::{Role, UIEvent, UserAction};
|
||||||
|
|
||||||
|
/// Errors that can occur in the TUI layer.
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum TuiError {
|
||||||
|
/// An underlying terminal I/O error.
|
||||||
|
#[error("terminal IO error: {0}")]
|
||||||
|
Io(#[from] std::io::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The UI-layer view of a conversation: rendered messages and the current input buffer.
|
||||||
|
pub struct AppState {
|
||||||
|
/// All conversation turns rendered as (role, content) pairs.
|
||||||
|
pub messages: Vec<(Role, String)>,
|
||||||
|
/// The current contents of the input box.
|
||||||
|
pub input: String,
|
||||||
|
/// Vertical scroll offset for the output pane (lines from top).
|
||||||
|
pub scroll: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AppState {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
messages: Vec::new(),
|
||||||
|
input: String::new(),
|
||||||
|
scroll: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Initialise the terminal: enable raw mode and switch to the alternate screen.
|
||||||
|
///
|
||||||
|
/// Callers must pair this with [`restore_terminal`] (and [`install_panic_hook`]) to
|
||||||
|
/// guarantee cleanup even on abnormal exit.
|
||||||
|
pub fn init_terminal() -> Result<Terminal<CrosstermBackend<Stdout>>, TuiError> {
|
||||||
|
enable_raw_mode()?;
|
||||||
|
let mut stdout = io::stdout();
|
||||||
|
execute!(stdout, EnterAlternateScreen)?;
|
||||||
|
let backend = CrosstermBackend::new(stdout);
|
||||||
|
let terminal = Terminal::new(backend)?;
|
||||||
|
Ok(terminal)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Restore the terminal to its pre-launch state: disable raw mode and leave the
|
||||||
|
/// alternate screen.
|
||||||
|
pub fn restore_terminal() -> io::Result<()> {
|
||||||
|
disable_raw_mode()?;
|
||||||
|
execute!(io::stdout(), LeaveAlternateScreen)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Install a panic hook that restores the terminal before printing the panic message.
|
||||||
|
///
|
||||||
|
/// Without this, a panic leaves the terminal in raw mode with the alternate screen
|
||||||
|
/// active, making the shell unusable until the user runs `reset`.
|
||||||
|
pub fn install_panic_hook() {
|
||||||
|
let original = std::panic::take_hook();
|
||||||
|
std::panic::set_hook(Box::new(move |info| {
|
||||||
|
// Best-effort restore; if it fails the original hook still runs.
|
||||||
|
let _ = restore_terminal();
|
||||||
|
original(info);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Internal control flow signal returned by [`handle_key`].
|
||||||
|
enum LoopControl {
|
||||||
|
/// The user pressed Enter with non-empty input; send this message to the core.
|
||||||
|
SendMessage(String),
|
||||||
|
/// The user pressed Ctrl+C or Ctrl+D; exit the event loop.
|
||||||
|
Quit,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Map a key event to a [`LoopControl`] signal, mutating `state` as a side-effect.
|
||||||
|
///
|
||||||
|
/// Returns `None` when the key is consumed with no further loop-level action needed.
|
||||||
|
///
|
||||||
|
/// | Key | Effect |
|
||||||
|
/// |------------------|-------------------------------------------------|
|
||||||
|
/// | Printable (no CTRL) | `state.input.push(c)` |
|
||||||
|
/// | Backspace | `state.input.pop()` |
|
||||||
|
/// | Enter (non-empty)| Take input, push User message, return `SendMessage` |
|
||||||
|
/// | Enter (empty) | No-op |
|
||||||
|
/// | Ctrl+C / Ctrl+D | Return `Quit` |
|
||||||
|
fn handle_key(key: Option<KeyEvent>, state: &mut AppState) -> Option<LoopControl> {
|
||||||
|
let key = key?;
|
||||||
|
match key.code {
|
||||||
|
KeyCode::Char('c') if key.modifiers.contains(KeyModifiers::CONTROL) => {
|
||||||
|
Some(LoopControl::Quit)
|
||||||
|
}
|
||||||
|
KeyCode::Char('d') if key.modifiers.contains(KeyModifiers::CONTROL) => {
|
||||||
|
Some(LoopControl::Quit)
|
||||||
|
}
|
||||||
|
KeyCode::Char(c) if !key.modifiers.contains(KeyModifiers::CONTROL) => {
|
||||||
|
state.input.push(c);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
KeyCode::Backspace => {
|
||||||
|
state.input.pop();
|
||||||
|
None
|
||||||
|
}
|
||||||
|
KeyCode::Enter => {
|
||||||
|
let msg = std::mem::take(&mut state.input);
|
||||||
|
if msg.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
state.messages.push((Role::User, msg.clone()));
|
||||||
|
Some(LoopControl::SendMessage(msg))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Drain all pending [`UIEvent`]s from `event_rx` and apply them to `state`.
|
||||||
|
///
|
||||||
|
/// This is non-blocking: it processes all currently-available events and returns
|
||||||
|
/// immediately when the channel is empty.
|
||||||
|
///
|
||||||
|
/// | Event | Effect |
|
||||||
|
/// |--------------------|------------------------------------------------------------|
|
||||||
|
/// | `StreamDelta(s)` | Append `s` to last message if it's `Assistant`; else push new |
|
||||||
|
/// | `TurnComplete` | No structural change; logged at debug level |
|
||||||
|
/// | `Error(msg)` | Push `(Assistant, "[error] {msg}")` |
|
||||||
|
fn drain_ui_events(event_rx: &mut mpsc::Receiver<UIEvent>, state: &mut AppState) {
|
||||||
|
while let Ok(event) = event_rx.try_recv() {
|
||||||
|
match event {
|
||||||
|
UIEvent::StreamDelta(chunk) => {
|
||||||
|
if let Some((Role::Assistant, content)) = state.messages.last_mut() {
|
||||||
|
content.push_str(&chunk);
|
||||||
|
} else {
|
||||||
|
state.messages.push((Role::Assistant, chunk));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
UIEvent::TurnComplete => {
|
||||||
|
debug!("turn complete");
|
||||||
|
}
|
||||||
|
UIEvent::Error(msg) => {
|
||||||
|
state
|
||||||
|
.messages
|
||||||
|
.push((Role::Assistant, format!("[error] {msg}")));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Estimate the total rendered line count for all messages and update `state.scroll`
|
||||||
|
/// so the bottom of the content is visible.
|
||||||
|
///
|
||||||
|
/// When content fits within the viewport, `state.scroll` is set to 0.
|
||||||
|
///
|
||||||
|
/// Line estimation per message:
|
||||||
|
/// - 1 line for the role header
|
||||||
|
/// - `ceil(chars / width).max(1)` lines per newline-separated content line
|
||||||
|
/// - 1 blank separator line
|
||||||
|
fn update_scroll(state: &mut AppState, area: Rect) {
|
||||||
|
// 3 = height of the input pane (border top + content + border bottom)
|
||||||
|
let viewport_height = area.height.saturating_sub(3);
|
||||||
|
let width = area.width.max(1) as usize;
|
||||||
|
|
||||||
|
let mut total_lines: u16 = 0;
|
||||||
|
for (_, content) in &state.messages {
|
||||||
|
total_lines = total_lines.saturating_add(1); // role header
|
||||||
|
for line in content.lines() {
|
||||||
|
let chars = line.chars().count();
|
||||||
|
let wrapped = chars.div_ceil(width).max(1) as u16;
|
||||||
|
total_lines = total_lines.saturating_add(wrapped);
|
||||||
|
}
|
||||||
|
total_lines = total_lines.saturating_add(1); // blank separator
|
||||||
|
}
|
||||||
|
|
||||||
|
state.scroll = total_lines.saturating_sub(viewport_height);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Render the full TUI into `frame`.
|
||||||
|
///
|
||||||
|
/// Layout (top to bottom):
|
||||||
|
/// ```text
|
||||||
|
/// +--------------------------------+
|
||||||
|
/// | conversation history | Fill(1)
|
||||||
|
/// | |
|
||||||
|
/// +--------------------------------+
|
||||||
|
/// | Input | Length(3)
|
||||||
|
/// | > _ |
|
||||||
|
/// +--------------------------------+
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// Role headers are coloured: `"You:"` in cyan, `"Assistant:"` in green.
|
||||||
|
fn render(frame: &mut Frame, state: &AppState) {
|
||||||
|
let chunks = Layout::vertical([Constraint::Fill(1), Constraint::Length(3)]).split(frame.area());
|
||||||
|
|
||||||
|
// --- Output pane ---
|
||||||
|
let mut lines: Vec<Line> = Vec::new();
|
||||||
|
for (role, content) in &state.messages {
|
||||||
|
let (label, color) = match role {
|
||||||
|
Role::User => ("You:", Color::Cyan),
|
||||||
|
Role::Assistant => ("Assistant:", Color::Green),
|
||||||
|
};
|
||||||
|
lines.push(Line::from(Span::styled(label, Style::default().fg(color))));
|
||||||
|
for body_line in content.lines() {
|
||||||
|
lines.push(Line::from(body_line.to_string()));
|
||||||
|
}
|
||||||
|
lines.push(Line::from("")); // blank separator
|
||||||
|
}
|
||||||
|
|
||||||
|
let output = Paragraph::new(lines)
|
||||||
|
.wrap(Wrap { trim: false })
|
||||||
|
.scroll((state.scroll, 0));
|
||||||
|
frame.render_widget(output, chunks[0]);
|
||||||
|
|
||||||
|
// --- Input pane ---
|
||||||
|
let input_text = format!("> {}", state.input);
|
||||||
|
let input_widget = Paragraph::new(input_text).block(Block::bordered().title("Input"));
|
||||||
|
frame.render_widget(input_widget, chunks[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the TUI event loop.
|
||||||
|
///
|
||||||
|
/// This function owns the terminal for its entire lifetime. It initialises the
|
||||||
|
/// terminal, installs the panic hook, then spins in a ~60 fps loop:
|
||||||
|
///
|
||||||
|
/// ```text
|
||||||
|
/// loop:
|
||||||
|
/// 1. drain UIEvents (non-blocking try_recv)
|
||||||
|
/// 2. poll keyboard for up to 16 ms via EventStream (async, no blocking thread)
|
||||||
|
/// 3. handle key event -> Option<LoopControl>
|
||||||
|
/// 4. render frame (scroll updated inside draw closure)
|
||||||
|
/// 5. act on LoopControl: send message or break
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// On `Ctrl+C` / `Ctrl+D`: sends [`UserAction::Quit`], restores the terminal, and
|
||||||
|
/// returns `Ok(())`.
|
||||||
|
pub async fn run(
|
||||||
|
action_tx: mpsc::Sender<UserAction>,
|
||||||
|
mut event_rx: mpsc::Receiver<UIEvent>,
|
||||||
|
) -> Result<(), TuiError> {
|
||||||
|
install_panic_hook();
|
||||||
|
let mut terminal = init_terminal()?;
|
||||||
|
let mut state = AppState::new();
|
||||||
|
let mut event_stream = EventStream::new();
|
||||||
|
|
||||||
|
loop {
|
||||||
|
// 1. Drain pending UI events.
|
||||||
|
drain_ui_events(&mut event_rx, &mut state);
|
||||||
|
|
||||||
|
// 2. Poll keyboard for up to 16 ms. EventStream integrates with the
|
||||||
|
// Tokio runtime via futures::Stream, so no blocking thread is needed.
|
||||||
|
// Timeout expiry, stream end, non-key events, and I/O errors all map
|
||||||
|
// to None -- the frame is rendered regardless.
|
||||||
|
let key_event: Option<KeyEvent> =
|
||||||
|
match tokio::time::timeout(Duration::from_millis(16), event_stream.next()).await {
|
||||||
|
Ok(Some(Ok(Event::Key(k)))) => Some(k),
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// 3. Handle key.
|
||||||
|
let control = handle_key(key_event, &mut state);
|
||||||
|
|
||||||
|
// 4. Render (scroll updated inside draw closure to use current frame area).
|
||||||
|
terminal.draw(|frame| {
|
||||||
|
update_scroll(&mut state, frame.area());
|
||||||
|
render(frame, &state);
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// 5. Act on control signal after render so the user sees the submitted message.
|
||||||
|
match control {
|
||||||
|
Some(LoopControl::SendMessage(msg)) => {
|
||||||
|
if action_tx.send(UserAction::SendMessage(msg)).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(LoopControl::Quit) => {
|
||||||
|
let _ = action_tx.send(UserAction::Quit).await;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
None => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
restore_terminal()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use ratatui::backend::TestBackend;
|
||||||
|
|
||||||
|
fn make_key(code: KeyCode) -> Option<KeyEvent> {
|
||||||
|
Some(KeyEvent::new(code, KeyModifiers::empty()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ctrl_key(c: char) -> Option<KeyEvent> {
|
||||||
|
Some(KeyEvent::new(KeyCode::Char(c), KeyModifiers::CONTROL))
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- handle_key tests ---
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn handle_key_printable_appends() {
|
||||||
|
let mut state = AppState::new();
|
||||||
|
handle_key(make_key(KeyCode::Char('h')), &mut state);
|
||||||
|
assert_eq!(state.input, "h");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn handle_key_backspace_pops() {
|
||||||
|
let mut state = AppState::new();
|
||||||
|
state.input = "ab".to_string();
|
||||||
|
handle_key(make_key(KeyCode::Backspace), &mut state);
|
||||||
|
assert_eq!(state.input, "a");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn handle_key_backspace_empty_noop() {
|
||||||
|
let mut state = AppState::new();
|
||||||
|
handle_key(make_key(KeyCode::Backspace), &mut state);
|
||||||
|
assert_eq!(state.input, "");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn handle_key_enter_empty_noop() {
|
||||||
|
let mut state = AppState::new();
|
||||||
|
let result = handle_key(make_key(KeyCode::Enter), &mut state);
|
||||||
|
assert!(result.is_none());
|
||||||
|
assert!(state.messages.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn handle_key_enter_sends_and_clears() {
|
||||||
|
let mut state = AppState::new();
|
||||||
|
state.input = "hello".to_string();
|
||||||
|
let result = handle_key(make_key(KeyCode::Enter), &mut state);
|
||||||
|
assert!(state.input.is_empty());
|
||||||
|
assert_eq!(state.messages.len(), 1);
|
||||||
|
assert!(matches!(result, Some(LoopControl::SendMessage(ref m)) if m == "hello"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn handle_key_ctrl_c_quits() {
|
||||||
|
let mut state = AppState::new();
|
||||||
|
let result = handle_key(ctrl_key('c'), &mut state);
|
||||||
|
assert!(matches!(result, Some(LoopControl::Quit)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- drain_ui_events tests ---
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn drain_appends_to_existing_assistant() {
|
||||||
|
let (tx, mut rx) = tokio::sync::mpsc::channel(8);
|
||||||
|
let mut state = AppState::new();
|
||||||
|
state.messages.push((Role::Assistant, "hello".to_string()));
|
||||||
|
tx.send(UIEvent::StreamDelta(" world".to_string()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
drop(tx);
|
||||||
|
drain_ui_events(&mut rx, &mut state);
|
||||||
|
assert_eq!(state.messages.last().unwrap().1, "hello world");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn drain_creates_assistant_on_user_last() {
|
||||||
|
let (tx, mut rx) = tokio::sync::mpsc::channel(8);
|
||||||
|
let mut state = AppState::new();
|
||||||
|
state.messages.push((Role::User, "hi".to_string()));
|
||||||
|
tx.send(UIEvent::StreamDelta("hello".to_string()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
drop(tx);
|
||||||
|
drain_ui_events(&mut rx, &mut state);
|
||||||
|
assert_eq!(state.messages.len(), 2);
|
||||||
|
assert_eq!(state.messages[1].0, Role::Assistant);
|
||||||
|
assert_eq!(state.messages[1].1, "hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- render tests ---
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn render_smoke_test() {
|
||||||
|
let backend = TestBackend::new(80, 24);
|
||||||
|
let mut terminal = Terminal::new(backend).unwrap();
|
||||||
|
let state = AppState::new();
|
||||||
|
terminal.draw(|frame| render(frame, &state)).unwrap();
|
||||||
|
// no panic is the assertion
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn render_shows_role_prefixes() {
|
||||||
|
let backend = TestBackend::new(80, 24);
|
||||||
|
let mut terminal = Terminal::new(backend).unwrap();
|
||||||
|
let mut state = AppState::new();
|
||||||
|
state.messages.push((Role::User, "hi".to_string()));
|
||||||
|
state.messages.push((Role::Assistant, "hello".to_string()));
|
||||||
|
terminal.draw(|frame| render(frame, &state)).unwrap();
|
||||||
|
let buf = terminal.backend().buffer().clone();
|
||||||
|
let all_text: String = buf
|
||||||
|
.content()
|
||||||
|
.iter()
|
||||||
|
.map(|c| c.symbol().to_string())
|
||||||
|
.collect();
|
||||||
|
assert!(
|
||||||
|
all_text.contains("You:"),
|
||||||
|
"expected 'You:' in buffer: {all_text:.100}"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
all_text.contains("Assistant:"),
|
||||||
|
"expected 'Assistant:' in buffer"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- update_scroll tests ---
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn update_scroll_zero_when_fits() {
|
||||||
|
let mut state = AppState::new();
|
||||||
|
state.messages.push((Role::User, "hello".to_string()));
|
||||||
|
let area = Rect::new(0, 0, 80, 24);
|
||||||
|
update_scroll(&mut state, area);
|
||||||
|
assert_eq!(state.scroll, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn update_scroll_positive_when_overflow() {
|
||||||
|
let mut state = AppState::new();
|
||||||
|
for i in 0..50 {
|
||||||
|
state.messages.push((Role::User, format!("message {i}")));
|
||||||
|
}
|
||||||
|
let area = Rect::new(0, 0, 80, 24);
|
||||||
|
update_scroll(&mut state, area);
|
||||||
|
assert!(state.scroll > 0, "expected scroll > 0 with 50 messages");
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue