feat: load templates from cache and project directories (#2126)
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
use minijinja::{context, path_loader, Environment};
|
use minijinja::{context, Environment};
|
||||||
use mlua::prelude::*;
|
use mlua::prelude::*;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::path::Path;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
struct State<'a> {
|
struct State<'a> {
|
||||||
@@ -79,11 +80,39 @@ fn render(state: &State, template: &str, context: TemplateContext) -> LuaResult<
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn initialize(state: &State, directory: String) {
|
fn initialize(state: &State, cache_directory: String, project_directory: String) {
|
||||||
let mut environment_mutex = state.environment.lock().unwrap();
|
let mut environment_mutex = state.environment.lock().unwrap();
|
||||||
// add directory as a base path for base directory template path
|
|
||||||
let mut env = Environment::new();
|
let mut env = Environment::new();
|
||||||
env.set_loader(path_loader(directory));
|
|
||||||
|
// Create a custom loader that searches both cache and project directories
|
||||||
|
let cache_dir = cache_directory.clone();
|
||||||
|
let project_dir = project_directory.clone();
|
||||||
|
|
||||||
|
env.set_loader(
|
||||||
|
move |name: &str| -> Result<Option<String>, minijinja::Error> {
|
||||||
|
// First try the cache directory (for built-in templates)
|
||||||
|
let cache_path = Path::new(&cache_dir).join(name);
|
||||||
|
if cache_path.exists() {
|
||||||
|
match std::fs::read_to_string(&cache_path) {
|
||||||
|
Ok(content) => return Ok(Some(content)),
|
||||||
|
Err(_) => {} // Continue to try project directory
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then try the project directory (for custom includes)
|
||||||
|
let project_path = Path::new(&project_dir).join(name);
|
||||||
|
if project_path.exists() {
|
||||||
|
match std::fs::read_to_string(&project_path) {
|
||||||
|
Ok(content) => return Ok(Some(content)),
|
||||||
|
Err(_) => {} // File not found or read error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Template not found in either directory
|
||||||
|
Ok(None)
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
*environment_mutex = Some(env);
|
*environment_mutex = Some(env);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -96,10 +125,12 @@ fn avante_templates(lua: &Lua) -> LuaResult<LuaTable> {
|
|||||||
let exports = lua.create_table()?;
|
let exports = lua.create_table()?;
|
||||||
exports.set(
|
exports.set(
|
||||||
"initialize",
|
"initialize",
|
||||||
lua.create_function(move |_, model: String| {
|
lua.create_function(
|
||||||
initialize(&state, model);
|
move |_, (cache_directory, project_directory): (String, String)| {
|
||||||
Ok(())
|
initialize(&state, cache_directory, project_directory);
|
||||||
})?,
|
Ok(())
|
||||||
|
},
|
||||||
|
)?,
|
||||||
)?;
|
)?;
|
||||||
exports.set(
|
exports.set(
|
||||||
"render",
|
"render",
|
||||||
|
|||||||
@@ -145,7 +145,7 @@ function M.generate_prompts(opts)
|
|||||||
end
|
end
|
||||||
|
|
||||||
local project_root = Utils.root.get()
|
local project_root = Utils.root.get()
|
||||||
Path.prompts.initialize(Path.prompts.get_templates_dir(project_root))
|
Path.prompts.initialize(Path.prompts.get_templates_dir(project_root), project_root)
|
||||||
|
|
||||||
local tool_id_to_tool_name = {}
|
local tool_id_to_tool_name = {}
|
||||||
local tool_id_to_path = {}
|
local tool_id_to_path = {}
|
||||||
|
|||||||
@@ -167,7 +167,7 @@ function Prompt.get_custom_prompts_filepath(mode) return string.format("custom.%
|
|||||||
function Prompt.get_builtin_prompts_filepath(mode) return string.format("%s.avanterules", mode) end
|
function Prompt.get_builtin_prompts_filepath(mode) return string.format("%s.avanterules", mode) end
|
||||||
|
|
||||||
---@class AvanteTemplates
|
---@class AvanteTemplates
|
||||||
---@field initialize fun(directory: string): nil
|
---@field initialize fun(cache_directory: string, project_directory: string): nil
|
||||||
---@field render fun(template: string, context: AvanteTemplateOptions): string
|
---@field render fun(template: string, context: AvanteTemplateOptions): string
|
||||||
local _templates_lib = nil
|
local _templates_lib = nil
|
||||||
|
|
||||||
@@ -244,7 +244,9 @@ function Prompt.render_mode(mode, opts)
|
|||||||
return _templates_lib.render(filepath, opts)
|
return _templates_lib.render(filepath, opts)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Prompt.initialize(directory) _templates_lib.initialize(directory) end
|
function Prompt.initialize(cache_directory, project_directory)
|
||||||
|
_templates_lib.initialize(cache_directory, project_directory)
|
||||||
|
end
|
||||||
|
|
||||||
P.prompts = Prompt
|
P.prompts = Prompt
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user