diff --git a/crates/avante-templates/src/lib.rs b/crates/avante-templates/src/lib.rs index ee62381..5387362 100644 --- a/crates/avante-templates/src/lib.rs +++ b/crates/avante-templates/src/lib.rs @@ -1,6 +1,7 @@ -use minijinja::{context, path_loader, Environment}; +use minijinja::{context, Environment}; use mlua::prelude::*; use serde::{Deserialize, Serialize}; +use std::path::Path; use std::sync::{Arc, Mutex}; struct State<'a> { @@ -79,11 +80,39 @@ fn render(state: &State, template: &str, context: TemplateContext) -> LuaResult< } } -fn initialize(state: &State, directory: String) { +fn initialize(state: &State, cache_directory: String, project_directory: String) { let mut environment_mutex = state.environment.lock().unwrap(); - // add directory as a base path for base directory template path let mut env = Environment::new(); - env.set_loader(path_loader(directory)); + + // Create a custom loader that searches both cache and project directories + let cache_dir = cache_directory.clone(); + let project_dir = project_directory.clone(); + + env.set_loader( + move |name: &str| -> Result, minijinja::Error> { + // First try the cache directory (for built-in templates) + let cache_path = Path::new(&cache_dir).join(name); + if cache_path.exists() { + match std::fs::read_to_string(&cache_path) { + Ok(content) => return Ok(Some(content)), + Err(_) => {} // Continue to try project directory + } + } + + // Then try the project directory (for custom includes) + let project_path = Path::new(&project_dir).join(name); + if project_path.exists() { + match std::fs::read_to_string(&project_path) { + Ok(content) => return Ok(Some(content)), + Err(_) => {} // File not found or read error + } + } + + // Template not found in either directory + Ok(None) + }, + ); + *environment_mutex = Some(env); } @@ -96,10 +125,12 @@ fn avante_templates(lua: &Lua) -> LuaResult { let exports = lua.create_table()?; exports.set( "initialize", - lua.create_function(move |_, model: String| { - initialize(&state, model); - Ok(()) - })?, + lua.create_function( + move |_, (cache_directory, project_directory): (String, String)| { + initialize(&state, cache_directory, project_directory); + Ok(()) + }, + )?, )?; exports.set( "render", diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 28e867e..bb16a09 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -145,7 +145,7 @@ function M.generate_prompts(opts) end local project_root = Utils.root.get() - Path.prompts.initialize(Path.prompts.get_templates_dir(project_root)) + Path.prompts.initialize(Path.prompts.get_templates_dir(project_root), project_root) local tool_id_to_tool_name = {} local tool_id_to_path = {} diff --git a/lua/avante/path.lua b/lua/avante/path.lua index 4e1d053..b573540 100644 --- a/lua/avante/path.lua +++ b/lua/avante/path.lua @@ -167,7 +167,7 @@ function Prompt.get_custom_prompts_filepath(mode) return string.format("custom.% function Prompt.get_builtin_prompts_filepath(mode) return string.format("%s.avanterules", mode) end ---@class AvanteTemplates ----@field initialize fun(directory: string): nil +---@field initialize fun(cache_directory: string, project_directory: string): nil ---@field render fun(template: string, context: AvanteTemplateOptions): string local _templates_lib = nil @@ -244,7 +244,9 @@ function Prompt.render_mode(mode, opts) return _templates_lib.render(filepath, opts) end -function Prompt.initialize(directory) _templates_lib.initialize(directory) end +function Prompt.initialize(cache_directory, project_directory) + _templates_lib.initialize(cache_directory, project_directory) +end P.prompts = Prompt