feat: implement a more flexible custom prompts solution (#1390)

This commit is contained in:
yetone
2025-02-25 16:08:16 +08:00
committed by GitHub
parent 481a44f041
commit 2b3a41e811
5 changed files with 59 additions and 33 deletions

View File

@@ -44,7 +44,7 @@ function M.generate_prompts(opts)
end
local project_root = Utils.root.get()
Path.prompts.initialize(Path.prompts.get(project_root))
Path.prompts.initialize(Path.prompts.get_templates_dir(project_root))
local system_info = Utils.get_system_info()

View File

@@ -79,20 +79,29 @@ local Prompt = {}
-- Given a mode, return the file name for the custom prompt.
---@param mode AvanteLlmMode
function Prompt.get_mode_file(mode) return string.format("custom.%s.avanterules", mode) end
---@return string
function Prompt.get_custom_prompts_filepath(mode) return string.format("custom.%s.avanterules", mode) end
function Prompt.get_builtin_prompts_filepath(mode) return string.format("%s.avanterules", mode) end
---@class AvanteTemplates
---@field initialize fun(directory: string): nil
---@field render fun(template: string, context: AvanteTemplateOptions): string
local templates = nil
local _templates_lib = nil
Prompt.templates = { planning = nil, editing = nil, suggesting = nil }
Prompt.custom_modes = {
planning = true,
editing = true,
suggesting = true,
["cursor-planning"] = true,
["cursor-applying"] = true,
}
Prompt.custom_prompts_contents = {}
-- We need to do this beacuse the prompt template engine requires a given directory to load all required files.
-- PERF: Hmm instead of copy to cache, we can also load in globals context, but it requires some work on bindings. (eh maybe?)
---@param project_root string
---@return string the resulted cache_directory to be loaded with avante_templates
function Prompt.get(project_root)
---@return string templates_dir
function Prompt.get_templates_dir(project_root)
if not P.available() then error("Make sure to build avante (missing avante_templates)", 2) end
-- get root directory of given bufnr
@@ -107,16 +116,13 @@ function Prompt.get(project_root)
for _, entry in ipairs(scanner) do
local file = Path:new(entry)
if file:is_file() then
if not entry:match("%.avanterules$") then goto continue end
if entry:find("planning") and Prompt.templates.planning == nil then
Utils.info(string.format("Using %s as planning system prompt", entry))
Prompt.templates.planning = file:read()
elseif entry:find("editing") and Prompt.templates.editing == nil then
Utils.info(string.format("Using %s as editing system prompt", entry))
Prompt.templates.editing = file:read()
elseif entry:find("suggesting") and Prompt.templates.suggesting == nil then
Utils.info(string.format("Using %s as suggesting system prompt", entry))
Prompt.templates.suggesting = file:read()
local pieces = vim.split(entry, "/")
local piece = pieces[#pieces]
local mode = piece:match("([^.]+)%.avanterules$")
if not mode or not Prompt.custom_modes[mode] then goto continue end
if Prompt.custom_prompts_contents[mode] == nil then
Utils.info(string.format("Using %s as %s system prompt", entry, mode))
Prompt.custom_prompts_contents[mode] = file:read()
end
end
::continue::
@@ -125,29 +131,41 @@ function Prompt.get(project_root)
Path:new(debug.getinfo(1).source:match("@?(.*/)"):gsub("/lua/avante/path.lua$", "") .. "templates")
:copy({ destination = cache_prompt_dir, recursive = true })
vim.iter(Prompt.templates):filter(function(_, v) return v ~= nil end):each(function(k, v)
local f = cache_prompt_dir:joinpath(Prompt.get_mode_file(k))
f:write(v, "w")
vim.iter(Prompt.custom_prompts_contents):filter(function(_, v) return v ~= nil end):each(function(k, v)
local orig_file = cache_prompt_dir:joinpath(Prompt.get_builtin_prompts_filepath(k))
local orig_content = orig_file:read()
local f = cache_prompt_dir:joinpath(Prompt.get_custom_prompts_filepath(k))
f:write(orig_content, "w")
f:write("{% block custom_prompt -%}\n", "a")
f:write(v, "a")
f:write("\n{%- endblock %}", "a")
end)
return cache_prompt_dir:absolute()
local dir = cache_prompt_dir:absolute()
Utils.debug("Prompt cache directory:", dir)
return dir
end
---@param mode AvanteLlmMode
function Prompt.get_file(mode)
if Prompt.templates[mode] ~= nil then return Prompt.get_mode_file(mode) end
return string.format("%s.avanterules", mode)
---@return string
function Prompt.get_filepath(mode)
if Prompt.custom_prompts_contents[mode] ~= nil then return Prompt.get_custom_prompts_filepath(mode) end
return Prompt.get_builtin_prompts_filepath(mode)
end
---@param path string
---@param opts AvanteTemplateOptions
function Prompt.render_file(path, opts) return templates.render(path, opts) end
function Prompt.render_file(path, opts) return _templates_lib.render(path, opts) end
---@param mode AvanteLlmMode
---@param opts AvanteTemplateOptions
function Prompt.render_mode(mode, opts) return templates.render(Prompt.get_file(mode), opts) end
function Prompt.render_mode(mode, opts)
local filepath = Prompt.get_filepath(mode)
Utils.debug("Prompt filepath:", filepath)
return _templates_lib.render(filepath, opts)
end
function Prompt.initialize(directory) templates.initialize(directory) end
function Prompt.initialize(directory) _templates_lib.initialize(directory) end
P.prompts = Prompt
@@ -184,14 +202,14 @@ P.repo_map = RepoMap
---@return AvanteTemplates|nil
function P._init_templates_lib()
if templates ~= nil then return templates end
if _templates_lib ~= nil then return _templates_lib end
local ok, module = pcall(require, "avante_templates")
---@cast module AvanteTemplates
---@cast ok boolean
if not ok then return nil end
templates = module
_templates_lib = module
return templates
return _templates_lib
end
function P.setup()

View File

@@ -20,3 +20,6 @@ Use the appropriate shell based on the user's system info:
{% block extra_prompt %}
{% endblock %}
{% block custom_prompt %}
{% endblock %}

View File

@@ -1 +1,4 @@
You are a coding assistant that helps merge code updates, ensuring every modification is fully integrated.
{% block custom_prompt %}
{% endblock %}