feat: enable override prompt templates directory (#2387)

This commit is contained in:
doodleEsc
2025-07-02 14:43:17 +08:00
committed by GitHub
parent 6bbf3d2004
commit 4329ed79f4
4 changed files with 100 additions and 6 deletions

View File

@@ -38,8 +38,10 @@ M._defaults = {
-- For most providers that we support we will determine this automatically.
-- If you wish to use a given implementation, then you can override it here.
tokenizer = "tiktoken",
---@type string | (fun(): string) | nil
---@type string | fun(): string | nil
system_prompt = nil,
---@type string | fun(): string | nil
override_prompt_dir = nil,
rules = {
project_dir = nil, ---@type string | nil (could be relative dirpath)
global_dir = nil, ---@type string | nil (absolute dirpath)

View File

@@ -7,6 +7,7 @@ local Config = require("avante.config")
---@class avante.Path
---@field history_path Path
---@field cache_path Path
---@field data_path Path
local P = {}
---@param bufnr integer | nil
@@ -243,6 +244,45 @@ function Prompt.get_templates_dir(project_root)
end
end
-- Check for override prompt
local override_prompt_dir = Config.override_prompt_dir
if override_prompt_dir then
-- Handle the case where override_prompt_dir is a function
if type(override_prompt_dir) == "function" then
local ok, result = pcall(override_prompt_dir)
if ok and result then override_prompt_dir = result end
end
if override_prompt_dir then
local user_template_path = Path:new(override_prompt_dir)
if user_template_path:exists() then
local user_scanner = Scan.scan_dir(user_template_path:absolute(), { depth = 1, add_dirs = false })
for _, entry in ipairs(user_scanner) do
local file = Path:new(entry)
if file:is_file() then
local pieces = vim.split(entry, "/")
local piece = pieces[#pieces]
if piece == "base.avanterules" then
local content = file:read()
if not content:match("{%% block extra_prompt %%}[%s,\\n]*{%% endblock %%}") then
file:write("{% block extra_prompt %}\n", "a")
file:write("{% endblock %}\n", "a")
end
if not content:match("{%% block custom_prompt %%}[%s,\\n]*{%% endblock %%}") then
file:write("{% block custom_prompt %}\n", "a")
file:write("{% endblock %}", "a")
end
end
file:copy({ destination = cache_prompt_dir:joinpath(piece) })
end
end
end
end
end
if Config.rules.project_dir then
local project_rules_path = Path:new(Config.rules.project_dir)
if not project_rules_path:is_absolute() then project_rules_path = directory:joinpath(project_rules_path) end
@@ -251,8 +291,12 @@ function Prompt.get_templates_dir(project_root)
find_rules(Config.rules.global_dir)
find_rules(directory:absolute())
Path:new(debug.getinfo(1).source:match("@?(.*/)"):gsub("/lua/avante/path.lua$", "") .. "templates")
:copy({ destination = cache_prompt_dir, recursive = true })
-- Copy built-in templates to cache directory (only if not overridden by user templates)
Path:new(debug.getinfo(1).source:match("@?(.*/)"):gsub("/lua/avante/path.lua$", "") .. "templates"):copy({
destination = cache_prompt_dir,
recursive = true,
override = false,
})
vim.iter(Prompt.custom_prompts_contents):filter(function(_, v) return v ~= nil end):each(function(k, v)
local orig_file = cache_prompt_dir:joinpath(Prompt.get_builtin_prompts_filepath(k))