skate/src/tools/write_file.rs
2026-03-01 18:53:02 -08:00

98 lines
2.9 KiB
Rust

//! `write_file` tool: writes content to a file within the working directory.
use std::path::Path;
use async_trait::async_trait;
use super::{RiskLevel, Tool, ToolError, ToolOutput, validate_path};
/// Writes content to a file. Requires user approval.
pub struct WriteFile;
#[async_trait]
impl Tool for WriteFile {
fn name(&self) -> &str {
"write_file"
}
fn description(&self) -> &str {
"Write content to a file. Creates the file if it doesn't exist, overwrites if it does. The path is relative to the working directory."
}
fn input_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "The file path to write to, relative to the working directory."
},
"content": {
"type": "string",
"description": "The content to write to the file."
}
},
"required": ["path", "content"]
})
}
fn risk_level(&self) -> RiskLevel {
RiskLevel::RequiresApproval
}
async fn execute(
&self,
input: &serde_json::Value,
working_dir: &Path,
) -> Result<ToolOutput, ToolError> {
let path_str = input["path"]
.as_str()
.ok_or_else(|| ToolError::InvalidInput("missing 'path' string".to_string()))?;
let content = input["content"]
.as_str()
.ok_or_else(|| ToolError::InvalidInput("missing 'content' string".to_string()))?;
let canonical = validate_path(working_dir, path_str)?;
// Create parent directories if needed.
if let Some(parent) = canonical.parent() {
tokio::fs::create_dir_all(parent).await?;
}
tokio::fs::write(&canonical, content).await?;
Ok(ToolOutput {
content: format!("Wrote {} bytes to {path_str}", content.len()),
is_error: false,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[tokio::test]
async fn write_creates_file() {
let dir = TempDir::new().unwrap();
let tool = WriteFile;
let input = serde_json::json!({"path": "out.txt", "content": "hello"});
let out = tool.execute(&input, dir.path()).await.unwrap();
assert!(!out.is_error);
assert_eq!(
fs::read_to_string(dir.path().join("out.txt")).unwrap(),
"hello"
);
}
#[tokio::test]
async fn write_path_traversal_rejected() {
let dir = TempDir::new().unwrap();
let tool = WriteFile;
let input = serde_json::json!({"path": "../../evil.txt", "content": "bad"});
let result = tool.execute(&input, dir.path()).await;
assert!(matches!(result, Err(ToolError::PathEscape(_))));
}
}