Move tool executor behind tarpc. (#12)
Implement tool executor as a separate tarpc service to improve isolation and create sandboxing opportunities. Reviewed-on: #12 Co-authored-by: Drew Galbraith <drew@tiramisu.one> Co-committed-by: Drew Galbraith <drew@tiramisu.one>
This commit is contained in:
parent
7420755800
commit
312a5866f7
17 changed files with 465 additions and 476 deletions
178
Cargo.lock
generated
178
Cargo.lock
generated
|
|
@ -253,6 +253,21 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-channel"
|
||||
version = "0.5.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.8.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
|
||||
|
||||
[[package]]
|
||||
name = "crossterm"
|
||||
version = "0.29.0"
|
||||
|
|
@ -696,13 +711,19 @@ dependencies = [
|
|||
"futures-core",
|
||||
"futures-sink",
|
||||
"http",
|
||||
"indexmap",
|
||||
"indexmap 2.13.0",
|
||||
"slab",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.12.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.15.5"
|
||||
|
|
@ -774,6 +795,12 @@ version = "1.10.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
|
||||
|
||||
[[package]]
|
||||
name = "humantime"
|
||||
version = "2.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "135b12329e5e3ce057a9f972339ea52bc954fe1e9358ef27f95e89716fbc5424"
|
||||
|
||||
[[package]]
|
||||
name = "hyper"
|
||||
version = "1.8.1"
|
||||
|
|
@ -951,6 +978,16 @@ dependencies = [
|
|||
"icu_properties",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "1.9.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"hashbrown 0.12.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.13.0"
|
||||
|
|
@ -1304,6 +1341,49 @@ version = "0.2.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe"
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry"
|
||||
version = "0.18.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "69d6c3d7288a106c0a363e4b0e8d308058d56902adefb16f4936f417ffef086e"
|
||||
dependencies = [
|
||||
"opentelemetry_api",
|
||||
"opentelemetry_sdk",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry_api"
|
||||
version = "0.18.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c24f96e21e7acc813c7a8394ee94978929db2bcc46cf6b5014fc612bf7760c22"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"indexmap 1.9.3",
|
||||
"js-sys",
|
||||
"once_cell",
|
||||
"pin-project-lite",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry_sdk"
|
||||
version = "0.18.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1ca41c4933371b61c2a2f214bf16931499af4ec90543604ec828f7a625c09113"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"crossbeam-channel",
|
||||
"futures-channel",
|
||||
"futures-executor",
|
||||
"futures-util",
|
||||
"once_cell",
|
||||
"opentelemetry_api",
|
||||
"percent-encoding",
|
||||
"rand 0.8.5",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ordered-float"
|
||||
version = "4.6.0"
|
||||
|
|
@ -1437,6 +1517,26 @@ dependencies = [
|
|||
"siphasher",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project"
|
||||
version = "1.1.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f1749c7ed4bcaf4c3d0a3efc28538844fb29bcdd7d2b67b2be7e20ba861ff517"
|
||||
dependencies = [
|
||||
"pin-project-internal",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-internal"
|
||||
version = "1.1.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d9b20ed30f105399776b9c883e68e536ef602a16ae6f596d2c473591d6ad64c6"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-lite"
|
||||
version = "0.2.16"
|
||||
|
|
@ -1575,6 +1675,8 @@ version = "0.8.5"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"rand_chacha 0.3.1",
|
||||
"rand_core 0.6.4",
|
||||
]
|
||||
|
||||
|
|
@ -1584,10 +1686,20 @@ version = "0.9.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1"
|
||||
dependencies = [
|
||||
"rand_chacha",
|
||||
"rand_chacha 0.9.0",
|
||||
"rand_core 0.9.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_chacha"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
|
||||
dependencies = [
|
||||
"ppv-lite86",
|
||||
"rand_core 0.6.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_chacha"
|
||||
version = "0.9.0"
|
||||
|
|
@ -1603,6 +1715,9 @@ name = "rand_core"
|
|||
version = "0.6.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
||||
dependencies = [
|
||||
"getrandom 0.2.17",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
|
|
@ -2087,6 +2202,7 @@ dependencies = [
|
|||
"serde",
|
||||
"serde_json",
|
||||
"similar",
|
||||
"tarpc",
|
||||
"tempfile",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
|
|
@ -2224,6 +2340,39 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tarpc"
|
||||
version = "0.34.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "93a1870169fb9490fb3b37df7f50782986475c33cb90955f9f9b9ae659124200"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"fnv",
|
||||
"futures",
|
||||
"humantime",
|
||||
"opentelemetry",
|
||||
"pin-project",
|
||||
"rand 0.8.5",
|
||||
"static_assertions",
|
||||
"tarpc-plugins",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"tracing-opentelemetry",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tarpc-plugins"
|
||||
version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ad8302bea2fb8a2b01b025d23414b0b4ed32a783b95e5d818c3320a8bc4baada"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tempfile"
|
||||
version = "3.26.0"
|
||||
|
|
@ -2443,6 +2592,7 @@ dependencies = [
|
|||
"futures-core",
|
||||
"futures-sink",
|
||||
"pin-project-lite",
|
||||
"slab",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
|
|
@ -2497,6 +2647,7 @@ version = "0.1.44"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100"
|
||||
dependencies = [
|
||||
"log",
|
||||
"pin-project-lite",
|
||||
"tracing-attributes",
|
||||
"tracing-core",
|
||||
|
|
@ -2534,6 +2685,19 @@ dependencies = [
|
|||
"tracing-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-opentelemetry"
|
||||
version = "0.18.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "21ebb87a95ea13271332df069020513ab70bdb5637ca42d6e492dc3bbbad48de"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"opentelemetry",
|
||||
"tracing",
|
||||
"tracing-core",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-subscriber"
|
||||
version = "0.3.22"
|
||||
|
|
@ -2787,7 +2951,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"indexmap",
|
||||
"indexmap 2.13.0",
|
||||
"wasm-encoder",
|
||||
"wasmparser",
|
||||
]
|
||||
|
|
@ -2813,7 +2977,7 @@ checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe"
|
|||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"hashbrown 0.15.5",
|
||||
"indexmap",
|
||||
"indexmap 2.13.0",
|
||||
"semver",
|
||||
]
|
||||
|
||||
|
|
@ -3234,7 +3398,7 @@ checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21"
|
|||
dependencies = [
|
||||
"anyhow",
|
||||
"heck",
|
||||
"indexmap",
|
||||
"indexmap 2.13.0",
|
||||
"prettyplease",
|
||||
"syn 2.0.117",
|
||||
"wasm-metadata",
|
||||
|
|
@ -3265,7 +3429,7 @@ checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2"
|
|||
dependencies = [
|
||||
"anyhow",
|
||||
"bitflags 2.11.0",
|
||||
"indexmap",
|
||||
"indexmap 2.13.0",
|
||||
"log",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
|
|
@ -3284,7 +3448,7 @@ checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736"
|
|||
dependencies = [
|
||||
"anyhow",
|
||||
"id-arena",
|
||||
"indexmap",
|
||||
"indexmap 2.13.0",
|
||||
"log",
|
||||
"semver",
|
||||
"serde",
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ futures = "0.3"
|
|||
async-trait = "0.1"
|
||||
similar = "2"
|
||||
landlock = "0.4"
|
||||
tarpc = { version = "0.34", features = ["tokio1"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.26.0"
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ use tokio::sync::mpsc;
|
|||
|
||||
use crate::core::orchestrator::Orchestrator;
|
||||
use crate::core::types::{StampedEvent, UserAction};
|
||||
use crate::executor::{ExecutorServer, spawn_local};
|
||||
use crate::provider::ClaudeProvider;
|
||||
use crate::sandbox::policy::SandboxPolicy;
|
||||
use crate::sandbox::{EnforcementMode, Sandbox};
|
||||
|
|
@ -87,9 +88,10 @@ pub async fn run(project_dir: &Path, yolo: bool) -> anyhow::Result<()> {
|
|||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(CHANNEL_CAP);
|
||||
let (event_tx, event_rx) = mpsc::channel::<StampedEvent>(CHANNEL_CAP);
|
||||
|
||||
// -- Tools & Orchestrator (background task) ------------------------------------
|
||||
let tool_registry = ToolRegistry::default_tools();
|
||||
let orch = Orchestrator::new(provider, tool_registry, sandbox, action_rx, event_tx);
|
||||
// -- Executor & Orchestrator (background task) ---------------------------------
|
||||
let registry = ToolRegistry::default_tools();
|
||||
let executor = spawn_local(ExecutorServer::new(registry, sandbox));
|
||||
let orch = Orchestrator::new(provider, executor, action_rx, event_tx).await;
|
||||
tokio::spawn(orch.run());
|
||||
|
||||
// -- TUI (foreground task) ----------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -7,9 +7,9 @@ use crate::core::types::{
|
|||
ContentBlock, ConversationMessage, Role, StampedEvent, StreamEvent, ToolDefinition,
|
||||
ToolDisplay, UIEvent, UserAction,
|
||||
};
|
||||
use crate::executor::ToolExecutorClient;
|
||||
use crate::provider::ModelProvider;
|
||||
use crate::sandbox::Sandbox;
|
||||
use crate::tools::{RiskLevel, ToolOutput, ToolRegistry};
|
||||
use crate::tools::ToolOutput;
|
||||
|
||||
/// Accumulates data for a single tool-use block while it is being streamed.
|
||||
///
|
||||
|
|
@ -64,6 +64,9 @@ impl TryFrom<ActiveToolUse> for ContentBlock {
|
|||
/// |
|
||||
/// v
|
||||
/// ModelProvider (SSE stream)
|
||||
/// |
|
||||
/// v
|
||||
/// ExecutorClient (in-process tarpc)
|
||||
/// ```
|
||||
///
|
||||
/// # Event loop
|
||||
|
|
@ -107,35 +110,39 @@ fn truncate(s: &str, max_len: usize) -> String {
|
|||
pub struct Orchestrator<P> {
|
||||
history: ConversationHistory,
|
||||
provider: P,
|
||||
tool_registry: ToolRegistry,
|
||||
sandbox: Sandbox,
|
||||
/// tarpc client to the tool executor (in-process channel).
|
||||
executor: ToolExecutorClient,
|
||||
/// Cached tool definitions fetched from the executor at construction.
|
||||
tool_definitions: Vec<ToolDefinition>,
|
||||
action_rx: mpsc::Receiver<UserAction>,
|
||||
event_tx: mpsc::Sender<StampedEvent>,
|
||||
/// Messages typed by the user while an approval prompt is open. They are
|
||||
/// queued here and replayed as new turns once the current turn completes.
|
||||
queued_messages: Vec<String>,
|
||||
/// Monotonic epoch incremented on each `:clear`. Events are tagged with
|
||||
/// this value so the TUI can discard stale in-flight messages.
|
||||
epoch: u64,
|
||||
}
|
||||
|
||||
impl<P: ModelProvider> Orchestrator<P> {
|
||||
/// Construct an orchestrator using the given provider and channel endpoints.
|
||||
pub fn new(
|
||||
/// Construct an orchestrator using the given provider and executor client.
|
||||
///
|
||||
/// Calls `executor.list_tools` once at startup to populate the cached
|
||||
/// `tool_definitions` that are passed to the model on every turn.
|
||||
pub async fn new(
|
||||
provider: P,
|
||||
tool_registry: ToolRegistry,
|
||||
sandbox: Sandbox,
|
||||
executor: ToolExecutorClient,
|
||||
action_rx: mpsc::Receiver<UserAction>,
|
||||
event_tx: mpsc::Sender<StampedEvent>,
|
||||
) -> Self {
|
||||
let tool_definitions = executor
|
||||
.list_tools(tarpc::context::current())
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
Self {
|
||||
history: ConversationHistory::new(),
|
||||
provider,
|
||||
tool_registry,
|
||||
sandbox,
|
||||
executor,
|
||||
tool_definitions,
|
||||
action_rx,
|
||||
event_tx,
|
||||
queued_messages: Vec::new(),
|
||||
epoch: 0,
|
||||
}
|
||||
}
|
||||
|
|
@ -241,11 +248,9 @@ impl<P: ModelProvider> Orchestrator<P> {
|
|||
/// loop back until the model produces a text-only response or we hit the
|
||||
/// iteration limit.
|
||||
async fn run_turn(&mut self) {
|
||||
let tool_defs = self.tool_registry.definitions();
|
||||
|
||||
for _ in 0..MAX_TOOL_ITERATIONS {
|
||||
let messages = self.history.messages().to_vec();
|
||||
let result = self.consume_stream(&messages, &tool_defs).await;
|
||||
let result = self.consume_stream(&messages, &self.tool_definitions).await;
|
||||
|
||||
match result {
|
||||
StreamResult::Error(msg) => {
|
||||
|
|
@ -273,12 +278,11 @@ impl<P: ModelProvider> Orchestrator<P> {
|
|||
// mutability (Arc<Mutex<...>> around event_tx, action_rx,
|
||||
// history) so that multiple futures can borrow self
|
||||
// simultaneously via futures::future::join_all.
|
||||
// Execute each tool-use block and collect results.
|
||||
let mut tool_results: Vec<ContentBlock> = Vec::new();
|
||||
|
||||
for block in &blocks {
|
||||
if let ContentBlock::ToolUse { id, name, input } = block {
|
||||
let result = self.execute_tool_with_approval(id, name, input).await;
|
||||
let result = self.execute_tool(id, name, input).await;
|
||||
tool_results.push(ContentBlock::ToolResult {
|
||||
tool_use_id: id.clone(),
|
||||
content: result.content,
|
||||
|
|
@ -292,7 +296,6 @@ impl<P: ModelProvider> Orchestrator<P> {
|
|||
role: Role::User,
|
||||
content: tool_results,
|
||||
});
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -307,16 +310,17 @@ impl<P: ModelProvider> Orchestrator<P> {
|
|||
///
|
||||
/// Matches on known tool names to extract structured fields; falls back to
|
||||
/// `Generic` with a JSON summary for anything else.
|
||||
async fn build_tool_display(&self, tool_name: &str, input: &serde_json::Value) -> ToolDisplay {
|
||||
fn build_tool_display(&self, tool_name: &str, input: &serde_json::Value) -> ToolDisplay {
|
||||
match tool_name {
|
||||
"write_file" => {
|
||||
let path = input["path"].as_str().unwrap_or("<unknown>").to_string();
|
||||
let new_content = input["content"].as_str().unwrap_or("").to_string();
|
||||
// Try to read existing content for diffing.
|
||||
let old_content = self.sandbox.read_file(&path).await.ok();
|
||||
// old_content is not available here (sandbox lives in the executor);
|
||||
// the diff is omitted at the display stage. The executor still
|
||||
// reads and writes through the sandbox as normal.
|
||||
ToolDisplay::WriteFile {
|
||||
path,
|
||||
old_content,
|
||||
old_content: None,
|
||||
new_content,
|
||||
}
|
||||
}
|
||||
|
|
@ -349,9 +353,6 @@ impl<P: ModelProvider> Orchestrator<P> {
|
|||
"write_file" => {
|
||||
let path = input["path"].as_str().unwrap_or("<unknown>").to_string();
|
||||
let new_content = input["content"].as_str().unwrap_or("").to_string();
|
||||
// For results, old_content isn't available post-write. We already
|
||||
// showed the diff at approval/executing time; the result just
|
||||
// confirms the byte count via the display formatter.
|
||||
ToolDisplay::WriteFile {
|
||||
path,
|
||||
old_content: None,
|
||||
|
|
@ -383,65 +384,33 @@ impl<P: ModelProvider> Orchestrator<P> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Execute a single tool, handling approval if needed.
|
||||
///
|
||||
/// For auto-approve tools, executes immediately. For tools requiring
|
||||
/// approval, sends a request to the TUI and waits for a response.
|
||||
async fn execute_tool_with_approval(
|
||||
&mut self,
|
||||
/// Execute a single tool via the executor client, emitting `ToolExecuting`
|
||||
/// before the call and `ToolResult` after.
|
||||
async fn execute_tool(
|
||||
&self,
|
||||
tool_use_id: &str,
|
||||
tool_name: &str,
|
||||
input: &serde_json::Value,
|
||||
) -> ToolOutput {
|
||||
// Extract tool info upfront to avoid holding a borrow on self across
|
||||
// the mutable wait_for_approval call.
|
||||
let risk = match self.tool_registry.get(tool_name) {
|
||||
Some(t) => t.risk_level(),
|
||||
None => {
|
||||
return ToolOutput {
|
||||
content: format!("unknown tool: {tool_name}"),
|
||||
is_error: true,
|
||||
};
|
||||
}
|
||||
};
|
||||
let display = self.build_tool_display(tool_name, input);
|
||||
self.send(UIEvent::ToolExecuting {
|
||||
tool_use_id: tool_use_id.to_string(),
|
||||
tool_name: tool_name.to_string(),
|
||||
display,
|
||||
})
|
||||
.await;
|
||||
|
||||
let display = self.build_tool_display(tool_name, input).await;
|
||||
let rpc_result = self
|
||||
.executor
|
||||
.call_tool(
|
||||
tarpc::context::current(),
|
||||
tool_name.to_string(),
|
||||
input.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Check approval.
|
||||
let approved = match risk {
|
||||
RiskLevel::AutoApprove => {
|
||||
self.send(UIEvent::ToolExecuting {
|
||||
tool_use_id: tool_use_id.to_string(),
|
||||
tool_name: tool_name.to_string(),
|
||||
display,
|
||||
})
|
||||
.await;
|
||||
true
|
||||
}
|
||||
RiskLevel::RequiresApproval => {
|
||||
self.send(UIEvent::ToolApprovalRequest {
|
||||
tool_use_id: tool_use_id.to_string(),
|
||||
tool_name: tool_name.to_string(),
|
||||
display,
|
||||
})
|
||||
.await;
|
||||
|
||||
// Wait for approval response from TUI.
|
||||
self.wait_for_approval(tool_use_id).await
|
||||
}
|
||||
};
|
||||
|
||||
if !approved {
|
||||
return ToolOutput {
|
||||
content: "tool execution denied by user".to_string(),
|
||||
is_error: true,
|
||||
};
|
||||
}
|
||||
|
||||
// Re-fetch tool for execution (borrow was released above).
|
||||
let tool = self.tool_registry.get(tool_name).unwrap();
|
||||
match tool.execute(input, &self.sandbox).await {
|
||||
Ok(output) => {
|
||||
match rpc_result {
|
||||
Ok(Ok(output)) => {
|
||||
let result_display = self.build_result_display(tool_name, input, &output.content);
|
||||
self.send(UIEvent::ToolResult {
|
||||
tool_use_id: tool_use_id.to_string(),
|
||||
|
|
@ -452,8 +421,23 @@ impl<P: ModelProvider> Orchestrator<P> {
|
|||
.await;
|
||||
output
|
||||
}
|
||||
Err(e) => {
|
||||
let msg = e.to_string();
|
||||
Ok(Err(msg)) => {
|
||||
self.send(UIEvent::ToolResult {
|
||||
tool_use_id: tool_use_id.to_string(),
|
||||
tool_name: tool_name.to_string(),
|
||||
display: ToolDisplay::Generic {
|
||||
summary: msg.clone(),
|
||||
},
|
||||
is_error: true,
|
||||
})
|
||||
.await;
|
||||
ToolOutput {
|
||||
content: msg,
|
||||
is_error: true,
|
||||
}
|
||||
}
|
||||
Err(rpc_err) => {
|
||||
let msg = format!("executor RPC error: {rpc_err}");
|
||||
self.send(UIEvent::ToolResult {
|
||||
tool_use_id: tool_use_id.to_string(),
|
||||
tool_name: tool_name.to_string(),
|
||||
|
|
@ -471,40 +455,6 @@ impl<P: ModelProvider> Orchestrator<P> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Wait for a `ToolApprovalResponse` matching `tool_use_id` from the TUI.
|
||||
///
|
||||
/// Any `SendMessage` actions that arrive during the wait are pushed onto
|
||||
/// `self.queued_messages` rather than discarded. They are replayed as new
|
||||
/// turns once the current tool-use loop completes (see `run()`).
|
||||
///
|
||||
/// Returns `true` if approved, `false` if denied or the channel closes.
|
||||
async fn wait_for_approval(&mut self, tool_use_id: &str) -> bool {
|
||||
while let Some(action) = self.action_rx.recv().await {
|
||||
match action {
|
||||
UserAction::ToolApprovalResponse {
|
||||
tool_use_id: id,
|
||||
approved,
|
||||
} if id == tool_use_id => {
|
||||
return approved;
|
||||
}
|
||||
UserAction::Quit => return false,
|
||||
UserAction::SendMessage(text) => {
|
||||
self.queued_messages.push(text);
|
||||
}
|
||||
UserAction::ClearHistory { epoch } => {
|
||||
// Keep epoch in sync even while blocked on an approval
|
||||
// prompt. Without this, events emitted after the approval
|
||||
// resolves would carry the pre-clear epoch and be silently
|
||||
// discarded by the TUI.
|
||||
self.epoch = epoch;
|
||||
self.history.clear();
|
||||
}
|
||||
_ => {} // discard stale approvals
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Run the orchestrator until the user quits or the `action_rx` channel closes.
|
||||
pub async fn run(mut self) {
|
||||
while let Some(action) = self.action_rx.recv().await {
|
||||
|
|
@ -514,31 +464,16 @@ impl<P: ModelProvider> Orchestrator<P> {
|
|||
self.epoch = epoch;
|
||||
self.history.clear();
|
||||
}
|
||||
|
||||
// Approval responses are handled inline during tool execution,
|
||||
// not in the main action loop. If one arrives here it's stale.
|
||||
UserAction::ToolApprovalResponse { .. } => {}
|
||||
|
||||
UserAction::SetNetworkPolicy(allowed) => {
|
||||
self.sandbox.set_network_allowed(allowed);
|
||||
// The sandbox lives inside the executor; there is currently
|
||||
// no RPC to change its policy at runtime. The event is still
|
||||
// forwarded to the TUI for status-bar display purposes.
|
||||
self.send(UIEvent::NetworkPolicyChanged(allowed)).await;
|
||||
}
|
||||
|
||||
UserAction::SendMessage(text) => {
|
||||
self.history
|
||||
.push(ConversationMessage::text(Role::User, text));
|
||||
self.run_turn().await;
|
||||
|
||||
// Drain any messages queued while an approval prompt was
|
||||
// open. Each queued message is a full turn in sequence.
|
||||
while !self.queued_messages.is_empty() {
|
||||
let queued = std::mem::take(&mut self.queued_messages);
|
||||
for msg in queued {
|
||||
self.history
|
||||
.push(ConversationMessage::text(Role::User, msg));
|
||||
self.run_turn().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -549,6 +484,7 @@ impl<P: ModelProvider> Orchestrator<P> {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::types::ToolDefinition;
|
||||
use crate::executor::{ExecutorServer, spawn_local};
|
||||
use crate::tools::ToolRegistry;
|
||||
use futures::Stream;
|
||||
use tokio::sync::mpsc;
|
||||
|
|
@ -576,8 +512,8 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
/// Create an Orchestrator with no tools for testing text-only flows.
|
||||
fn test_orchestrator<P: ModelProvider>(
|
||||
/// Create an `Orchestrator` with an empty tool registry for testing text-only flows.
|
||||
async fn test_orchestrator<P: ModelProvider>(
|
||||
provider: P,
|
||||
action_rx: mpsc::Receiver<UserAction>,
|
||||
event_tx: mpsc::Sender<StampedEvent>,
|
||||
|
|
@ -589,13 +525,8 @@ mod tests {
|
|||
std::path::PathBuf::from("/tmp"),
|
||||
EnforcementMode::Yolo,
|
||||
);
|
||||
Orchestrator::new(
|
||||
provider,
|
||||
ToolRegistry::empty(),
|
||||
sandbox,
|
||||
action_rx,
|
||||
event_tx,
|
||||
)
|
||||
let executor = spawn_local(ExecutorServer::new(ToolRegistry::empty(), sandbox));
|
||||
Orchestrator::new(provider, executor, action_rx, event_tx).await
|
||||
}
|
||||
|
||||
/// Collect all UIEvents that arrive within one orchestrator turn, stopping
|
||||
|
|
@ -633,7 +564,7 @@ mod tests {
|
|||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(16);
|
||||
|
||||
let orch = test_orchestrator(provider, action_rx, event_tx);
|
||||
let orch = test_orchestrator(provider, action_rx, event_tx).await;
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
||||
action_tx
|
||||
|
|
@ -652,7 +583,7 @@ mod tests {
|
|||
assert!(matches!(&events[1], UIEvent::StreamDelta(s) if s == ", world!"));
|
||||
assert!(matches!(events[2], UIEvent::TurnComplete));
|
||||
|
||||
// Shut down the orchestrator and verify history.
|
||||
// Shut down the orchestrator.
|
||||
action_tx.send(UserAction::Quit).await.unwrap();
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
|
@ -671,7 +602,7 @@ mod tests {
|
|||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(16);
|
||||
|
||||
let orch = test_orchestrator(provider, action_rx, event_tx);
|
||||
let orch = test_orchestrator(provider, action_rx, event_tx).await;
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
||||
action_tx
|
||||
|
|
@ -713,7 +644,7 @@ mod tests {
|
|||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||
let (event_tx, _event_rx) = mpsc::channel::<StampedEvent>(8);
|
||||
|
||||
let orch = test_orchestrator(NeverCalledProvider, action_rx, event_tx);
|
||||
let orch = test_orchestrator(NeverCalledProvider, action_rx, event_tx).await;
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
||||
action_tx.send(UserAction::Quit).await.unwrap();
|
||||
|
|
@ -724,12 +655,8 @@ mod tests {
|
|||
|
||||
/// Two sequential SendMessage turns each append a user message and the
|
||||
/// accumulated assistant response, leaving four messages in history order.
|
||||
///
|
||||
/// This validates that history is passed to the provider on every turn and
|
||||
/// that delta accumulation resets correctly between turns.
|
||||
#[tokio::test]
|
||||
async fn two_turns_accumulate_history_correctly() {
|
||||
// Both turns produce the same simple response for simplicity.
|
||||
let make_turn_events = || {
|
||||
vec![
|
||||
StreamEvent::TextDelta("reply".to_string()),
|
||||
|
|
@ -737,8 +664,6 @@ mod tests {
|
|||
]
|
||||
};
|
||||
|
||||
// We need to serve two different turns from the same provider.
|
||||
// Use an `Arc<Mutex<VecDeque>>` so the provider can pop event sets.
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
|
|
@ -766,7 +691,7 @@ mod tests {
|
|||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(32);
|
||||
|
||||
let orch = test_orchestrator(provider, action_rx, event_tx);
|
||||
let orch = test_orchestrator(provider, action_rx, event_tx).await;
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
||||
// First turn.
|
||||
|
|
@ -794,8 +719,8 @@ mod tests {
|
|||
// -- tool-use accumulation ----------------------------------------------------
|
||||
|
||||
/// When the provider emits tool-use events, the orchestrator executes the
|
||||
/// tool (auto-approve since read_file is AutoApprove), feeds the result back,
|
||||
/// and the provider's second call returns text.
|
||||
/// tool via the executor, feeds the result back, and the provider's second
|
||||
/// call returns text.
|
||||
#[tokio::test]
|
||||
async fn tool_use_loop_executes_and_feeds_back() {
|
||||
use std::collections::VecDeque;
|
||||
|
|
@ -839,6 +764,7 @@ mod tests {
|
|||
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(32);
|
||||
|
||||
// Use a real ToolRegistry so read_file works.
|
||||
use crate::executor::{ExecutorServer, spawn_local};
|
||||
use crate::sandbox::policy::SandboxPolicy;
|
||||
use crate::sandbox::{EnforcementMode, Sandbox};
|
||||
let dir = tempfile::TempDir::new().unwrap();
|
||||
|
|
@ -848,13 +774,8 @@ mod tests {
|
|||
dir.path().to_path_buf(),
|
||||
EnforcementMode::Yolo,
|
||||
);
|
||||
let orch = Orchestrator::new(
|
||||
MultiCallMock { turns },
|
||||
ToolRegistry::default_tools(),
|
||||
sandbox,
|
||||
action_rx,
|
||||
event_tx,
|
||||
);
|
||||
let executor = spawn_local(ExecutorServer::new(ToolRegistry::default_tools(), sandbox));
|
||||
let orch = Orchestrator::new(MultiCallMock { turns }, executor, action_rx, event_tx).await;
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
||||
action_tx
|
||||
|
|
@ -925,7 +846,7 @@ mod tests {
|
|||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(16);
|
||||
|
||||
let orch = test_orchestrator(provider, action_rx, event_tx);
|
||||
let orch = test_orchestrator(provider, action_rx, event_tx).await;
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
||||
action_tx
|
||||
|
|
@ -946,116 +867,11 @@ mod tests {
|
|||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
// -- queued messages ----------------------------------------------------------
|
||||
|
||||
/// A SendMessage sent while an approval prompt is open must be processed
|
||||
/// after the current turn completes, not silently dropped.
|
||||
///
|
||||
/// This test uses a RequiresApproval tool (write_file) so that the
|
||||
/// orchestrator blocks in wait_for_approval. While blocked, we send a second
|
||||
/// message. After approving (denied here for simplicity), the queued message
|
||||
/// should still be processed.
|
||||
#[tokio::test]
|
||||
async fn send_message_during_approval_is_queued_and_processed() {
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
struct MultiCallMock {
|
||||
turns: Arc<Mutex<VecDeque<Vec<StreamEvent>>>>,
|
||||
}
|
||||
impl ModelProvider for MultiCallMock {
|
||||
fn stream<'a>(
|
||||
&'a self,
|
||||
_messages: &'a [ConversationMessage],
|
||||
_tools: &'a [ToolDefinition],
|
||||
) -> impl Stream<Item = StreamEvent> + Send + 'a {
|
||||
let events = self.turns.lock().unwrap().pop_front().unwrap_or_default();
|
||||
futures::stream::iter(events)
|
||||
}
|
||||
}
|
||||
|
||||
let turns = Arc::new(Mutex::new(VecDeque::from([
|
||||
// Turn 1: requests write_file (RequiresApproval)
|
||||
vec![
|
||||
StreamEvent::ToolUseStart {
|
||||
id: "toolu_w".to_string(),
|
||||
name: "write_file".to_string(),
|
||||
},
|
||||
StreamEvent::ToolUseInputDelta(
|
||||
"{\"path\":\"x.txt\",\"content\":\"hi\"}".to_string(),
|
||||
),
|
||||
StreamEvent::ToolUseDone,
|
||||
StreamEvent::Done,
|
||||
],
|
||||
// Turn 1 second iteration after tool result (denied)
|
||||
vec![
|
||||
StreamEvent::TextDelta("ok denied".to_string()),
|
||||
StreamEvent::Done,
|
||||
],
|
||||
// Turn 2: the queued message
|
||||
vec![
|
||||
StreamEvent::TextDelta("queued reply".to_string()),
|
||||
StreamEvent::Done,
|
||||
],
|
||||
])));
|
||||
|
||||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(16);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(64);
|
||||
|
||||
let orch = test_orchestrator(MultiCallMock { turns }, action_rx, event_tx);
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
||||
// Start turn 1 -- orchestrator will block on approval.
|
||||
action_tx
|
||||
.send(UserAction::SendMessage("turn1".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Let the orchestrator reach wait_for_approval.
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
|
||||
// Send a message while blocked -- should be queued.
|
||||
action_tx
|
||||
.send(UserAction::SendMessage("queued".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Deny the tool -- unblocks wait_for_approval.
|
||||
action_tx
|
||||
.send(UserAction::ToolApprovalResponse {
|
||||
tool_use_id: "toolu_w".to_string(),
|
||||
approved: false,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Wait for both turns to complete.
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
|
||||
// Collect everything.
|
||||
let mut all_events = Vec::new();
|
||||
while let Ok(ev) = event_rx.try_recv() {
|
||||
all_events.push(ev.event);
|
||||
}
|
||||
|
||||
// The queued message must have produced "queued reply".
|
||||
assert!(
|
||||
all_events
|
||||
.iter()
|
||||
.any(|e| matches!(e, UIEvent::StreamDelta(s) if s == "queued reply")),
|
||||
"queued message was not processed; events: {all_events:?}"
|
||||
);
|
||||
|
||||
action_tx.send(UserAction::Quit).await.unwrap();
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
// -- network policy toggle ------------------------------------------------
|
||||
|
||||
/// SetNetworkPolicy sends a NetworkPolicyChanged event back to the TUI.
|
||||
#[tokio::test]
|
||||
async fn set_network_policy_sends_event() {
|
||||
// Provider that returns immediately to avoid blocking.
|
||||
struct NeverCalledProvider;
|
||||
impl ModelProvider for NeverCalledProvider {
|
||||
fn stream<'a>(
|
||||
|
|
@ -1070,7 +886,7 @@ mod tests {
|
|||
let (action_tx, action_rx) = mpsc::channel::<UserAction>(8);
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<StampedEvent>(8);
|
||||
|
||||
let orch = test_orchestrator(NeverCalledProvider, action_rx, event_tx);
|
||||
let orch = test_orchestrator(NeverCalledProvider, action_rx, event_tx).await;
|
||||
let handle = tokio::spawn(orch.run());
|
||||
|
||||
action_tx
|
||||
|
|
|
|||
|
|
@ -26,8 +26,6 @@ pub enum StreamEvent {
|
|||
pub enum UserAction {
|
||||
/// The user has submitted a message.
|
||||
SendMessage(String),
|
||||
/// The user has responded to a tool approval request.
|
||||
ToolApprovalResponse { tool_use_id: String, approved: bool },
|
||||
/// The user has requested to quit.
|
||||
Quit,
|
||||
/// The user has requested to clear conversation history.
|
||||
|
|
@ -67,13 +65,7 @@ pub enum ToolDisplay {
|
|||
pub enum UIEvent {
|
||||
/// A text chunk to append to the current assistant message.
|
||||
StreamDelta(String),
|
||||
/// A tool requires user approval before execution.
|
||||
ToolApprovalRequest {
|
||||
tool_use_id: String,
|
||||
tool_name: String,
|
||||
display: ToolDisplay,
|
||||
},
|
||||
/// A tool is being executed (informational, after approval or auto-approve).
|
||||
/// A tool is being executed.
|
||||
ToolExecuting {
|
||||
tool_use_id: String,
|
||||
tool_name: String,
|
||||
|
|
@ -121,7 +113,9 @@ pub enum Role {
|
|||
///
|
||||
/// This is provider-agnostic -- the provider serializes it into the
|
||||
/// format required by its API (e.g. the `tools` array for Anthropic).
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
/// `Deserialize` is required so that `ToolDefinition` can cross the tarpc
|
||||
/// channel as a return value from `ExecutorServer::list_tools`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolDefinition {
|
||||
/// The tool name the model will use in `tool_use` blocks.
|
||||
pub name: String,
|
||||
|
|
|
|||
176
src/executor/mod.rs
Normal file
176
src/executor/mod.rs
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
//! tarpc-based tool executor: owns the [`ToolRegistry`] and [`Sandbox`],
|
||||
//! exposes them to the orchestrator through an in-process RPC channel.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! Orchestrator --call_tool(name, input)--> ExecutorServer --execute()--> Tool
|
||||
//! |
|
||||
//! v
|
||||
//! Sandbox
|
||||
//! ```
|
||||
//!
|
||||
//! The split exists so that each sub-agent (Phase 7) can get its own
|
||||
//! `ExecutorClient` backed by an independent sandbox policy. For now there is
|
||||
//! one executor shared by the single orchestrator.
|
||||
//!
|
||||
//! Transport: [`tarpc::transport::channel::unbounded`] (in-process, zero
|
||||
//! serialization overhead for benchmarks, no network required).
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::StreamExt as _;
|
||||
use tarpc::server::Channel as _;
|
||||
|
||||
use crate::core::types::ToolDefinition;
|
||||
use crate::sandbox::Sandbox;
|
||||
use crate::tools::{ToolOutput, ToolRegistry};
|
||||
|
||||
/// tarpc service definition for the tool executor.
|
||||
///
|
||||
/// The two methods map cleanly to the two operations the orchestrator needs:
|
||||
/// list_tools -> populate the provider's `tools` array once at startup
|
||||
/// call_tool -> dispatch a single tool-use block during a turn
|
||||
#[tarpc::service]
|
||||
pub trait ToolExecutor {
|
||||
/// Return the definitions of all registered tools.
|
||||
async fn list_tools() -> Vec<ToolDefinition>;
|
||||
/// Execute a tool by name with the given JSON input.
|
||||
///
|
||||
/// Returns `Ok(output)` on success or a human-readable `Err(msg)` when
|
||||
/// the tool name is unknown or execution fails.
|
||||
async fn call_tool(name: String, input: serde_json::Value) -> Result<ToolOutput, String>;
|
||||
}
|
||||
|
||||
/// Server-side state: the registry of tools and the sandbox they run in.
|
||||
///
|
||||
/// `Clone` is derived because tarpc clones the server struct for each
|
||||
/// in-flight request. Both fields are `Arc`-wrapped so clones are cheap.
|
||||
#[derive(Clone)]
|
||||
pub struct ExecutorServer {
|
||||
registry: Arc<ToolRegistry>,
|
||||
sandbox: Arc<Sandbox>,
|
||||
}
|
||||
|
||||
impl ExecutorServer {
|
||||
/// Wrap `registry` and `sandbox` in an `ExecutorServer` ready to serve.
|
||||
pub fn new(registry: ToolRegistry, sandbox: Sandbox) -> Self {
|
||||
Self {
|
||||
registry: Arc::new(registry),
|
||||
sandbox: Arc::new(sandbox),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolExecutor for ExecutorServer {
|
||||
async fn list_tools(self, _ctx: tarpc::context::Context) -> Vec<ToolDefinition> {
|
||||
self.registry.definitions()
|
||||
}
|
||||
|
||||
async fn call_tool(
|
||||
self,
|
||||
_ctx: tarpc::context::Context,
|
||||
name: String,
|
||||
input: serde_json::Value,
|
||||
) -> Result<ToolOutput, String> {
|
||||
match self.registry.get(&name) {
|
||||
Some(tool) => tool
|
||||
.execute(&input, self.sandbox.as_ref())
|
||||
.await
|
||||
.map_err(|e| e.to_string()),
|
||||
None => Err(format!("unknown tool: {name}")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Spawn an `ExecutorServer` on an in-process tarpc channel and return the
|
||||
/// client end.
|
||||
///
|
||||
/// The server task runs as a detached tokio task and lives as long as the
|
||||
/// returned client holds the channel open. No network sockets are involved.
|
||||
pub fn spawn_local(server: ExecutorServer) -> ToolExecutorClient {
|
||||
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||
let server_channel = tarpc::server::BaseChannel::with_defaults(server_transport);
|
||||
tokio::spawn(
|
||||
server_channel
|
||||
.execute(server.serve())
|
||||
.for_each(|response| async move {
|
||||
tokio::spawn(response);
|
||||
}),
|
||||
);
|
||||
ToolExecutorClient::new(tarpc::client::Config::default(), client_transport).spawn()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::sandbox::test_sandbox;
|
||||
use tempfile::TempDir;
|
||||
|
||||
// -- direct server (no tarpc) -------------------------------------------------
|
||||
|
||||
/// `list_tools` returns a definition for every tool in the registry.
|
||||
#[tokio::test]
|
||||
async fn list_tools_returns_all_definitions() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let server = ExecutorServer::new(ToolRegistry::default_tools(), test_sandbox(dir.path()));
|
||||
let ctx = tarpc::context::current();
|
||||
let defs = server.list_tools(ctx).await;
|
||||
assert_eq!(defs.len(), 4);
|
||||
assert!(defs.iter().any(|d| d.name == "read_file"));
|
||||
assert!(defs.iter().any(|d| d.name == "write_file"));
|
||||
}
|
||||
|
||||
/// `call_tool` for `read_file` returns the file contents.
|
||||
#[tokio::test]
|
||||
async fn call_tool_read_file_returns_content() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
std::fs::write(dir.path().join("hello.txt"), "world").unwrap();
|
||||
let server = ExecutorServer::new(ToolRegistry::default_tools(), test_sandbox(dir.path()));
|
||||
let ctx = tarpc::context::current();
|
||||
let result = server
|
||||
.call_tool(
|
||||
ctx,
|
||||
"read_file".to_string(),
|
||||
serde_json::json!({"path": "hello.txt"}),
|
||||
)
|
||||
.await;
|
||||
let output = result.unwrap();
|
||||
assert_eq!(output.content, "world");
|
||||
assert!(!output.is_error);
|
||||
}
|
||||
|
||||
/// `call_tool` for an unknown tool returns an `Err` with a descriptive message.
|
||||
#[tokio::test]
|
||||
async fn call_tool_unknown_tool_returns_err() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let server = ExecutorServer::new(ToolRegistry::default_tools(), test_sandbox(dir.path()));
|
||||
let ctx = tarpc::context::current();
|
||||
let result = server
|
||||
.call_tool(ctx, "nonexistent".to_string(), serde_json::json!({}))
|
||||
.await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("unknown tool"));
|
||||
}
|
||||
|
||||
// -- in-process tarpc channel (spawn_local) ------------------------------------
|
||||
|
||||
/// `spawn_local` + `call_tool` over the in-process channel reads a real file.
|
||||
#[tokio::test]
|
||||
async fn spawn_local_call_tool_read_file() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
std::fs::write(dir.path().join("data.txt"), "from channel").unwrap();
|
||||
let server = ExecutorServer::new(ToolRegistry::default_tools(), test_sandbox(dir.path()));
|
||||
let client = spawn_local(server);
|
||||
let result = client
|
||||
.call_tool(
|
||||
tarpc::context::current(),
|
||||
"read_file".to_string(),
|
||||
serde_json::json!({"path": "data.txt"}),
|
||||
)
|
||||
.await
|
||||
.expect("RPC call failed");
|
||||
let output = result.unwrap();
|
||||
assert_eq!(output.content, "from channel");
|
||||
}
|
||||
}
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
mod app;
|
||||
mod core;
|
||||
mod executor;
|
||||
mod provider;
|
||||
mod sandbox;
|
||||
mod tools;
|
||||
|
|
|
|||
|
|
@ -119,6 +119,10 @@ impl Sandbox {
|
|||
}
|
||||
|
||||
/// Set the network access policy.
|
||||
///
|
||||
/// Currently not wired to a runtime RPC -- will be exposed via the executor
|
||||
/// in a future phase.
|
||||
#[allow(dead_code)]
|
||||
pub fn set_network_allowed(&mut self, allowed: bool) {
|
||||
tracing::info!(network_allowed = allowed, "sandbox network policy updated");
|
||||
self.policy.network_allowed = allowed;
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use crate::sandbox::Sandbox;
|
|||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use super::{RiskLevel, Tool, ToolError, ToolOutput};
|
||||
use super::{Tool, ToolError, ToolOutput};
|
||||
|
||||
/// Lists directory contents. Auto-approved (read-only).
|
||||
pub struct ListDirectory;
|
||||
|
|
@ -32,10 +32,6 @@ impl Tool for ListDirectory {
|
|||
})
|
||||
}
|
||||
|
||||
fn risk_level(&self) -> RiskLevel {
|
||||
RiskLevel::AutoApprove
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
&self,
|
||||
input: &serde_json::Value,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
//! Tool system: trait, registry, risk classification, and built-in tools.
|
||||
//! Tool system: trait, registry, and built-in tools.
|
||||
//!
|
||||
//! All tools implement the [`Tool`] trait. The [`ToolRegistry`] collects them
|
||||
//! and provides lookup by name plus generation of [`ToolDefinition`]s for the
|
||||
|
|
@ -13,12 +13,16 @@ mod shell_exec;
|
|||
mod write_file;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::core::types::ToolDefinition;
|
||||
use crate::sandbox::Sandbox;
|
||||
|
||||
/// The output of a tool execution.
|
||||
#[derive(Debug)]
|
||||
///
|
||||
/// Derives `Serialize + Deserialize` so it can cross the tarpc channel
|
||||
/// between [`crate::executor::ExecutorServer`] and the orchestrator.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ToolOutput {
|
||||
/// The text content returned to the model.
|
||||
pub content: String,
|
||||
|
|
@ -26,15 +30,6 @@ pub struct ToolOutput {
|
|||
pub is_error: bool,
|
||||
}
|
||||
|
||||
/// Risk classification for tool approval gating.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum RiskLevel {
|
||||
/// Safe to execute without user confirmation (e.g. read-only operations).
|
||||
AutoApprove,
|
||||
/// Requires explicit user approval before execution (e.g. writes, shell).
|
||||
RequiresApproval,
|
||||
}
|
||||
|
||||
/// A tool that the model can invoke.
|
||||
///
|
||||
/// All file I/O and process spawning must go through the [`Sandbox`] passed
|
||||
|
|
@ -53,8 +48,6 @@ pub trait Tool: Send + Sync {
|
|||
fn description(&self) -> &str;
|
||||
/// JSON Schema for the tool's input parameters.
|
||||
fn input_schema(&self) -> serde_json::Value;
|
||||
/// The risk level of this tool.
|
||||
fn risk_level(&self) -> RiskLevel;
|
||||
/// Execute the tool with the given input, confined by `sandbox`.
|
||||
async fn execute(
|
||||
&self,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use crate::sandbox::Sandbox;
|
|||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use super::{RiskLevel, Tool, ToolError, ToolOutput};
|
||||
use super::{Tool, ToolError, ToolOutput};
|
||||
|
||||
/// Reads file contents. Auto-approved (read-only).
|
||||
pub struct ReadFile;
|
||||
|
|
@ -32,10 +32,6 @@ impl Tool for ReadFile {
|
|||
})
|
||||
}
|
||||
|
||||
fn risk_level(&self) -> RiskLevel {
|
||||
RiskLevel::AutoApprove
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
&self,
|
||||
input: &serde_json::Value,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use crate::sandbox::Sandbox;
|
|||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use super::{RiskLevel, Tool, ToolError, ToolOutput};
|
||||
use super::{Tool, ToolError, ToolOutput};
|
||||
|
||||
/// Executes a shell command. Requires user approval.
|
||||
pub struct ShellExec;
|
||||
|
|
@ -32,10 +32,6 @@ impl Tool for ShellExec {
|
|||
})
|
||||
}
|
||||
|
||||
fn risk_level(&self) -> RiskLevel {
|
||||
RiskLevel::RequiresApproval
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
&self,
|
||||
input: &serde_json::Value,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use crate::sandbox::Sandbox;
|
|||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use super::{RiskLevel, Tool, ToolError, ToolOutput};
|
||||
use super::{Tool, ToolError, ToolOutput};
|
||||
|
||||
/// Writes content to a file. Requires user approval.
|
||||
pub struct WriteFile;
|
||||
|
|
@ -36,10 +36,6 @@ impl Tool for WriteFile {
|
|||
})
|
||||
}
|
||||
|
||||
fn risk_level(&self) -> RiskLevel {
|
||||
RiskLevel::RequiresApproval
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
&self,
|
||||
input: &serde_json::Value,
|
||||
|
|
|
|||
|
|
@ -16,14 +16,13 @@ use crate::tui::tool_display;
|
|||
/// arrives, the handler searches `state.messages` for an existing entry with the
|
||||
/// same `tool_use_id` and replaces its content rather than appending a new row.
|
||||
///
|
||||
/// | Event | Effect |
|
||||
/// |------------------------|------------------------------------------------------------|
|
||||
/// | `StreamDelta(s)` | Append `s` to last message if it's `Assistant`; else push |
|
||||
/// | `ToolApprovalRequest` | Push inline message with approval prompt, set pending |
|
||||
/// | `ToolExecuting` | Replace approval message in-place (or push new) |
|
||||
/// | `ToolResult` | Replace executing message in-place (or push new) |
|
||||
/// | `TurnComplete` | No structural change; logged at debug level |
|
||||
/// | `Error(msg)` | Push `(Assistant, "[error] {msg}")` |
|
||||
/// | Event | Effect |
|
||||
/// |-------------------|------------------------------------------------------------|
|
||||
/// | `StreamDelta(s)` | Append `s` to last message if it's `Assistant`; else push |
|
||||
/// | `ToolExecuting` | Push new executing message (or replace in-place) |
|
||||
/// | `ToolResult` | Replace executing message in-place (or push new) |
|
||||
/// | `TurnComplete` | No structural change; logged at debug level |
|
||||
/// | `Error(msg)` | Push `(Assistant, "[error] {msg}")` |
|
||||
pub(super) fn drain_ui_events(event_rx: &mut mpsc::Receiver<StampedEvent>, state: &mut AppState) {
|
||||
while let Ok(stamped) = event_rx.try_recv() {
|
||||
// Discard events from before the most recent :clear.
|
||||
|
|
@ -56,21 +55,6 @@ pub(super) fn drain_ui_events(event_rx: &mut mpsc::Receiver<StampedEvent>, state
|
|||
}
|
||||
state.content_changed = true;
|
||||
}
|
||||
UIEvent::ToolApprovalRequest {
|
||||
tool_use_id,
|
||||
tool_name,
|
||||
display,
|
||||
} => {
|
||||
let mut content = tool_display::format_executing(&tool_name, &display);
|
||||
content.push_str("\n[y] approve [n] deny");
|
||||
state.messages.push(DisplayMessage {
|
||||
role: Role::Assistant,
|
||||
content,
|
||||
tool_use_id: Some(tool_use_id.clone()),
|
||||
});
|
||||
state.pending_approval = Some(PendingApproval { tool_use_id });
|
||||
state.content_changed = true;
|
||||
}
|
||||
UIEvent::ToolExecuting {
|
||||
tool_use_id,
|
||||
tool_name,
|
||||
|
|
@ -127,12 +111,6 @@ fn replace_or_push(state: &mut AppState, tool_use_id: &str, content: String) {
|
|||
}
|
||||
}
|
||||
|
||||
/// A pending tool approval request waiting for user input.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PendingApproval {
|
||||
pub tool_use_id: String,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
@ -179,29 +157,6 @@ mod tests {
|
|||
assert_eq!(state.messages[1].content, "hello");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn drain_tool_approval_sets_pending_and_adds_message() {
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel(8);
|
||||
let mut state = AppState::new();
|
||||
tx.send(stamp(UIEvent::ToolApprovalRequest {
|
||||
tool_use_id: "t1".to_string(),
|
||||
tool_name: "shell_exec".to_string(),
|
||||
display: ToolDisplay::ShellExec {
|
||||
command: "cargo test".to_string(),
|
||||
},
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
drain_ui_events(&mut rx, &mut state);
|
||||
assert!(state.pending_approval.is_some());
|
||||
assert_eq!(state.pending_approval.as_ref().unwrap().tool_use_id, "t1");
|
||||
// Message should be inline with approval prompt.
|
||||
assert_eq!(state.messages.len(), 1);
|
||||
assert!(state.messages[0].content.contains("[y] approve"));
|
||||
assert_eq!(state.messages[0].tool_use_id.as_deref(), Some("t1"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn drain_tool_result_replaces_existing_message() {
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel(8);
|
||||
|
|
|
|||
|
|
@ -13,8 +13,6 @@ pub(super) enum LoopControl {
|
|||
Quit,
|
||||
/// The user ran `:clear`; wipe the conversation.
|
||||
ClearHistory,
|
||||
/// The user responded to a tool approval prompt.
|
||||
ToolApproval { tool_use_id: String, approved: bool },
|
||||
/// The user ran `:net on` or `:net off`.
|
||||
SetNetworkPolicy(bool),
|
||||
}
|
||||
|
|
@ -28,60 +26,6 @@ pub(super) fn handle_key(key: Option<KeyEvent>, state: &mut AppState) -> Option<
|
|||
// Clear any transient status error on the next keypress.
|
||||
state.status_error = None;
|
||||
|
||||
// If a tool approval is pending, intercept y/n but allow scrolling.
|
||||
if let Some(approval) = &state.pending_approval {
|
||||
let tool_use_id = approval.tool_use_id.clone();
|
||||
let is_ctrl = key.modifiers.contains(KeyModifiers::CONTROL);
|
||||
match key.code {
|
||||
KeyCode::Char('y') | KeyCode::Char('Y') if !is_ctrl => {
|
||||
state.pending_approval = None;
|
||||
return Some(LoopControl::ToolApproval {
|
||||
tool_use_id,
|
||||
approved: true,
|
||||
});
|
||||
}
|
||||
KeyCode::Char('n') | KeyCode::Char('N') if !is_ctrl => {
|
||||
state.pending_approval = None;
|
||||
return Some(LoopControl::ToolApproval {
|
||||
tool_use_id,
|
||||
approved: false,
|
||||
});
|
||||
}
|
||||
// Allow scrolling while approval is pending.
|
||||
KeyCode::Char('j') if !is_ctrl => {
|
||||
state.scroll = state.scroll.saturating_add(1);
|
||||
return None;
|
||||
}
|
||||
KeyCode::Char('k') if !is_ctrl => {
|
||||
state.scroll = state.scroll.saturating_sub(1);
|
||||
return None;
|
||||
}
|
||||
KeyCode::Char('G') if !is_ctrl => {
|
||||
state.scroll = u16::MAX;
|
||||
return None;
|
||||
}
|
||||
KeyCode::Char('g') if !is_ctrl => {
|
||||
state.pending_keys.push('g');
|
||||
if state.pending_keys.ends_with(&['g', 'g']) {
|
||||
state.scroll = 0;
|
||||
state.pending_keys.clear();
|
||||
}
|
||||
return None;
|
||||
}
|
||||
KeyCode::Char('d') if is_ctrl => {
|
||||
let half = (state.viewport_height / 2).max(1);
|
||||
state.scroll = state.scroll.saturating_add(half);
|
||||
return None;
|
||||
}
|
||||
KeyCode::Char('u') if is_ctrl => {
|
||||
let half = (state.viewport_height / 2).max(1);
|
||||
state.scroll = state.scroll.saturating_sub(half);
|
||||
return None;
|
||||
}
|
||||
_ => return None,
|
||||
}
|
||||
}
|
||||
|
||||
// Ctrl+C quits from any mode.
|
||||
if key.modifiers.contains(KeyModifiers::CONTROL) && key.code == KeyCode::Char('c') {
|
||||
return Some(LoopControl::Quit);
|
||||
|
|
|
|||
|
|
@ -90,8 +90,6 @@ pub struct AppState {
|
|||
pub viewport_height: u16,
|
||||
/// Transient error message shown in the status bar, cleared on next keypress.
|
||||
pub status_error: Option<String>,
|
||||
/// A tool approval request waiting for user input (y/n).
|
||||
pub pending_approval: Option<events::PendingApproval>,
|
||||
/// Whether the sandbox is in yolo (unsandboxed) mode.
|
||||
pub sandbox_yolo: bool,
|
||||
/// Whether network access is currently allowed.
|
||||
|
|
@ -115,7 +113,6 @@ impl AppState {
|
|||
pending_keys: Vec::new(),
|
||||
viewport_height: 0,
|
||||
status_error: None,
|
||||
pending_approval: None,
|
||||
sandbox_yolo: false,
|
||||
network_allowed: false,
|
||||
epoch: 0,
|
||||
|
|
@ -224,17 +221,6 @@ pub async fn run(
|
|||
.send(UserAction::ClearHistory { epoch: state.epoch })
|
||||
.await;
|
||||
}
|
||||
Some(input::LoopControl::ToolApproval {
|
||||
tool_use_id,
|
||||
approved,
|
||||
}) => {
|
||||
let _ = action_tx
|
||||
.send(UserAction::ToolApprovalResponse {
|
||||
tool_use_id,
|
||||
approved,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
Some(input::LoopControl::SetNetworkPolicy(allowed)) => {
|
||||
let _ = action_tx.send(UserAction::SetNetworkPolicy(allowed)).await;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -425,37 +425,6 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn render_approval_inline_visible() {
|
||||
let backend = TestBackend::new(80, 24);
|
||||
let mut terminal = Terminal::new(backend).unwrap();
|
||||
let mut state = AppState::new();
|
||||
// Inline approval message in the message stream.
|
||||
state.messages.push(DisplayMessage {
|
||||
role: Role::Assistant,
|
||||
content: "$ cargo test\n[y] approve [n] deny".to_string(),
|
||||
tool_use_id: Some("t1".to_string()),
|
||||
});
|
||||
state.pending_approval = Some(super::super::events::PendingApproval {
|
||||
tool_use_id: "t1".to_string(),
|
||||
});
|
||||
terminal.draw(|frame| render(frame, &state)).unwrap();
|
||||
let buf = terminal.backend().buffer().clone();
|
||||
let all_text: String = buf
|
||||
.content()
|
||||
.iter()
|
||||
.map(|c| c.symbol().to_string())
|
||||
.collect();
|
||||
assert!(
|
||||
all_text.contains("approve"),
|
||||
"expected approval prompt in buffer"
|
||||
);
|
||||
assert!(
|
||||
all_text.contains("cargo test"),
|
||||
"expected tool info in buffer"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn render_status_bar_shows_net_off() {
|
||||
let backend = TestBackend::new(80, 24);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue