Move tool executor behind tarpc. (#12)
Implement tool executor as a separate tarpc service to improve isolation and create sandboxing opportunities. Reviewed-on: #12 Co-authored-by: Drew Galbraith <drew@tiramisu.one> Co-committed-by: Drew Galbraith <drew@tiramisu.one>
This commit is contained in:
parent
7420755800
commit
312a5866f7
17 changed files with 465 additions and 476 deletions
178
Cargo.lock
generated
178
Cargo.lock
generated
|
|
@ -253,6 +253,21 @@ dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-channel"
|
||||||
|
version = "0.5.15"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-utils",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-utils"
|
||||||
|
version = "0.8.21"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crossterm"
|
name = "crossterm"
|
||||||
version = "0.29.0"
|
version = "0.29.0"
|
||||||
|
|
@ -696,13 +711,19 @@ dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-sink",
|
"futures-sink",
|
||||||
"http",
|
"http",
|
||||||
"indexmap",
|
"indexmap 2.13.0",
|
||||||
"slab",
|
"slab",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-util",
|
"tokio-util",
|
||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hashbrown"
|
||||||
|
version = "0.12.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hashbrown"
|
name = "hashbrown"
|
||||||
version = "0.15.5"
|
version = "0.15.5"
|
||||||
|
|
@ -774,6 +795,12 @@ version = "1.10.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
|
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "humantime"
|
||||||
|
version = "2.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "135b12329e5e3ce057a9f972339ea52bc954fe1e9358ef27f95e89716fbc5424"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hyper"
|
name = "hyper"
|
||||||
version = "1.8.1"
|
version = "1.8.1"
|
||||||
|
|
@ -951,6 +978,16 @@ dependencies = [
|
||||||
"icu_properties",
|
"icu_properties",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "indexmap"
|
||||||
|
version = "1.9.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"hashbrown 0.12.3",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "indexmap"
|
name = "indexmap"
|
||||||
version = "2.13.0"
|
version = "2.13.0"
|
||||||
|
|
@ -1304,6 +1341,49 @@ version = "0.2.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe"
|
checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "opentelemetry"
|
||||||
|
version = "0.18.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "69d6c3d7288a106c0a363e4b0e8d308058d56902adefb16f4936f417ffef086e"
|
||||||
|
dependencies = [
|
||||||
|
"opentelemetry_api",
|
||||||
|
"opentelemetry_sdk",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "opentelemetry_api"
|
||||||
|
version = "0.18.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c24f96e21e7acc813c7a8394ee94978929db2bcc46cf6b5014fc612bf7760c22"
|
||||||
|
dependencies = [
|
||||||
|
"futures-channel",
|
||||||
|
"futures-util",
|
||||||
|
"indexmap 1.9.3",
|
||||||
|
"js-sys",
|
||||||
|
"once_cell",
|
||||||
|
"pin-project-lite",
|
||||||
|
"thiserror 1.0.69",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "opentelemetry_sdk"
|
||||||
|
version = "0.18.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1ca41c4933371b61c2a2f214bf16931499af4ec90543604ec828f7a625c09113"
|
||||||
|
dependencies = [
|
||||||
|
"async-trait",
|
||||||
|
"crossbeam-channel",
|
||||||
|
"futures-channel",
|
||||||
|
"futures-executor",
|
||||||
|
"futures-util",
|
||||||
|
"once_cell",
|
||||||
|
"opentelemetry_api",
|
||||||
|
"percent-encoding",
|
||||||
|
"rand 0.8.5",
|
||||||
|
"thiserror 1.0.69",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ordered-float"
|
name = "ordered-float"
|
||||||
version = "4.6.0"
|
version = "4.6.0"
|
||||||
|
|
@ -1437,6 +1517,26 @@ dependencies = [
|
||||||
"siphasher",
|
"siphasher",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pin-project"
|
||||||
|
version = "1.1.11"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f1749c7ed4bcaf4c3d0a3efc28538844fb29bcdd7d2b67b2be7e20ba861ff517"
|
||||||
|
dependencies = [
|
||||||
|
"pin-project-internal",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pin-project-internal"
|
||||||
|
version = "1.1.11"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d9b20ed30f105399776b9c883e68e536ef602a16ae6f596d2c473591d6ad64c6"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.117",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pin-project-lite"
|
name = "pin-project-lite"
|
||||||
version = "0.2.16"
|
version = "0.2.16"
|
||||||
|
|
@ -1575,6 +1675,8 @@ version = "0.8.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
|
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
"rand_chacha 0.3.1",
|
||||||
"rand_core 0.6.4",
|
"rand_core 0.6.4",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -1584,10 +1686,20 @@ version = "0.9.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1"
|
checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"rand_chacha",
|
"rand_chacha 0.9.0",
|
||||||
"rand_core 0.9.5",
|
"rand_core 0.9.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rand_chacha"
|
||||||
|
version = "0.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
|
||||||
|
dependencies = [
|
||||||
|
"ppv-lite86",
|
||||||
|
"rand_core 0.6.4",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rand_chacha"
|
name = "rand_chacha"
|
||||||
version = "0.9.0"
|
version = "0.9.0"
|
||||||
|
|
@ -1603,6 +1715,9 @@ name = "rand_core"
|
||||||
version = "0.6.4"
|
version = "0.6.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
||||||
|
dependencies = [
|
||||||
|
"getrandom 0.2.17",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rand_core"
|
name = "rand_core"
|
||||||
|
|
@ -2087,6 +2202,7 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"similar",
|
"similar",
|
||||||
|
"tarpc",
|
||||||
"tempfile",
|
"tempfile",
|
||||||
"thiserror 2.0.18",
|
"thiserror 2.0.18",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
|
@ -2224,6 +2340,39 @@ dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tarpc"
|
||||||
|
version = "0.34.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "93a1870169fb9490fb3b37df7f50782986475c33cb90955f9f9b9ae659124200"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"fnv",
|
||||||
|
"futures",
|
||||||
|
"humantime",
|
||||||
|
"opentelemetry",
|
||||||
|
"pin-project",
|
||||||
|
"rand 0.8.5",
|
||||||
|
"static_assertions",
|
||||||
|
"tarpc-plugins",
|
||||||
|
"thiserror 1.0.69",
|
||||||
|
"tokio",
|
||||||
|
"tokio-util",
|
||||||
|
"tracing",
|
||||||
|
"tracing-opentelemetry",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tarpc-plugins"
|
||||||
|
version = "0.13.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ad8302bea2fb8a2b01b025d23414b0b4ed32a783b95e5d818c3320a8bc4baada"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 1.0.109",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tempfile"
|
name = "tempfile"
|
||||||
version = "3.26.0"
|
version = "3.26.0"
|
||||||
|
|
@ -2443,6 +2592,7 @@ dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-sink",
|
"futures-sink",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
|
"slab",
|
||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -2497,6 +2647,7 @@ version = "0.1.44"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100"
|
checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"log",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"tracing-attributes",
|
"tracing-attributes",
|
||||||
"tracing-core",
|
"tracing-core",
|
||||||
|
|
@ -2534,6 +2685,19 @@ dependencies = [
|
||||||
"tracing-core",
|
"tracing-core",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tracing-opentelemetry"
|
||||||
|
version = "0.18.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "21ebb87a95ea13271332df069020513ab70bdb5637ca42d6e492dc3bbbad48de"
|
||||||
|
dependencies = [
|
||||||
|
"once_cell",
|
||||||
|
"opentelemetry",
|
||||||
|
"tracing",
|
||||||
|
"tracing-core",
|
||||||
|
"tracing-subscriber",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tracing-subscriber"
|
name = "tracing-subscriber"
|
||||||
version = "0.3.22"
|
version = "0.3.22"
|
||||||
|
|
@ -2787,7 +2951,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909"
|
checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"indexmap",
|
"indexmap 2.13.0",
|
||||||
"wasm-encoder",
|
"wasm-encoder",
|
||||||
"wasmparser",
|
"wasmparser",
|
||||||
]
|
]
|
||||||
|
|
@ -2813,7 +2977,7 @@ checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bitflags 2.11.0",
|
"bitflags 2.11.0",
|
||||||
"hashbrown 0.15.5",
|
"hashbrown 0.15.5",
|
||||||
"indexmap",
|
"indexmap 2.13.0",
|
||||||
"semver",
|
"semver",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -3234,7 +3398,7 @@ checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"heck",
|
"heck",
|
||||||
"indexmap",
|
"indexmap 2.13.0",
|
||||||
"prettyplease",
|
"prettyplease",
|
||||||
"syn 2.0.117",
|
"syn 2.0.117",
|
||||||
"wasm-metadata",
|
"wasm-metadata",
|
||||||
|
|
@ -3265,7 +3429,7 @@ checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"bitflags 2.11.0",
|
"bitflags 2.11.0",
|
||||||
"indexmap",
|
"indexmap 2.13.0",
|
||||||
"log",
|
"log",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_derive",
|
"serde_derive",
|
||||||
|
|
@ -3284,7 +3448,7 @@ checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"id-arena",
|
"id-arena",
|
||||||
"indexmap",
|
"indexmap 2.13.0",
|
||||||
"log",
|
"log",
|
||||||
"semver",
|
"semver",
|
||||||
"serde",
|
"serde",
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ futures = "0.3"
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
similar = "2"
|
similar = "2"
|
||||||
landlock = "0.4"
|
landlock = "0.4"
|
||||||
|
tarpc = { version = "0.34", features = ["tokio1"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tempfile = "3.26.0"
|
tempfile = "3.26.0"
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ use tokio::sync::mpsc;
|
||||||
|
|
||||||
use crate::core::orchestrator::Orchestrator;
|
use crate::core::orchestrator::Orchestrator;
|
||||||
use crate::core::types::{StampedEvent, UserAction};
|
use crate::core::types::{StampedEvent, UserAction};
|
||||||
|
use crate::executor::{ExecutorServer, spawn_local};
|
||||||
use crate::provider::ClaudeProvider;
|
use crate::provider::ClaudeProvider;
|
||||||
use crate::sandbox::policy::SandboxPolicy;
|
use crate::sandbox::policy::SandboxPolicy;
|
||||||
use crate::sandbox::{EnforcementMode, Sandbox};
|
use crate::sandbox::{EnforcementMode, Sandbox};
|
||||||
|
|
@ -87,9 +88,10 @@ pub async fn run(project_dir: &Path, yolo: bool) -> anyhow::Result<()> {
|
||||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(CHANNEL_CAP);
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(CHANNEL_CAP);
|
||||||
let (event_tx, event_rx) = mpsc::channel::<StampedEvent>(CHANNEL_CAP);
|
let (event_tx, event_rx) = mpsc::channel::<StampedEvent>(CHANNEL_CAP);
|
||||||
|
|
||||||
// -- Tools & Orchestrator (background task) ------------------------------------
|
// -- Executor & Orchestrator (background task) ---------------------------------
|
||||||
let tool_registry = ToolRegistry::default_tools();
|
let registry = ToolRegistry::default_tools();
|
||||||
let orch = Orchestrator::new(provider, tool_registry, sandbox, action_rx, event_tx);
|
let executor = spawn_local(ExecutorServer::new(registry, sandbox));
|
||||||
|
let orch = Orchestrator::new(provider, executor, action_rx, event_tx).await;
|
||||||
tokio::spawn(orch.run());
|
tokio::spawn(orch.run());
|
||||||
|
|
||||||
// -- TUI (foreground task) ----------------------------------------------------
|
// -- TUI (foreground task) ----------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -7,9 +7,9 @@ use crate::core::types::{
|
||||||
ContentBlock, ConversationMessage, Role, StampedEvent, StreamEvent, ToolDefinition,
|
ContentBlock, ConversationMessage, Role, StampedEvent, StreamEvent, ToolDefinition,
|
||||||
ToolDisplay, UIEvent, UserAction,
|
ToolDisplay, UIEvent, UserAction,
|
||||||
};
|
};
|
||||||
|
use crate::executor::ToolExecutorClient;
|
||||||
use crate::provider::ModelProvider;
|
use crate::provider::ModelProvider;
|
||||||
use crate::sandbox::Sandbox;
|
use crate::tools::ToolOutput;
|
||||||
use crate::tools::{RiskLevel, ToolOutput, ToolRegistry};
|
|
||||||
|
|
||||||
/// Accumulates data for a single tool-use block while it is being streamed.
|
/// Accumulates data for a single tool-use block while it is being streamed.
|
||||||
///
|
///
|
||||||
|
|
@ -64,6 +64,9 @@ impl TryFrom<ActiveToolUse> for ContentBlock {
|
||||||
/// |
|
/// |
|
||||||
/// v
|
/// v
|
||||||
/// ModelProvider (SSE stream)
|
/// ModelProvider (SSE stream)
|
||||||
|
/// |
|
||||||
|
/// v
|
||||||
|
/// ExecutorClient (in-process tarpc)
|
||||||
/// ```
|
/// ```
|
||||||
///
|
///
|
||||||
/// # Event loop
|
/// # Event loop
|
||||||
|
|
@ -107,35 +110,39 @@ fn truncate(s: &str, max_len: usize) -> String {
|
||||||
pub struct Orchestrator<P> {
|
pub struct Orchestrator<P> {
|
||||||
history: ConversationHistory,
|
history: ConversationHistory,
|
||||||
provider: P,
|
provider: P,
|
||||||
tool_registry: ToolRegistry,
|
/// tarpc client to the tool executor (in-process channel).
|
||||||
sandbox: Sandbox,
|
executor: ToolExecutorClient,
|
||||||
|
/// Cached tool definitions fetched from the executor at construction.
|
||||||
|
tool_definitions: Vec<ToolDefinition>,
|
||||||
action_rx: mpsc::Receiver<UserAction>,
|
action_rx: mpsc::Receiver<UserAction>,
|
||||||
event_tx: mpsc::Sender<StampedEvent>,
|
event_tx: mpsc::Sender<StampedEvent>,
|
||||||
/// Messages typed by the user while an approval prompt is open. They are
|
|
||||||
/// queued here and replayed as new turns once the current turn completes.
|
|
||||||
queued_messages: Vec<String>,
|
|
||||||
/// Monotonic epoch incremented on each `:clear`. Events are tagged with
|
/// Monotonic epoch incremented on each `:clear`. Events are tagged with
|
||||||
/// this value so the TUI can discard stale in-flight messages.
|
/// this value so the TUI can discard stale in-flight messages.
|
||||||
epoch: u64,
|
epoch: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<P: ModelProvider> Orchestrator<P> {
|
impl<P: ModelProvider> Orchestrator<P> {
|
||||||
/// Construct an orchestrator using the given provider and channel endpoints.
|
/// Construct an orchestrator using the given provider and executor client.
|
||||||
pub fn new(
|
///
|
||||||
|
/// Calls `executor.list_tools` once at startup to populate the cached
|
||||||
|
/// `tool_definitions` that are passed to the model on every turn.
|
||||||
|
pub async fn new(
|
||||||
provider: P,
|
provider: P,
|
||||||
tool_registry: ToolRegistry,
|
executor: ToolExecutorClient,
|
||||||
sandbox: Sandbox,
|
|
||||||
action_rx: mpsc::Receiver<UserAction>,
|
action_rx: mpsc::Receiver<UserAction>,
|
||||||
event_tx: mpsc::Sender<StampedEvent>,
|
event_tx: mpsc::Sender<StampedEvent>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
let tool_definitions = executor
|
||||||
|
.list_tools(tarpc::context::current())
|
||||||
|
.await
|
||||||
|
.unwrap_or_default();
|
||||||
Self {
|
Self {
|
||||||
history: ConversationHistory::new(),
|
history: ConversationHistory::new(),
|
||||||
provider,
|
provider,
|
||||||
tool_registry,
|
executor,
|
||||||
sandbox,
|
tool_definitions,
|
||||||
action_rx,
|
action_rx,
|
||||||
event_tx,
|
event_tx,
|
||||||
queued_messages: Vec::new(),
|
|
||||||
epoch: 0,
|
epoch: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -241,11 +248,9 @@ impl<P: ModelProvider> Orchestrator<P> {
|
||||||
/// loop back until the model produces a text-only response or we hit the
|
/// loop back until the model produces a text-only response or we hit the
|
||||||
/// iteration limit.
|
/// iteration limit.
|
||||||
async fn run_turn(&mut self) {
|
async fn run_turn(&mut self) {
|
||||||
let tool_defs = self.tool_registry.definitions();
|
|
||||||
|
|
||||||
for _ in 0..MAX_TOOL_ITERATIONS {
|
for _ in 0..MAX_TOOL_ITERATIONS {
|
||||||
let messages = self.history.messages().to_vec();
|
let messages = self.history.messages().to_vec();
|
||||||
let result = self.consume_stream(&messages, &tool_defs).await;
|
let result = self.consume_stream(&messages, &self.tool_definitions).await;
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
StreamResult::Error(msg) => {
|
StreamResult::Error(msg) => {
|
||||||
|
|
@ -273,12 +278,11 @@ impl<P: ModelProvider> Orchestrator<P> {
|
||||||
// mutability (Arc<Mutex<...>> around event_tx, action_rx,
|
// mutability (Arc<Mutex<...>> around event_tx, action_rx,
|
||||||
// history) so that multiple futures can borrow self
|
// history) so that multiple futures can borrow self
|
||||||
// simultaneously via futures::future::join_all.
|
// simultaneously via futures::future::join_all.
|
||||||
// Execute each tool-use block and collect results.
|
|
||||||
let mut tool_results: Vec<ContentBlock> = Vec::new();
|
let mut tool_results: Vec<ContentBlock> = Vec::new();
|
||||||
|
|
||||||
for block in &blocks {
|
for block in &blocks {
|
||||||
if let ContentBlock::ToolUse { id, name, input } = block {
|
if let ContentBlock::ToolUse { id, name, input } = block {
|
||||||
let result = self.execute_tool_with_approval(id, name, input).await;
|
let result = self.execute_tool(id, name, input).await;
|
||||||
tool_results.push(ContentBlock::ToolResult {
|
tool_results.push(ContentBlock::ToolResult {
|
||||||
tool_use_id: id.clone(),
|
tool_use_id: id.clone(),
|
||||||
content: result.content,
|
content: result.content,
|
||||||
|
|
@ -292,7 +296,6 @@ impl<P: ModelProvider> Orchestrator<P> {
|
||||||
role: Role::User,
|
role: Role::User,
|
||||||
content: tool_results,
|
content: tool_results,
|
||||||
});
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -307,16 +310,17 @@ impl<P: ModelProvider> Orchestrator<P> {
|
||||||
///
|
///
|
||||||
/// Matches on known tool names to extract structured fields; falls back to
|
/// Matches on known tool names to extract structured fields; falls back to
|
||||||
/// `Generic` with a JSON summary for anything else.
|
/// `Generic` with a JSON summary for anything else.
|
||||||
async fn build_tool_display(&self, tool_name: &str, input: &serde_json::Value) -> ToolDisplay {
|
fn build_tool_display(&self, tool_name: &str, input: &serde_json::Value) -> ToolDisplay {
|
||||||
match tool_name {
|
match tool_name {
|
||||||
"write_file" => {
|
"write_file" => {
|
||||||
let path = input["path"].as_str().unwrap_or("<unknown>").to_string();
|
let path = input["path"].as_str().unwrap_or("<unknown>").to_string();
|
||||||
let new_content = input["content"].as_str().unwrap_or("").to_string();
|
let new_content = input["content"].as_str().unwrap_or("").to_string();
|
||||||
// Try to read existing content for diffing.
|
// old_content is not available here (sandbox lives in the executor);
|
||||||
let old_content = self.sandbox.read_file(&path).await.ok();
|
// the diff is omitted at the display stage. The executor still
|
||||||
|
// reads and writes through the sandbox as normal.
|
||||||
ToolDisplay::WriteFile {
|
ToolDisplay::WriteFile {
|
||||||
path,
|
path,
|
||||||
old_content,
|
old_content: None,
|
||||||
new_content,
|
new_content,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -349,9 +353,6 @@ impl<P: ModelProvider> Orchestrator<P> {
|
||||||
"write_file" => {
|
"write_file" => {
|
||||||
let path = input["path"].as_str().unwrap_or("<unknown>").to_string();
|
let path = input["path"].as_str().unwrap_or("<unknown>").to_string();
|
||||||
let new_content = input["content"].as_str().unwrap_or("").to_string();
|
let new_content = input["content"].as_str().unwrap_or("").to_string();
|
||||||
// For results, old_content isn't available post-write. We already
|
|
||||||
// showed the diff at approval/executing time; the result just
|
|
||||||
// confirms the byte count via the display formatter.
|
|
||||||
ToolDisplay::WriteFile {
|
ToolDisplay::WriteFile {
|
||||||
path,
|
path,
|
||||||
old_content: None,
|
old_content: None,
|
||||||
|
|
@ -383,65 +384,33 @@ impl<P: ModelProvider> Orchestrator<P> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Execute a single tool, handling approval if needed.
|
/// Execute a single tool via the executor client, emitting `ToolExecuting`
|
||||||
///
|
/// before the call and `ToolResult` after.
|
||||||
/// For auto-approve tools, executes immediately. For tools requiring
|
async fn execute_tool(
|
||||||
/// approval, sends a request to the TUI and waits for a response.
|
&self,
|
||||||
async fn execute_tool_with_approval(
|
|
||||||
&mut self,
|
|
||||||
tool_use_id: &str,
|
tool_use_id: &str,
|
||||||
tool_name: &str,
|
tool_name: &str,
|
||||||
input: &serde_json::Value,
|
input: &serde_json::Value,
|
||||||
) -> ToolOutput {
|
) -> ToolOutput {
|
||||||
// Extract tool info upfront to avoid holding a borrow on self across
|
let display = self.build_tool_display(tool_name, input);
|
||||||
// the mutable wait_for_approval call.
|
self.send(UIEvent::ToolExecuting {
|
||||||
let risk = match self.tool_registry.get(tool_name) {
|
tool_use_id: tool_use_id.to_string(),
|
||||||
Some(t) => t.risk_level(),
|
tool_name: tool_name.to_string(),
|
||||||
None => {
|
display,
|
||||||
return ToolOutput {
|
})
|
||||||
content: format!("unknown tool: {tool_name}"),
|
.await;
|
||||||
is_error: true,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let display = self.build_tool_display(tool_name, input).await;
|
let rpc_result = self
|
||||||
|
.executor
|
||||||
|
.call_tool(
|
||||||
|
tarpc::context::current(),
|
||||||
|
tool_name.to_string(),
|
||||||
|
input.clone(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
// Check approval.
|
match rpc_result {
|
||||||
let approved = match risk {
|
Ok(Ok(output)) => {
|
||||||
RiskLevel::AutoApprove => {
|
|
||||||
self.send(UIEvent::ToolExecuting {
|
|
||||||
tool_use_id: tool_use_id.to_string(),
|
|
||||||
tool_name: tool_name.to_string(),
|
|
||||||
display,
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
true
|
|
||||||
}
|
|
||||||
RiskLevel::RequiresApproval => {
|
|
||||||
self.send(UIEvent::ToolApprovalRequest {
|
|
||||||
tool_use_id: tool_use_id.to_string(),
|
|
||||||
tool_name: tool_name.to_string(),
|
|
||||||
display,
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
|
|
||||||
// Wait for approval response from TUI.
|
|
||||||
self.wait_for_approval(tool_use_id).await
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if !approved {
|
|
||||||
return ToolOutput {
|
|
||||||
content: "tool execution denied by user".to_string(),
|
|
||||||
is_error: true,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Re-fetch tool for execution (borrow was released above).
|
|
||||||
let tool = self.tool_registry.get(tool_name).unwrap();
|
|
||||||
match tool.execute(input, &self.sandbox).await {
|
|
||||||
Ok(output) => {
|
|
||||||
let result_display = self.build_result_display(tool_name, input, &output.content);
|
let result_display = self.build_result_display(tool_name, input, &output.content);
|
||||||
self.send(UIEvent::ToolResult {
|
self.send(UIEvent::ToolResult {
|
||||||
tool_use_id: tool_use_id.to_string(),
|
tool_use_id: tool_use_id.to_string(),
|
||||||
|
|
@ -452,8 +421,23 @@ impl<P: ModelProvider> Orchestrator<P> {
|
||||||
.await;
|
.await;
|
||||||
output
|
output
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Ok(Err(msg)) => {
|
||||||
let msg = e.to_string();
|
self.send(UIEvent::ToolResult {
|
||||||
|
tool_use_id: tool_use_id.to_string(),
|
||||||
|
tool_name: tool_name.to_string(),
|
||||||
|
display: ToolDisplay::Generic {
|
||||||
|
summary: msg.clone(),
|
||||||
|
},
|
||||||
|
is_error: true,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
ToolOutput {
|
||||||
|
content: msg,
|
||||||
|
is_error: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(rpc_err) => {
|
||||||
|
let msg = format!("executor RPC error: {rpc_err}");
|
||||||
self.send(UIEvent::ToolResult {
|
self.send(UIEvent::ToolResult {
|
||||||
tool_use_id: tool_use_id.to_string(),
|
tool_use_id: tool_use_id.to_string(),
|
||||||
tool_name: tool_name.to_string(),
|
tool_name: tool_name.to_string(),
|
||||||
|
|
@ -471,40 +455,6 @@ impl<P: ModelProvider> Orchestrator<P> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Wait for a `ToolApprovalResponse` matching `tool_use_id` from the TUI.
|
|
||||||
///
|
|
||||||
/// Any `SendMessage` actions that arrive during the wait are pushed onto
|
|
||||||
/// `self.queued_messages` rather than discarded. They are replayed as new
|
|
||||||
/// turns once the current tool-use loop completes (see `run()`).
|
|
||||||
///
|
|
||||||
/// Returns `true` if approved, `false` if denied or the channel closes.
|
|
||||||
async fn wait_for_approval(&mut self, tool_use_id: &str) -> bool {
|
|
||||||
while let Some(action) = self.action_rx.recv().await {
|
|
||||||
match action {
|
|
||||||
UserAction::ToolApprovalResponse {
|
|
||||||
tool_use_id: id,
|
|
||||||
approved,
|
|
||||||
} if id == tool_use_id => {
|
|
||||||
return approved;
|
|
||||||
}
|
|
||||||
UserAction::Quit => return false,
|
|
||||||
UserAction::SendMessage(text) => {
|
|
||||||
self.queued_messages.push(text);
|
|
||||||
}
|
|
||||||
UserAction::ClearHistory { epoch } => {
|
|
||||||
// Keep epoch in sync even while blocked on an approval
|
|
||||||
// prompt. Without this, events emitted after the approval
|
|
||||||
// resolves would carry the pre-clear epoch and be silently
|
|
||||||
// discarded by the TUI.
|
|
||||||
self.epoch = epoch;
|
|
||||||
self.history.clear();
|
|
||||||
}
|
|
||||||
_ => {} // discard stale approvals
|
|
||||||
}
|
|
||||||
}
|
|
||||||
false
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Run the orchestrator until the user quits or the `action_rx` channel closes.
|
/// Run the orchestrator until the user quits or the `action_rx` channel closes.
|
||||||
pub async fn run(mut self) {
|
pub async fn run(mut self) {
|
||||||
while let Some(action) = self.action_rx.recv().await {
|
while let Some(action) = self.action_rx.recv().await {
|
||||||
|
|
@ -514,31 +464,16 @@ impl<P: ModelProvider> Orchestrator<P> {
|
||||||
self.epoch = epoch;
|
self.epoch = epoch;
|
||||||
self.history.clear();
|
self.history.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Approval responses are handled inline during tool execution,
|
|
||||||
// not in the main action loop. If one arrives here it's stale.
|
|
||||||
UserAction::ToolApprovalResponse { .. } => {}
|
|
||||||
|
|
||||||
UserAction::SetNetworkPolicy(allowed) => {
|
UserAction::SetNetworkPolicy(allowed) => {
|
||||||
self.sandbox.set_network_allowed(allowed);
|
// The sandbox lives inside the executor; there is currently
|
||||||
|
// no RPC to change its policy at runtime. The event is still
|
||||||
|
// forwarded to the TUI for status-bar display purposes.
|
||||||
self.send(UIEvent::NetworkPolicyChanged(allowed)).await;
|
self.send(UIEvent::NetworkPolicyChanged(allowed)).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
UserAction::SendMessage(text) => {
|
UserAction::SendMessage(text) => {
|
||||||
self.history
|
self.history
|
||||||
.push(ConversationMessage::text(Role::User, text));
|
.push(ConversationMessage::text(Role::User, text));
|
||||||
self.run_turn().await;
|
self.run_turn().await;
|
||||||
|
|
||||||
// Drain any messages queued while an approval prompt was
|
|
||||||
// open. Each queued message is a full turn in sequence.
|
|
||||||
while !self.queued_messages.is_empty() {
|
|
||||||
let queued = std::mem::take(&mut self.queued_messages);
|
|
||||||
for msg in queued {
|
|
||||||
self.history
|
|
||||||
.push(ConversationMessage::text(Role::User, msg));
|
|
||||||
self.run_turn().await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -549,6 +484,7 @@ impl<P: ModelProvider> Orchestrator<P> {
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::core::types::ToolDefinition;
|
use crate::core::types::ToolDefinition;
|
||||||
|
use crate::executor::{ExecutorServer, spawn_local};
|
||||||
use crate::tools::ToolRegistry;
|
use crate::tools::ToolRegistry;
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
|
|
@ -576,8 +512,8 @@ mod tests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create an Orchestrator with no tools for testing text-only flows.
|
/// Create an `Orchestrator` with an empty tool registry for testing text-only flows.
|
||||||
fn test_orchestrator<P: ModelProvider>(
|
async fn test_orchestrator<P: ModelProvider>(
|
||||||
provider: P,
|
provider: P,
|
||||||
action_rx: mpsc::Receiver<UserAction>,
|
action_rx: mpsc::Receiver<UserAction>,
|
||||||
event_tx: mpsc::Sender<StampedEvent>,
|
event_tx: mpsc::Sender<StampedEvent>,
|
||||||
|
|
@ -589,13 +525,8 @@ mod tests {
|
||||||
std::path::PathBuf::from("/tmp"),
|
std::path::PathBuf::from("/tmp"),
|
||||||
EnforcementMode::Yolo,
|
EnforcementMode::Yolo,
|
||||||
);
|
);
|
||||||
Orchestrator::new(
|
let executor = spawn_local(ExecutorServer::new(ToolRegistry::empty(), sandbox));
|
||||||
provider,
|
Orchestrator::new(provider, executor, action_rx, event_tx).await
|
||||||
ToolRegistry::empty(),
|
|
||||||
sandbox,
|
|
||||||
action_rx,
|
|
||||||
event_tx,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Collect all UIEvents that arrive within one orchestrator turn, stopping
|
/// Collect all UIEvents that arrive within one orchestrator turn, stopping
|
||||||
|
|
@ -633,7 +564,7 @@ mod tests {
|
||||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||||
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(16);
|
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(16);
|
||||||
|
|
||||||
let orch = test_orchestrator(provider, action_rx, event_tx);
|
let orch = test_orchestrator(provider, action_rx, event_tx).await;
|
||||||
let handle = tokio::spawn(orch.run());
|
let handle = tokio::spawn(orch.run());
|
||||||
|
|
||||||
action_tx
|
action_tx
|
||||||
|
|
@ -652,7 +583,7 @@ mod tests {
|
||||||
assert!(matches!(&events[1], UIEvent::StreamDelta(s) if s == ", world!"));
|
assert!(matches!(&events[1], UIEvent::StreamDelta(s) if s == ", world!"));
|
||||||
assert!(matches!(events[2], UIEvent::TurnComplete));
|
assert!(matches!(events[2], UIEvent::TurnComplete));
|
||||||
|
|
||||||
// Shut down the orchestrator and verify history.
|
// Shut down the orchestrator.
|
||||||
action_tx.send(UserAction::Quit).await.unwrap();
|
action_tx.send(UserAction::Quit).await.unwrap();
|
||||||
handle.await.unwrap();
|
handle.await.unwrap();
|
||||||
}
|
}
|
||||||
|
|
@ -671,7 +602,7 @@ mod tests {
|
||||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||||
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(16);
|
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(16);
|
||||||
|
|
||||||
let orch = test_orchestrator(provider, action_rx, event_tx);
|
let orch = test_orchestrator(provider, action_rx, event_tx).await;
|
||||||
let handle = tokio::spawn(orch.run());
|
let handle = tokio::spawn(orch.run());
|
||||||
|
|
||||||
action_tx
|
action_tx
|
||||||
|
|
@ -713,7 +644,7 @@ mod tests {
|
||||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||||
let (event_tx, _event_rx) = mpsc::channel::<StampedEvent>(8);
|
let (event_tx, _event_rx) = mpsc::channel::<StampedEvent>(8);
|
||||||
|
|
||||||
let orch = test_orchestrator(NeverCalledProvider, action_rx, event_tx);
|
let orch = test_orchestrator(NeverCalledProvider, action_rx, event_tx).await;
|
||||||
let handle = tokio::spawn(orch.run());
|
let handle = tokio::spawn(orch.run());
|
||||||
|
|
||||||
action_tx.send(UserAction::Quit).await.unwrap();
|
action_tx.send(UserAction::Quit).await.unwrap();
|
||||||
|
|
@ -724,12 +655,8 @@ mod tests {
|
||||||
|
|
||||||
/// Two sequential SendMessage turns each append a user message and the
|
/// Two sequential SendMessage turns each append a user message and the
|
||||||
/// accumulated assistant response, leaving four messages in history order.
|
/// accumulated assistant response, leaving four messages in history order.
|
||||||
///
|
|
||||||
/// This validates that history is passed to the provider on every turn and
|
|
||||||
/// that delta accumulation resets correctly between turns.
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn two_turns_accumulate_history_correctly() {
|
async fn two_turns_accumulate_history_correctly() {
|
||||||
// Both turns produce the same simple response for simplicity.
|
|
||||||
let make_turn_events = || {
|
let make_turn_events = || {
|
||||||
vec![
|
vec![
|
||||||
StreamEvent::TextDelta("reply".to_string()),
|
StreamEvent::TextDelta("reply".to_string()),
|
||||||
|
|
@ -737,8 +664,6 @@ mod tests {
|
||||||
]
|
]
|
||||||
};
|
};
|
||||||
|
|
||||||
// We need to serve two different turns from the same provider.
|
|
||||||
// Use an `Arc<Mutex<VecDeque>>` so the provider can pop event sets.
|
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
|
@ -766,7 +691,7 @@ mod tests {
|
||||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||||
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(32);
|
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(32);
|
||||||
|
|
||||||
let orch = test_orchestrator(provider, action_rx, event_tx);
|
let orch = test_orchestrator(provider, action_rx, event_tx).await;
|
||||||
let handle = tokio::spawn(orch.run());
|
let handle = tokio::spawn(orch.run());
|
||||||
|
|
||||||
// First turn.
|
// First turn.
|
||||||
|
|
@ -794,8 +719,8 @@ mod tests {
|
||||||
// -- tool-use accumulation ----------------------------------------------------
|
// -- tool-use accumulation ----------------------------------------------------
|
||||||
|
|
||||||
/// When the provider emits tool-use events, the orchestrator executes the
|
/// When the provider emits tool-use events, the orchestrator executes the
|
||||||
/// tool (auto-approve since read_file is AutoApprove), feeds the result back,
|
/// tool via the executor, feeds the result back, and the provider's second
|
||||||
/// and the provider's second call returns text.
|
/// call returns text.
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn tool_use_loop_executes_and_feeds_back() {
|
async fn tool_use_loop_executes_and_feeds_back() {
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
|
|
@ -839,6 +764,7 @@ mod tests {
|
||||||
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(32);
|
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(32);
|
||||||
|
|
||||||
// Use a real ToolRegistry so read_file works.
|
// Use a real ToolRegistry so read_file works.
|
||||||
|
use crate::executor::{ExecutorServer, spawn_local};
|
||||||
use crate::sandbox::policy::SandboxPolicy;
|
use crate::sandbox::policy::SandboxPolicy;
|
||||||
use crate::sandbox::{EnforcementMode, Sandbox};
|
use crate::sandbox::{EnforcementMode, Sandbox};
|
||||||
let dir = tempfile::TempDir::new().unwrap();
|
let dir = tempfile::TempDir::new().unwrap();
|
||||||
|
|
@ -848,13 +774,8 @@ mod tests {
|
||||||
dir.path().to_path_buf(),
|
dir.path().to_path_buf(),
|
||||||
EnforcementMode::Yolo,
|
EnforcementMode::Yolo,
|
||||||
);
|
);
|
||||||
let orch = Orchestrator::new(
|
let executor = spawn_local(ExecutorServer::new(ToolRegistry::default_tools(), sandbox));
|
||||||
MultiCallMock { turns },
|
let orch = Orchestrator::new(MultiCallMock { turns }, executor, action_rx, event_tx).await;
|
||||||
ToolRegistry::default_tools(),
|
|
||||||
sandbox,
|
|
||||||
action_rx,
|
|
||||||
event_tx,
|
|
||||||
);
|
|
||||||
let handle = tokio::spawn(orch.run());
|
let handle = tokio::spawn(orch.run());
|
||||||
|
|
||||||
action_tx
|
action_tx
|
||||||
|
|
@ -925,7 +846,7 @@ mod tests {
|
||||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||||
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(16);
|
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(16);
|
||||||
|
|
||||||
let orch = test_orchestrator(provider, action_rx, event_tx);
|
let orch = test_orchestrator(provider, action_rx, event_tx).await;
|
||||||
let handle = tokio::spawn(orch.run());
|
let handle = tokio::spawn(orch.run());
|
||||||
|
|
||||||
action_tx
|
action_tx
|
||||||
|
|
@ -946,116 +867,11 @@ mod tests {
|
||||||
handle.await.unwrap();
|
handle.await.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
// -- queued messages ----------------------------------------------------------
|
|
||||||
|
|
||||||
/// A SendMessage sent while an approval prompt is open must be processed
|
|
||||||
/// after the current turn completes, not silently dropped.
|
|
||||||
///
|
|
||||||
/// This test uses a RequiresApproval tool (write_file) so that the
|
|
||||||
/// orchestrator blocks in wait_for_approval. While blocked, we send a second
|
|
||||||
/// message. After approving (denied here for simplicity), the queued message
|
|
||||||
/// should still be processed.
|
|
||||||
#[tokio::test]
|
|
||||||
async fn send_message_during_approval_is_queued_and_processed() {
|
|
||||||
use std::collections::VecDeque;
|
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
|
|
||||||
struct MultiCallMock {
|
|
||||||
turns: Arc<Mutex<VecDeque<Vec<StreamEvent>>>>,
|
|
||||||
}
|
|
||||||
impl ModelProvider for MultiCallMock {
|
|
||||||
fn stream<'a>(
|
|
||||||
&'a self,
|
|
||||||
_messages: &'a [ConversationMessage],
|
|
||||||
_tools: &'a [ToolDefinition],
|
|
||||||
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
|
||||||
let events = self.turns.lock().unwrap().pop_front().unwrap_or_default();
|
|
||||||
futures::stream::iter(events)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let turns = Arc::new(Mutex::new(VecDeque::from([
|
|
||||||
// Turn 1: requests write_file (RequiresApproval)
|
|
||||||
vec![
|
|
||||||
StreamEvent::ToolUseStart {
|
|
||||||
id: "toolu_w".to_string(),
|
|
||||||
name: "write_file".to_string(),
|
|
||||||
},
|
|
||||||
StreamEvent::ToolUseInputDelta(
|
|
||||||
"{\"path\":\"x.txt\",\"content\":\"hi\"}".to_string(),
|
|
||||||
),
|
|
||||||
StreamEvent::ToolUseDone,
|
|
||||||
StreamEvent::Done,
|
|
||||||
],
|
|
||||||
// Turn 1 second iteration after tool result (denied)
|
|
||||||
vec![
|
|
||||||
StreamEvent::TextDelta("ok denied".to_string()),
|
|
||||||
StreamEvent::Done,
|
|
||||||
],
|
|
||||||
// Turn 2: the queued message
|
|
||||||
vec![
|
|
||||||
StreamEvent::TextDelta("queued reply".to_string()),
|
|
||||||
StreamEvent::Done,
|
|
||||||
],
|
|
||||||
])));
|
|
||||||
|
|
||||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(16);
|
|
||||||
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(64);
|
|
||||||
|
|
||||||
let orch = test_orchestrator(MultiCallMock { turns }, action_rx, event_tx);
|
|
||||||
let handle = tokio::spawn(orch.run());
|
|
||||||
|
|
||||||
// Start turn 1 -- orchestrator will block on approval.
|
|
||||||
action_tx
|
|
||||||
.send(UserAction::SendMessage("turn1".to_string()))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Let the orchestrator reach wait_for_approval.
|
|
||||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
|
||||||
|
|
||||||
// Send a message while blocked -- should be queued.
|
|
||||||
action_tx
|
|
||||||
.send(UserAction::SendMessage("queued".to_string()))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Deny the tool -- unblocks wait_for_approval.
|
|
||||||
action_tx
|
|
||||||
.send(UserAction::ToolApprovalResponse {
|
|
||||||
tool_use_id: "toolu_w".to_string(),
|
|
||||||
approved: false,
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Wait for both turns to complete.
|
|
||||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
|
||||||
|
|
||||||
// Collect everything.
|
|
||||||
let mut all_events = Vec::new();
|
|
||||||
while let Ok(ev) = event_rx.try_recv() {
|
|
||||||
all_events.push(ev.event);
|
|
||||||
}
|
|
||||||
|
|
||||||
// The queued message must have produced "queued reply".
|
|
||||||
assert!(
|
|
||||||
all_events
|
|
||||||
.iter()
|
|
||||||
.any(|e| matches!(e, UIEvent::StreamDelta(s) if s == "queued reply")),
|
|
||||||
"queued message was not processed; events: {all_events:?}"
|
|
||||||
);
|
|
||||||
|
|
||||||
action_tx.send(UserAction::Quit).await.unwrap();
|
|
||||||
handle.await.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
// -- network policy toggle ------------------------------------------------
|
// -- network policy toggle ------------------------------------------------
|
||||||
|
|
||||||
/// SetNetworkPolicy sends a NetworkPolicyChanged event back to the TUI.
|
/// SetNetworkPolicy sends a NetworkPolicyChanged event back to the TUI.
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn set_network_policy_sends_event() {
|
async fn set_network_policy_sends_event() {
|
||||||
// Provider that returns immediately to avoid blocking.
|
|
||||||
struct NeverCalledProvider;
|
struct NeverCalledProvider;
|
||||||
impl ModelProvider for NeverCalledProvider {
|
impl ModelProvider for NeverCalledProvider {
|
||||||
fn stream<'a>(
|
fn stream<'a>(
|
||||||
|
|
@ -1070,7 +886,7 @@ mod tests {
|
||||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||||
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(8);
|
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(8);
|
||||||
|
|
||||||
let orch = test_orchestrator(NeverCalledProvider, action_rx, event_tx);
|
let orch = test_orchestrator(NeverCalledProvider, action_rx, event_tx).await;
|
||||||
let handle = tokio::spawn(orch.run());
|
let handle = tokio::spawn(orch.run());
|
||||||
|
|
||||||
action_tx
|
action_tx
|
||||||
|
|
|
||||||
|
|
@ -26,8 +26,6 @@ pub enum StreamEvent {
|
||||||
pub enum UserAction {
|
pub enum UserAction {
|
||||||
/// The user has submitted a message.
|
/// The user has submitted a message.
|
||||||
SendMessage(String),
|
SendMessage(String),
|
||||||
/// The user has responded to a tool approval request.
|
|
||||||
ToolApprovalResponse { tool_use_id: String, approved: bool },
|
|
||||||
/// The user has requested to quit.
|
/// The user has requested to quit.
|
||||||
Quit,
|
Quit,
|
||||||
/// The user has requested to clear conversation history.
|
/// The user has requested to clear conversation history.
|
||||||
|
|
@ -67,13 +65,7 @@ pub enum ToolDisplay {
|
||||||
pub enum UIEvent {
|
pub enum UIEvent {
|
||||||
/// A text chunk to append to the current assistant message.
|
/// A text chunk to append to the current assistant message.
|
||||||
StreamDelta(String),
|
StreamDelta(String),
|
||||||
/// A tool requires user approval before execution.
|
/// A tool is being executed.
|
||||||
ToolApprovalRequest {
|
|
||||||
tool_use_id: String,
|
|
||||||
tool_name: String,
|
|
||||||
display: ToolDisplay,
|
|
||||||
},
|
|
||||||
/// A tool is being executed (informational, after approval or auto-approve).
|
|
||||||
ToolExecuting {
|
ToolExecuting {
|
||||||
tool_use_id: String,
|
tool_use_id: String,
|
||||||
tool_name: String,
|
tool_name: String,
|
||||||
|
|
@ -121,7 +113,9 @@ pub enum Role {
|
||||||
///
|
///
|
||||||
/// This is provider-agnostic -- the provider serializes it into the
|
/// This is provider-agnostic -- the provider serializes it into the
|
||||||
/// format required by its API (e.g. the `tools` array for Anthropic).
|
/// format required by its API (e.g. the `tools` array for Anthropic).
|
||||||
#[derive(Debug, Clone, Serialize)]
|
/// `Deserialize` is required so that `ToolDefinition` can cross the tarpc
|
||||||
|
/// channel as a return value from `ExecutorServer::list_tools`.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct ToolDefinition {
|
pub struct ToolDefinition {
|
||||||
/// The tool name the model will use in `tool_use` blocks.
|
/// The tool name the model will use in `tool_use` blocks.
|
||||||
pub name: String,
|
pub name: String,
|
||||||
|
|
|
||||||
176
src/executor/mod.rs
Normal file
176
src/executor/mod.rs
Normal file
|
|
@ -0,0 +1,176 @@
|
||||||
|
//! tarpc-based tool executor: owns the [`ToolRegistry`] and [`Sandbox`],
|
||||||
|
//! exposes them to the orchestrator through an in-process RPC channel.
|
||||||
|
//!
|
||||||
|
//! # Architecture
|
||||||
|
//!
|
||||||
|
//! ```text
|
||||||
|
//! Orchestrator --call_tool(name, input)--> ExecutorServer --execute()--> Tool
|
||||||
|
//! |
|
||||||
|
//! v
|
||||||
|
//! Sandbox
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! The split exists so that each sub-agent (Phase 7) can get its own
|
||||||
|
//! `ExecutorClient` backed by an independent sandbox policy. For now there is
|
||||||
|
//! one executor shared by the single orchestrator.
|
||||||
|
//!
|
||||||
|
//! Transport: [`tarpc::transport::channel::unbounded`] (in-process, zero
|
||||||
|
//! serialization overhead for benchmarks, no network required).
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use futures::StreamExt as _;
|
||||||
|
use tarpc::server::Channel as _;
|
||||||
|
|
||||||
|
use crate::core::types::ToolDefinition;
|
||||||
|
use crate::sandbox::Sandbox;
|
||||||
|
use crate::tools::{ToolOutput, ToolRegistry};
|
||||||
|
|
||||||
|
/// tarpc service definition for the tool executor.
|
||||||
|
///
|
||||||
|
/// The two methods map cleanly to the two operations the orchestrator needs:
|
||||||
|
/// list_tools -> populate the provider's `tools` array once at startup
|
||||||
|
/// call_tool -> dispatch a single tool-use block during a turn
|
||||||
|
#[tarpc::service]
|
||||||
|
pub trait ToolExecutor {
|
||||||
|
/// Return the definitions of all registered tools.
|
||||||
|
async fn list_tools() -> Vec<ToolDefinition>;
|
||||||
|
/// Execute a tool by name with the given JSON input.
|
||||||
|
///
|
||||||
|
/// Returns `Ok(output)` on success or a human-readable `Err(msg)` when
|
||||||
|
/// the tool name is unknown or execution fails.
|
||||||
|
async fn call_tool(name: String, input: serde_json::Value) -> Result<ToolOutput, String>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Server-side state: the registry of tools and the sandbox they run in.
|
||||||
|
///
|
||||||
|
/// `Clone` is derived because tarpc clones the server struct for each
|
||||||
|
/// in-flight request. Both fields are `Arc`-wrapped so clones are cheap.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct ExecutorServer {
|
||||||
|
registry: Arc<ToolRegistry>,
|
||||||
|
sandbox: Arc<Sandbox>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ExecutorServer {
|
||||||
|
/// Wrap `registry` and `sandbox` in an `ExecutorServer` ready to serve.
|
||||||
|
pub fn new(registry: ToolRegistry, sandbox: Sandbox) -> Self {
|
||||||
|
Self {
|
||||||
|
registry: Arc::new(registry),
|
||||||
|
sandbox: Arc::new(sandbox),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolExecutor for ExecutorServer {
|
||||||
|
async fn list_tools(self, _ctx: tarpc::context::Context) -> Vec<ToolDefinition> {
|
||||||
|
self.registry.definitions()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call_tool(
|
||||||
|
self,
|
||||||
|
_ctx: tarpc::context::Context,
|
||||||
|
name: String,
|
||||||
|
input: serde_json::Value,
|
||||||
|
) -> Result<ToolOutput, String> {
|
||||||
|
match self.registry.get(&name) {
|
||||||
|
Some(tool) => tool
|
||||||
|
.execute(&input, self.sandbox.as_ref())
|
||||||
|
.await
|
||||||
|
.map_err(|e| e.to_string()),
|
||||||
|
None => Err(format!("unknown tool: {name}")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Spawn an `ExecutorServer` on an in-process tarpc channel and return the
|
||||||
|
/// client end.
|
||||||
|
///
|
||||||
|
/// The server task runs as a detached tokio task and lives as long as the
|
||||||
|
/// returned client holds the channel open. No network sockets are involved.
|
||||||
|
pub fn spawn_local(server: ExecutorServer) -> ToolExecutorClient {
|
||||||
|
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||||
|
let server_channel = tarpc::server::BaseChannel::with_defaults(server_transport);
|
||||||
|
tokio::spawn(
|
||||||
|
server_channel
|
||||||
|
.execute(server.serve())
|
||||||
|
.for_each(|response| async move {
|
||||||
|
tokio::spawn(response);
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
ToolExecutorClient::new(tarpc::client::Config::default(), client_transport).spawn()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::sandbox::test_sandbox;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
// -- direct server (no tarpc) -------------------------------------------------
|
||||||
|
|
||||||
|
/// `list_tools` returns a definition for every tool in the registry.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn list_tools_returns_all_definitions() {
|
||||||
|
let dir = TempDir::new().unwrap();
|
||||||
|
let server = ExecutorServer::new(ToolRegistry::default_tools(), test_sandbox(dir.path()));
|
||||||
|
let ctx = tarpc::context::current();
|
||||||
|
let defs = server.list_tools(ctx).await;
|
||||||
|
assert_eq!(defs.len(), 4);
|
||||||
|
assert!(defs.iter().any(|d| d.name == "read_file"));
|
||||||
|
assert!(defs.iter().any(|d| d.name == "write_file"));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `call_tool` for `read_file` returns the file contents.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn call_tool_read_file_returns_content() {
|
||||||
|
let dir = TempDir::new().unwrap();
|
||||||
|
std::fs::write(dir.path().join("hello.txt"), "world").unwrap();
|
||||||
|
let server = ExecutorServer::new(ToolRegistry::default_tools(), test_sandbox(dir.path()));
|
||||||
|
let ctx = tarpc::context::current();
|
||||||
|
let result = server
|
||||||
|
.call_tool(
|
||||||
|
ctx,
|
||||||
|
"read_file".to_string(),
|
||||||
|
serde_json::json!({"path": "hello.txt"}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let output = result.unwrap();
|
||||||
|
assert_eq!(output.content, "world");
|
||||||
|
assert!(!output.is_error);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `call_tool` for an unknown tool returns an `Err` with a descriptive message.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn call_tool_unknown_tool_returns_err() {
|
||||||
|
let dir = TempDir::new().unwrap();
|
||||||
|
let server = ExecutorServer::new(ToolRegistry::default_tools(), test_sandbox(dir.path()));
|
||||||
|
let ctx = tarpc::context::current();
|
||||||
|
let result = server
|
||||||
|
.call_tool(ctx, "nonexistent".to_string(), serde_json::json!({}))
|
||||||
|
.await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().contains("unknown tool"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// -- in-process tarpc channel (spawn_local) ------------------------------------
|
||||||
|
|
||||||
|
/// `spawn_local` + `call_tool` over the in-process channel reads a real file.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn spawn_local_call_tool_read_file() {
|
||||||
|
let dir = TempDir::new().unwrap();
|
||||||
|
std::fs::write(dir.path().join("data.txt"), "from channel").unwrap();
|
||||||
|
let server = ExecutorServer::new(ToolRegistry::default_tools(), test_sandbox(dir.path()));
|
||||||
|
let client = spawn_local(server);
|
||||||
|
let result = client
|
||||||
|
.call_tool(
|
||||||
|
tarpc::context::current(),
|
||||||
|
"read_file".to_string(),
|
||||||
|
serde_json::json!({"path": "data.txt"}),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("RPC call failed");
|
||||||
|
let output = result.unwrap();
|
||||||
|
assert_eq!(output.content, "from channel");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
mod app;
|
mod app;
|
||||||
mod core;
|
mod core;
|
||||||
|
mod executor;
|
||||||
mod provider;
|
mod provider;
|
||||||
mod sandbox;
|
mod sandbox;
|
||||||
mod tools;
|
mod tools;
|
||||||
|
|
|
||||||
|
|
@ -119,6 +119,10 @@ impl Sandbox {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set the network access policy.
|
/// Set the network access policy.
|
||||||
|
///
|
||||||
|
/// Currently not wired to a runtime RPC -- will be exposed via the executor
|
||||||
|
/// in a future phase.
|
||||||
|
#[allow(dead_code)]
|
||||||
pub fn set_network_allowed(&mut self, allowed: bool) {
|
pub fn set_network_allowed(&mut self, allowed: bool) {
|
||||||
tracing::info!(network_allowed = allowed, "sandbox network policy updated");
|
tracing::info!(network_allowed = allowed, "sandbox network policy updated");
|
||||||
self.policy.network_allowed = allowed;
|
self.policy.network_allowed = allowed;
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ use crate::sandbox::Sandbox;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
|
||||||
use super::{RiskLevel, Tool, ToolError, ToolOutput};
|
use super::{Tool, ToolError, ToolOutput};
|
||||||
|
|
||||||
/// Lists directory contents. Auto-approved (read-only).
|
/// Lists directory contents. Auto-approved (read-only).
|
||||||
pub struct ListDirectory;
|
pub struct ListDirectory;
|
||||||
|
|
@ -32,10 +32,6 @@ impl Tool for ListDirectory {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn risk_level(&self) -> RiskLevel {
|
|
||||||
RiskLevel::AutoApprove
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn execute(
|
async fn execute(
|
||||||
&self,
|
&self,
|
||||||
input: &serde_json::Value,
|
input: &serde_json::Value,
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
//! Tool system: trait, registry, risk classification, and built-in tools.
|
//! Tool system: trait, registry, and built-in tools.
|
||||||
//!
|
//!
|
||||||
//! All tools implement the [`Tool`] trait. The [`ToolRegistry`] collects them
|
//! All tools implement the [`Tool`] trait. The [`ToolRegistry`] collects them
|
||||||
//! and provides lookup by name plus generation of [`ToolDefinition`]s for the
|
//! and provides lookup by name plus generation of [`ToolDefinition`]s for the
|
||||||
|
|
@ -13,12 +13,16 @@ mod shell_exec;
|
||||||
mod write_file;
|
mod write_file;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::core::types::ToolDefinition;
|
use crate::core::types::ToolDefinition;
|
||||||
use crate::sandbox::Sandbox;
|
use crate::sandbox::Sandbox;
|
||||||
|
|
||||||
/// The output of a tool execution.
|
/// The output of a tool execution.
|
||||||
#[derive(Debug)]
|
///
|
||||||
|
/// Derives `Serialize + Deserialize` so it can cross the tarpc channel
|
||||||
|
/// between [`crate::executor::ExecutorServer`] and the orchestrator.
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct ToolOutput {
|
pub struct ToolOutput {
|
||||||
/// The text content returned to the model.
|
/// The text content returned to the model.
|
||||||
pub content: String,
|
pub content: String,
|
||||||
|
|
@ -26,15 +30,6 @@ pub struct ToolOutput {
|
||||||
pub is_error: bool,
|
pub is_error: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Risk classification for tool approval gating.
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
||||||
pub enum RiskLevel {
|
|
||||||
/// Safe to execute without user confirmation (e.g. read-only operations).
|
|
||||||
AutoApprove,
|
|
||||||
/// Requires explicit user approval before execution (e.g. writes, shell).
|
|
||||||
RequiresApproval,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A tool that the model can invoke.
|
/// A tool that the model can invoke.
|
||||||
///
|
///
|
||||||
/// All file I/O and process spawning must go through the [`Sandbox`] passed
|
/// All file I/O and process spawning must go through the [`Sandbox`] passed
|
||||||
|
|
@ -53,8 +48,6 @@ pub trait Tool: Send + Sync {
|
||||||
fn description(&self) -> &str;
|
fn description(&self) -> &str;
|
||||||
/// JSON Schema for the tool's input parameters.
|
/// JSON Schema for the tool's input parameters.
|
||||||
fn input_schema(&self) -> serde_json::Value;
|
fn input_schema(&self) -> serde_json::Value;
|
||||||
/// The risk level of this tool.
|
|
||||||
fn risk_level(&self) -> RiskLevel;
|
|
||||||
/// Execute the tool with the given input, confined by `sandbox`.
|
/// Execute the tool with the given input, confined by `sandbox`.
|
||||||
async fn execute(
|
async fn execute(
|
||||||
&self,
|
&self,
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ use crate::sandbox::Sandbox;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
|
||||||
use super::{RiskLevel, Tool, ToolError, ToolOutput};
|
use super::{Tool, ToolError, ToolOutput};
|
||||||
|
|
||||||
/// Reads file contents. Auto-approved (read-only).
|
/// Reads file contents. Auto-approved (read-only).
|
||||||
pub struct ReadFile;
|
pub struct ReadFile;
|
||||||
|
|
@ -32,10 +32,6 @@ impl Tool for ReadFile {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn risk_level(&self) -> RiskLevel {
|
|
||||||
RiskLevel::AutoApprove
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn execute(
|
async fn execute(
|
||||||
&self,
|
&self,
|
||||||
input: &serde_json::Value,
|
input: &serde_json::Value,
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ use crate::sandbox::Sandbox;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
|
||||||
use super::{RiskLevel, Tool, ToolError, ToolOutput};
|
use super::{Tool, ToolError, ToolOutput};
|
||||||
|
|
||||||
/// Executes a shell command. Requires user approval.
|
/// Executes a shell command. Requires user approval.
|
||||||
pub struct ShellExec;
|
pub struct ShellExec;
|
||||||
|
|
@ -32,10 +32,6 @@ impl Tool for ShellExec {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn risk_level(&self) -> RiskLevel {
|
|
||||||
RiskLevel::RequiresApproval
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn execute(
|
async fn execute(
|
||||||
&self,
|
&self,
|
||||||
input: &serde_json::Value,
|
input: &serde_json::Value,
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ use crate::sandbox::Sandbox;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
|
||||||
use super::{RiskLevel, Tool, ToolError, ToolOutput};
|
use super::{Tool, ToolError, ToolOutput};
|
||||||
|
|
||||||
/// Writes content to a file. Requires user approval.
|
/// Writes content to a file. Requires user approval.
|
||||||
pub struct WriteFile;
|
pub struct WriteFile;
|
||||||
|
|
@ -36,10 +36,6 @@ impl Tool for WriteFile {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn risk_level(&self) -> RiskLevel {
|
|
||||||
RiskLevel::RequiresApproval
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn execute(
|
async fn execute(
|
||||||
&self,
|
&self,
|
||||||
input: &serde_json::Value,
|
input: &serde_json::Value,
|
||||||
|
|
|
||||||
|
|
@ -16,14 +16,13 @@ use crate::tui::tool_display;
|
||||||
/// arrives, the handler searches `state.messages` for an existing entry with the
|
/// arrives, the handler searches `state.messages` for an existing entry with the
|
||||||
/// same `tool_use_id` and replaces its content rather than appending a new row.
|
/// same `tool_use_id` and replaces its content rather than appending a new row.
|
||||||
///
|
///
|
||||||
/// | Event | Effect |
|
/// | Event | Effect |
|
||||||
/// |------------------------|------------------------------------------------------------|
|
/// |-------------------|------------------------------------------------------------|
|
||||||
/// | `StreamDelta(s)` | Append `s` to last message if it's `Assistant`; else push |
|
/// | `StreamDelta(s)` | Append `s` to last message if it's `Assistant`; else push |
|
||||||
/// | `ToolApprovalRequest` | Push inline message with approval prompt, set pending |
|
/// | `ToolExecuting` | Push new executing message (or replace in-place) |
|
||||||
/// | `ToolExecuting` | Replace approval message in-place (or push new) |
|
/// | `ToolResult` | Replace executing message in-place (or push new) |
|
||||||
/// | `ToolResult` | Replace executing message in-place (or push new) |
|
/// | `TurnComplete` | No structural change; logged at debug level |
|
||||||
/// | `TurnComplete` | No structural change; logged at debug level |
|
/// | `Error(msg)` | Push `(Assistant, "[error] {msg}")` |
|
||||||
/// | `Error(msg)` | Push `(Assistant, "[error] {msg}")` |
|
|
||||||
pub(super) fn drain_ui_events(event_rx: &mut mpsc::Receiver<StampedEvent>, state: &mut AppState) {
|
pub(super) fn drain_ui_events(event_rx: &mut mpsc::Receiver<StampedEvent>, state: &mut AppState) {
|
||||||
while let Ok(stamped) = event_rx.try_recv() {
|
while let Ok(stamped) = event_rx.try_recv() {
|
||||||
// Discard events from before the most recent :clear.
|
// Discard events from before the most recent :clear.
|
||||||
|
|
@ -56,21 +55,6 @@ pub(super) fn drain_ui_events(event_rx: &mut mpsc::Receiver<StampedEvent>, state
|
||||||
}
|
}
|
||||||
state.content_changed = true;
|
state.content_changed = true;
|
||||||
}
|
}
|
||||||
UIEvent::ToolApprovalRequest {
|
|
||||||
tool_use_id,
|
|
||||||
tool_name,
|
|
||||||
display,
|
|
||||||
} => {
|
|
||||||
let mut content = tool_display::format_executing(&tool_name, &display);
|
|
||||||
content.push_str("\n[y] approve [n] deny");
|
|
||||||
state.messages.push(DisplayMessage {
|
|
||||||
role: Role::Assistant,
|
|
||||||
content,
|
|
||||||
tool_use_id: Some(tool_use_id.clone()),
|
|
||||||
});
|
|
||||||
state.pending_approval = Some(PendingApproval { tool_use_id });
|
|
||||||
state.content_changed = true;
|
|
||||||
}
|
|
||||||
UIEvent::ToolExecuting {
|
UIEvent::ToolExecuting {
|
||||||
tool_use_id,
|
tool_use_id,
|
||||||
tool_name,
|
tool_name,
|
||||||
|
|
@ -127,12 +111,6 @@ fn replace_or_push(state: &mut AppState, tool_use_id: &str, content: String) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A pending tool approval request waiting for user input.
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct PendingApproval {
|
|
||||||
pub tool_use_id: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
@ -179,29 +157,6 @@ mod tests {
|
||||||
assert_eq!(state.messages[1].content, "hello");
|
assert_eq!(state.messages[1].content, "hello");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn drain_tool_approval_sets_pending_and_adds_message() {
|
|
||||||
let (tx, mut rx) = tokio::sync::mpsc::channel(8);
|
|
||||||
let mut state = AppState::new();
|
|
||||||
tx.send(stamp(UIEvent::ToolApprovalRequest {
|
|
||||||
tool_use_id: "t1".to_string(),
|
|
||||||
tool_name: "shell_exec".to_string(),
|
|
||||||
display: ToolDisplay::ShellExec {
|
|
||||||
command: "cargo test".to_string(),
|
|
||||||
},
|
|
||||||
}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
drop(tx);
|
|
||||||
drain_ui_events(&mut rx, &mut state);
|
|
||||||
assert!(state.pending_approval.is_some());
|
|
||||||
assert_eq!(state.pending_approval.as_ref().unwrap().tool_use_id, "t1");
|
|
||||||
// Message should be inline with approval prompt.
|
|
||||||
assert_eq!(state.messages.len(), 1);
|
|
||||||
assert!(state.messages[0].content.contains("[y] approve"));
|
|
||||||
assert_eq!(state.messages[0].tool_use_id.as_deref(), Some("t1"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn drain_tool_result_replaces_existing_message() {
|
async fn drain_tool_result_replaces_existing_message() {
|
||||||
let (tx, mut rx) = tokio::sync::mpsc::channel(8);
|
let (tx, mut rx) = tokio::sync::mpsc::channel(8);
|
||||||
|
|
|
||||||
|
|
@ -13,8 +13,6 @@ pub(super) enum LoopControl {
|
||||||
Quit,
|
Quit,
|
||||||
/// The user ran `:clear`; wipe the conversation.
|
/// The user ran `:clear`; wipe the conversation.
|
||||||
ClearHistory,
|
ClearHistory,
|
||||||
/// The user responded to a tool approval prompt.
|
|
||||||
ToolApproval { tool_use_id: String, approved: bool },
|
|
||||||
/// The user ran `:net on` or `:net off`.
|
/// The user ran `:net on` or `:net off`.
|
||||||
SetNetworkPolicy(bool),
|
SetNetworkPolicy(bool),
|
||||||
}
|
}
|
||||||
|
|
@ -28,60 +26,6 @@ pub(super) fn handle_key(key: Option<KeyEvent>, state: &mut AppState) -> Option<
|
||||||
// Clear any transient status error on the next keypress.
|
// Clear any transient status error on the next keypress.
|
||||||
state.status_error = None;
|
state.status_error = None;
|
||||||
|
|
||||||
// If a tool approval is pending, intercept y/n but allow scrolling.
|
|
||||||
if let Some(approval) = &state.pending_approval {
|
|
||||||
let tool_use_id = approval.tool_use_id.clone();
|
|
||||||
let is_ctrl = key.modifiers.contains(KeyModifiers::CONTROL);
|
|
||||||
match key.code {
|
|
||||||
KeyCode::Char('y') | KeyCode::Char('Y') if !is_ctrl => {
|
|
||||||
state.pending_approval = None;
|
|
||||||
return Some(LoopControl::ToolApproval {
|
|
||||||
tool_use_id,
|
|
||||||
approved: true,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
KeyCode::Char('n') | KeyCode::Char('N') if !is_ctrl => {
|
|
||||||
state.pending_approval = None;
|
|
||||||
return Some(LoopControl::ToolApproval {
|
|
||||||
tool_use_id,
|
|
||||||
approved: false,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
// Allow scrolling while approval is pending.
|
|
||||||
KeyCode::Char('j') if !is_ctrl => {
|
|
||||||
state.scroll = state.scroll.saturating_add(1);
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
KeyCode::Char('k') if !is_ctrl => {
|
|
||||||
state.scroll = state.scroll.saturating_sub(1);
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
KeyCode::Char('G') if !is_ctrl => {
|
|
||||||
state.scroll = u16::MAX;
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
KeyCode::Char('g') if !is_ctrl => {
|
|
||||||
state.pending_keys.push('g');
|
|
||||||
if state.pending_keys.ends_with(&['g', 'g']) {
|
|
||||||
state.scroll = 0;
|
|
||||||
state.pending_keys.clear();
|
|
||||||
}
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
KeyCode::Char('d') if is_ctrl => {
|
|
||||||
let half = (state.viewport_height / 2).max(1);
|
|
||||||
state.scroll = state.scroll.saturating_add(half);
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
KeyCode::Char('u') if is_ctrl => {
|
|
||||||
let half = (state.viewport_height / 2).max(1);
|
|
||||||
state.scroll = state.scroll.saturating_sub(half);
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
_ => return None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ctrl+C quits from any mode.
|
// Ctrl+C quits from any mode.
|
||||||
if key.modifiers.contains(KeyModifiers::CONTROL) && key.code == KeyCode::Char('c') {
|
if key.modifiers.contains(KeyModifiers::CONTROL) && key.code == KeyCode::Char('c') {
|
||||||
return Some(LoopControl::Quit);
|
return Some(LoopControl::Quit);
|
||||||
|
|
|
||||||
|
|
@ -90,8 +90,6 @@ pub struct AppState {
|
||||||
pub viewport_height: u16,
|
pub viewport_height: u16,
|
||||||
/// Transient error message shown in the status bar, cleared on next keypress.
|
/// Transient error message shown in the status bar, cleared on next keypress.
|
||||||
pub status_error: Option<String>,
|
pub status_error: Option<String>,
|
||||||
/// A tool approval request waiting for user input (y/n).
|
|
||||||
pub pending_approval: Option<events::PendingApproval>,
|
|
||||||
/// Whether the sandbox is in yolo (unsandboxed) mode.
|
/// Whether the sandbox is in yolo (unsandboxed) mode.
|
||||||
pub sandbox_yolo: bool,
|
pub sandbox_yolo: bool,
|
||||||
/// Whether network access is currently allowed.
|
/// Whether network access is currently allowed.
|
||||||
|
|
@ -115,7 +113,6 @@ impl AppState {
|
||||||
pending_keys: Vec::new(),
|
pending_keys: Vec::new(),
|
||||||
viewport_height: 0,
|
viewport_height: 0,
|
||||||
status_error: None,
|
status_error: None,
|
||||||
pending_approval: None,
|
|
||||||
sandbox_yolo: false,
|
sandbox_yolo: false,
|
||||||
network_allowed: false,
|
network_allowed: false,
|
||||||
epoch: 0,
|
epoch: 0,
|
||||||
|
|
@ -224,17 +221,6 @@ pub async fn run(
|
||||||
.send(UserAction::ClearHistory { epoch: state.epoch })
|
.send(UserAction::ClearHistory { epoch: state.epoch })
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
Some(input::LoopControl::ToolApproval {
|
|
||||||
tool_use_id,
|
|
||||||
approved,
|
|
||||||
}) => {
|
|
||||||
let _ = action_tx
|
|
||||||
.send(UserAction::ToolApprovalResponse {
|
|
||||||
tool_use_id,
|
|
||||||
approved,
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
Some(input::LoopControl::SetNetworkPolicy(allowed)) => {
|
Some(input::LoopControl::SetNetworkPolicy(allowed)) => {
|
||||||
let _ = action_tx.send(UserAction::SetNetworkPolicy(allowed)).await;
|
let _ = action_tx.send(UserAction::SetNetworkPolicy(allowed)).await;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -425,37 +425,6 @@ mod tests {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn render_approval_inline_visible() {
|
|
||||||
let backend = TestBackend::new(80, 24);
|
|
||||||
let mut terminal = Terminal::new(backend).unwrap();
|
|
||||||
let mut state = AppState::new();
|
|
||||||
// Inline approval message in the message stream.
|
|
||||||
state.messages.push(DisplayMessage {
|
|
||||||
role: Role::Assistant,
|
|
||||||
content: "$ cargo test\n[y] approve [n] deny".to_string(),
|
|
||||||
tool_use_id: Some("t1".to_string()),
|
|
||||||
});
|
|
||||||
state.pending_approval = Some(super::super::events::PendingApproval {
|
|
||||||
tool_use_id: "t1".to_string(),
|
|
||||||
});
|
|
||||||
terminal.draw(|frame| render(frame, &state)).unwrap();
|
|
||||||
let buf = terminal.backend().buffer().clone();
|
|
||||||
let all_text: String = buf
|
|
||||||
.content()
|
|
||||||
.iter()
|
|
||||||
.map(|c| c.symbol().to_string())
|
|
||||||
.collect();
|
|
||||||
assert!(
|
|
||||||
all_text.contains("approve"),
|
|
||||||
"expected approval prompt in buffer"
|
|
||||||
);
|
|
||||||
assert!(
|
|
||||||
all_text.contains("cargo test"),
|
|
||||||
"expected tool info in buffer"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn render_status_bar_shows_net_off() {
|
fn render_status_bar_shows_net_off() {
|
||||||
let backend = TestBackend::new(80, 24);
|
let backend = TestBackend::new(80, 24);
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue