Tool Use.
This commit is contained in:
parent
6b85ff3cb8
commit
0c1c928498
20 changed files with 1822 additions and 129 deletions
87
src/tools/list_directory.rs
Normal file
87
src/tools/list_directory.rs
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
//! `list_directory` tool: lists entries in a directory within the working directory.
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use super::{RiskLevel, Tool, ToolError, ToolOutput, validate_path};
|
||||
|
||||
/// Lists directory contents. Auto-approved (read-only).
|
||||
pub struct ListDirectory;
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ListDirectory {
|
||||
fn name(&self) -> &str {
|
||||
"list_directory"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"List the files and subdirectories in a directory. The path is relative to the working directory."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The directory path to list, relative to the working directory. Use '.' for the working directory itself."
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
})
|
||||
}
|
||||
|
||||
fn risk_level(&self) -> RiskLevel {
|
||||
RiskLevel::AutoApprove
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
&self,
|
||||
input: &serde_json::Value,
|
||||
working_dir: &Path,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let path_str = input["path"]
|
||||
.as_str()
|
||||
.ok_or_else(|| ToolError::InvalidInput("missing 'path' string".to_string()))?;
|
||||
|
||||
let canonical = validate_path(working_dir, path_str)?;
|
||||
|
||||
let mut entries: Vec<String> = Vec::new();
|
||||
let mut dir = tokio::fs::read_dir(&canonical).await?;
|
||||
while let Some(entry) = dir.next_entry().await? {
|
||||
let name = entry.file_name().to_string_lossy().to_string();
|
||||
let suffix = if entry.file_type().await?.is_dir() {
|
||||
"/"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
entries.push(format!("{name}{suffix}"));
|
||||
}
|
||||
entries.sort();
|
||||
|
||||
Ok(ToolOutput {
|
||||
content: entries.join("\n"),
|
||||
is_error: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::fs;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_directory_contents() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
fs::write(dir.path().join("a.txt"), "").unwrap();
|
||||
fs::create_dir(dir.path().join("subdir")).unwrap();
|
||||
let tool = ListDirectory;
|
||||
let input = serde_json::json!({"path": "."});
|
||||
let out = tool.execute(&input, dir.path()).await.unwrap();
|
||||
assert!(out.content.contains("a.txt"));
|
||||
assert!(out.content.contains("subdir/"));
|
||||
}
|
||||
}
|
||||
220
src/tools/mod.rs
Normal file
220
src/tools/mod.rs
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
//! Tool system: trait, registry, risk classification, and built-in tools.
|
||||
//!
|
||||
//! All tools implement the [`Tool`] trait. The [`ToolRegistry`] collects them
|
||||
//! and provides lookup by name plus generation of [`ToolDefinition`]s for the
|
||||
//! model provider.
|
||||
|
||||
mod list_directory;
|
||||
mod read_file;
|
||||
mod shell_exec;
|
||||
mod write_file;
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::core::types::ToolDefinition;
|
||||
|
||||
/// The output of a tool execution.
|
||||
#[derive(Debug)]
|
||||
pub struct ToolOutput {
|
||||
/// The text content returned to the model.
|
||||
pub content: String,
|
||||
/// Whether the tool encountered an error.
|
||||
pub is_error: bool,
|
||||
}
|
||||
|
||||
/// Risk classification for tool approval gating.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum RiskLevel {
|
||||
/// Safe to execute without user confirmation (e.g. read-only operations).
|
||||
AutoApprove,
|
||||
/// Requires explicit user approval before execution (e.g. writes, shell).
|
||||
RequiresApproval,
|
||||
}
|
||||
|
||||
/// A tool that the model can invoke.
|
||||
///
|
||||
/// The `execute` method is async so that tool implementations can use
|
||||
/// `tokio::fs` and `tokio::process` without blocking a Tokio worker thread.
|
||||
/// `#[async_trait]` desugars the async fn to a boxed future, which is required
|
||||
/// for `dyn Tool` to remain object-safe.
|
||||
#[async_trait]
|
||||
pub trait Tool: Send + Sync {
|
||||
/// The name the model uses to invoke this tool.
|
||||
fn name(&self) -> &str;
|
||||
/// Human-readable description for the model.
|
||||
fn description(&self) -> &str;
|
||||
/// JSON Schema for the tool's input parameters.
|
||||
fn input_schema(&self) -> serde_json::Value;
|
||||
/// The risk level of this tool.
|
||||
fn risk_level(&self) -> RiskLevel;
|
||||
/// Execute the tool with the given input, confined to `working_dir`.
|
||||
async fn execute(
|
||||
&self,
|
||||
input: &serde_json::Value,
|
||||
working_dir: &Path,
|
||||
) -> Result<ToolOutput, ToolError>;
|
||||
}
|
||||
|
||||
/// Errors from tool execution.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ToolError {
|
||||
/// The requested path escapes the working directory.
|
||||
#[error("path escapes working directory: {0}")]
|
||||
PathEscape(PathBuf),
|
||||
/// An I/O error during tool execution.
|
||||
#[error("I/O error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
/// A required input field is missing or has the wrong type.
|
||||
#[error("invalid input: {0}")]
|
||||
InvalidInput(String),
|
||||
}
|
||||
|
||||
/// Validate that `requested` resolves to a path inside `working_dir`.
|
||||
///
|
||||
/// Joins `working_dir` with `requested`, canonicalizes the result (resolving
|
||||
/// symlinks and `..` components), and checks that the canonical path starts
|
||||
/// with the canonical working directory.
|
||||
///
|
||||
/// Returns the canonical path on success, or [`ToolError::PathEscape`] if the
|
||||
/// path would escape the working directory.
|
||||
pub fn validate_path(working_dir: &Path, requested: &str) -> Result<PathBuf, ToolError> {
|
||||
let candidate = if Path::new(requested).is_absolute() {
|
||||
PathBuf::from(requested)
|
||||
} else {
|
||||
working_dir.join(requested)
|
||||
};
|
||||
|
||||
// For paths that don't exist yet (e.g. write_file creating a new file),
|
||||
// canonicalize the parent directory and append the filename.
|
||||
let canonical = if candidate.exists() {
|
||||
candidate
|
||||
.canonicalize()
|
||||
.map_err(|_| ToolError::PathEscape(candidate.clone()))?
|
||||
} else {
|
||||
let parent = candidate
|
||||
.parent()
|
||||
.ok_or_else(|| ToolError::PathEscape(candidate.clone()))?;
|
||||
let file_name = candidate
|
||||
.file_name()
|
||||
.ok_or_else(|| ToolError::PathEscape(candidate.clone()))?;
|
||||
let canonical_parent = parent
|
||||
.canonicalize()
|
||||
.map_err(|_| ToolError::PathEscape(candidate.clone()))?;
|
||||
canonical_parent.join(file_name)
|
||||
};
|
||||
|
||||
let canonical_root = working_dir
|
||||
.canonicalize()
|
||||
.map_err(|_| ToolError::PathEscape(candidate.clone()))?;
|
||||
|
||||
if canonical.starts_with(&canonical_root) {
|
||||
Ok(canonical)
|
||||
} else {
|
||||
Err(ToolError::PathEscape(candidate))
|
||||
}
|
||||
}
|
||||
|
||||
/// Collection of available tools with name-based lookup.
|
||||
pub struct ToolRegistry {
|
||||
tools: Vec<Box<dyn Tool>>,
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
/// Create an empty registry with no tools.
|
||||
#[allow(dead_code)]
|
||||
pub fn empty() -> Self {
|
||||
Self { tools: Vec::new() }
|
||||
}
|
||||
|
||||
/// Create a registry with the default built-in tools.
|
||||
pub fn default_tools() -> Self {
|
||||
Self {
|
||||
tools: vec![
|
||||
Box::new(read_file::ReadFile),
|
||||
Box::new(list_directory::ListDirectory),
|
||||
Box::new(write_file::WriteFile),
|
||||
Box::new(shell_exec::ShellExec),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
/// Look up a tool by name.
|
||||
pub fn get(&self, name: &str) -> Option<&dyn Tool> {
|
||||
self.tools.iter().find(|t| t.name() == name).map(|t| &**t)
|
||||
}
|
||||
|
||||
/// Generate [`ToolDefinition`]s for the model provider.
|
||||
pub fn definitions(&self) -> Vec<ToolDefinition> {
|
||||
self.tools
|
||||
.iter()
|
||||
.map(|t| ToolDefinition {
|
||||
name: t.name().to_string(),
|
||||
description: t.description().to_string(),
|
||||
input_schema: t.input_schema(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::fs;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn validate_path_allows_subpath() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let sub = dir.path().join("sub");
|
||||
fs::create_dir(&sub).unwrap();
|
||||
let result = validate_path(dir.path(), "sub");
|
||||
assert!(result.is_ok());
|
||||
assert!(
|
||||
result
|
||||
.unwrap()
|
||||
.starts_with(dir.path().canonicalize().unwrap())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_path_rejects_traversal() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let result = validate_path(dir.path(), "../../../etc/passwd");
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result, Err(ToolError::PathEscape(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_path_rejects_absolute_outside() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let result = validate_path(dir.path(), "/etc/passwd");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_path_allows_new_file_in_working_dir() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let result = validate_path(dir.path(), "new_file.txt");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn registry_default_has_all_tools() {
|
||||
let reg = ToolRegistry::default_tools();
|
||||
assert!(reg.get("read_file").is_some());
|
||||
assert!(reg.get("list_directory").is_some());
|
||||
assert!(reg.get("write_file").is_some());
|
||||
assert!(reg.get("shell_exec").is_some());
|
||||
assert!(reg.get("nonexistent").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn registry_definitions_match_tools() {
|
||||
let reg = ToolRegistry::default_tools();
|
||||
let defs = reg.definitions();
|
||||
assert_eq!(defs.len(), 4);
|
||||
assert!(defs.iter().any(|d| d.name == "read_file"));
|
||||
}
|
||||
}
|
||||
92
src/tools/read_file.rs
Normal file
92
src/tools/read_file.rs
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
//! `read_file` tool: reads the contents of a file within the working directory.
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use super::{RiskLevel, Tool, ToolError, ToolOutput, validate_path};
|
||||
|
||||
/// Reads file contents. Auto-approved (read-only).
|
||||
pub struct ReadFile;
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ReadFile {
|
||||
fn name(&self) -> &str {
|
||||
"read_file"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Read the contents of a file. The path is relative to the working directory."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The file path to read, relative to the working directory."
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
})
|
||||
}
|
||||
|
||||
fn risk_level(&self) -> RiskLevel {
|
||||
RiskLevel::AutoApprove
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
&self,
|
||||
input: &serde_json::Value,
|
||||
working_dir: &Path,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let path_str = input["path"]
|
||||
.as_str()
|
||||
.ok_or_else(|| ToolError::InvalidInput("missing 'path' string".to_string()))?;
|
||||
|
||||
let canonical = validate_path(working_dir, path_str)?;
|
||||
let content = tokio::fs::read_to_string(&canonical).await?;
|
||||
|
||||
Ok(ToolOutput {
|
||||
content,
|
||||
is_error: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::fs;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_existing_file() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
fs::write(dir.path().join("hello.txt"), "world").unwrap();
|
||||
let tool = ReadFile;
|
||||
let input = serde_json::json!({"path": "hello.txt"});
|
||||
let out = tool.execute(&input, dir.path()).await.unwrap();
|
||||
assert_eq!(out.content, "world");
|
||||
assert!(!out.is_error);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_nonexistent_file_errors() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let tool = ReadFile;
|
||||
let input = serde_json::json!({"path": "nope.txt"});
|
||||
let result = tool.execute(&input, dir.path()).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_path_traversal_rejected() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let tool = ReadFile;
|
||||
let input = serde_json::json!({"path": "../../../etc/passwd"});
|
||||
let result = tool.execute(&input, dir.path()).await;
|
||||
assert!(matches!(result, Err(ToolError::PathEscape(_))));
|
||||
}
|
||||
}
|
||||
108
src/tools/shell_exec.rs
Normal file
108
src/tools/shell_exec.rs
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
//! `shell_exec` tool: runs a shell command within the working directory.
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use super::{RiskLevel, Tool, ToolError, ToolOutput};
|
||||
|
||||
/// Executes a shell command. Requires user approval.
|
||||
pub struct ShellExec;
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ShellExec {
|
||||
fn name(&self) -> &str {
|
||||
"shell_exec"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Execute a shell command in the working directory. Returns stdout and stderr."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The shell command to execute."
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
})
|
||||
}
|
||||
|
||||
fn risk_level(&self) -> RiskLevel {
|
||||
RiskLevel::RequiresApproval
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
&self,
|
||||
input: &serde_json::Value,
|
||||
working_dir: &Path,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let command = input["command"]
|
||||
.as_str()
|
||||
.ok_or_else(|| ToolError::InvalidInput("missing 'command' string".to_string()))?;
|
||||
|
||||
let output = tokio::process::Command::new("sh")
|
||||
.arg("-c")
|
||||
.arg(command)
|
||||
.current_dir(working_dir)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
|
||||
let mut content = String::new();
|
||||
if !stdout.is_empty() {
|
||||
content.push_str(&stdout);
|
||||
}
|
||||
if !stderr.is_empty() {
|
||||
if !content.is_empty() {
|
||||
content.push('\n');
|
||||
}
|
||||
content.push_str("[stderr]\n");
|
||||
content.push_str(&stderr);
|
||||
}
|
||||
if content.is_empty() {
|
||||
content.push_str("(no output)");
|
||||
}
|
||||
|
||||
let is_error = !output.status.success();
|
||||
if is_error {
|
||||
content.push_str(&format!(
|
||||
"\n[exit code: {}]",
|
||||
output.status.code().unwrap_or(-1)
|
||||
));
|
||||
}
|
||||
|
||||
Ok(ToolOutput { content, is_error })
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn shell_exec_echo() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let tool = ShellExec;
|
||||
let input = serde_json::json!({"command": "echo hello"});
|
||||
let out = tool.execute(&input, dir.path()).await.unwrap();
|
||||
assert!(out.content.contains("hello"));
|
||||
assert!(!out.is_error);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shell_exec_failing_command() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let tool = ShellExec;
|
||||
let input = serde_json::json!({"command": "false"});
|
||||
let out = tool.execute(&input, dir.path()).await.unwrap();
|
||||
assert!(out.is_error);
|
||||
}
|
||||
}
|
||||
98
src/tools/write_file.rs
Normal file
98
src/tools/write_file.rs
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
//! `write_file` tool: writes content to a file within the working directory.
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use super::{RiskLevel, Tool, ToolError, ToolOutput, validate_path};
|
||||
|
||||
/// Writes content to a file. Requires user approval.
|
||||
pub struct WriteFile;
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for WriteFile {
|
||||
fn name(&self) -> &str {
|
||||
"write_file"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Write content to a file. Creates the file if it doesn't exist, overwrites if it does. The path is relative to the working directory."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The file path to write to, relative to the working directory."
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The content to write to the file."
|
||||
}
|
||||
},
|
||||
"required": ["path", "content"]
|
||||
})
|
||||
}
|
||||
|
||||
fn risk_level(&self) -> RiskLevel {
|
||||
RiskLevel::RequiresApproval
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
&self,
|
||||
input: &serde_json::Value,
|
||||
working_dir: &Path,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let path_str = input["path"]
|
||||
.as_str()
|
||||
.ok_or_else(|| ToolError::InvalidInput("missing 'path' string".to_string()))?;
|
||||
let content = input["content"]
|
||||
.as_str()
|
||||
.ok_or_else(|| ToolError::InvalidInput("missing 'content' string".to_string()))?;
|
||||
|
||||
let canonical = validate_path(working_dir, path_str)?;
|
||||
|
||||
// Create parent directories if needed.
|
||||
if let Some(parent) = canonical.parent() {
|
||||
tokio::fs::create_dir_all(parent).await?;
|
||||
}
|
||||
|
||||
tokio::fs::write(&canonical, content).await?;
|
||||
|
||||
Ok(ToolOutput {
|
||||
content: format!("Wrote {} bytes to {path_str}", content.len()),
|
||||
is_error: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::fs;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_creates_file() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let tool = WriteFile;
|
||||
let input = serde_json::json!({"path": "out.txt", "content": "hello"});
|
||||
let out = tool.execute(&input, dir.path()).await.unwrap();
|
||||
assert!(!out.is_error);
|
||||
assert_eq!(
|
||||
fs::read_to_string(dir.path().join("out.txt")).unwrap(),
|
||||
"hello"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_path_traversal_rejected() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let tool = WriteFile;
|
||||
let input = serde_json::json!({"path": "../../evil.txt", "content": "bad"});
|
||||
let result = tool.execute(&input, dir.path()).await;
|
||||
assert!(matches!(result, Err(ToolError::PathEscape(_))));
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue