feat: memory 🧠 (#793)
This commit is contained in:
@@ -28,7 +28,7 @@ local group = api.nvim_create_augroup("avante_llm", { clear = true })
|
||||
---@field file_content string
|
||||
---@field selected_code string | nil
|
||||
---@field project_context string | nil
|
||||
---@field memory_context string | nil
|
||||
---@field history_messages AvanteLLMMessage[]
|
||||
---
|
||||
---@class StreamOptions: TemplateOptions
|
||||
---@field ask boolean
|
||||
@@ -44,10 +44,12 @@ M.stream = function(opts)
|
||||
local mode = opts.mode or "planning"
|
||||
---@type AvanteProviderFunctor
|
||||
local Provider = opts.provider or P[Config.provider]
|
||||
local _, body_opts = P.parse_config(Provider)
|
||||
local max_tokens = body_opts.max_tokens or 4096
|
||||
|
||||
-- Check if the instructions contains an image path
|
||||
local image_paths = {}
|
||||
local original_instructions = opts.instructions
|
||||
local instructions = opts.instructions
|
||||
if opts.instructions:match("image: ") then
|
||||
local lines = vim.split(opts.instructions, "\n")
|
||||
for i, line in ipairs(lines) do
|
||||
@@ -57,7 +59,7 @@ M.stream = function(opts)
|
||||
table.remove(lines, i)
|
||||
end
|
||||
end
|
||||
original_instructions = table.concat(lines, "\n")
|
||||
instructions = table.concat(lines, "\n")
|
||||
end
|
||||
|
||||
Path.prompts.initialize(Path.prompts.get(opts.bufnr))
|
||||
@@ -67,29 +69,61 @@ M.stream = function(opts)
|
||||
local template_opts = {
|
||||
use_xml_format = Provider.use_xml_format,
|
||||
ask = opts.ask, -- TODO: add mode without ask instruction
|
||||
question = original_instructions,
|
||||
code_lang = opts.code_lang,
|
||||
filepath = filepath,
|
||||
file_content = opts.file_content,
|
||||
selected_code = opts.selected_code,
|
||||
project_context = opts.project_context,
|
||||
memory_context = opts.memory_context,
|
||||
}
|
||||
|
||||
local user_prompts = vim
|
||||
.iter({
|
||||
Path.prompts.render_file("_project.avanterules", template_opts),
|
||||
Path.prompts.render_file("_memory.avanterules", template_opts),
|
||||
Path.prompts.render_file("_context.avanterules", template_opts),
|
||||
Path.prompts.render_mode(mode, template_opts),
|
||||
})
|
||||
:filter(function(k) return k ~= "" end)
|
||||
:totable()
|
||||
local system_prompt = Path.prompts.render_mode(mode, template_opts)
|
||||
|
||||
---@type AvanteLLMMessage[]
|
||||
local messages = {}
|
||||
|
||||
if opts.project_context ~= nil and opts.project_context ~= "" and opts.project_context ~= "null" then
|
||||
local project_context = Path.prompts.render_file("_project.avanterules", template_opts)
|
||||
if project_context ~= "" then table.insert(messages, { role = "user", content = project_context }) end
|
||||
end
|
||||
|
||||
local code_context = Path.prompts.render_file("_context.avanterules", template_opts)
|
||||
if code_context ~= "" then table.insert(messages, { role = "user", content = code_context }) end
|
||||
|
||||
if opts.use_xml_format then
|
||||
table.insert(messages, { role = "user", content = string.format("<question>%s</question>", instructions) })
|
||||
else
|
||||
table.insert(messages, { role = "user", content = string.format("QUESTION:\n%s", instructions) })
|
||||
end
|
||||
|
||||
local remaining_tokens = max_tokens - Utils.tokens.calculate_tokens(system_prompt)
|
||||
|
||||
for _, message in ipairs(messages) do
|
||||
remaining_tokens = remaining_tokens - Utils.tokens.calculate_tokens(message.content)
|
||||
end
|
||||
|
||||
if opts.history_messages then
|
||||
if Config.history.max_tokens > 0 then remaining_tokens = math.min(Config.history.max_tokens, remaining_tokens) end
|
||||
-- Traverse the history in reverse, keeping only the latest history until the remaining tokens are exhausted and the first message role is "user"
|
||||
local history_messages = {}
|
||||
for i = #opts.history_messages, 1, -1 do
|
||||
local message = opts.history_messages[i]
|
||||
local tokens = Utils.tokens.calculate_tokens(message.content)
|
||||
remaining_tokens = remaining_tokens - tokens
|
||||
if remaining_tokens > 0 then
|
||||
table.insert(history_messages, message)
|
||||
else
|
||||
break
|
||||
end
|
||||
end
|
||||
if #history_messages > 0 and history_messages[1].role == "assistant" then table.remove(history_messages, 1) end
|
||||
-- prepend the history messages to the messages table
|
||||
vim.iter(history_messages):each(function(msg) table.insert(messages, 1, msg) end)
|
||||
end
|
||||
|
||||
---@type AvantePromptOptions
|
||||
local code_opts = {
|
||||
system_prompt = Config.system_prompt,
|
||||
user_prompts = user_prompts,
|
||||
system_prompt = system_prompt,
|
||||
messages = messages,
|
||||
image_paths = image_paths,
|
||||
}
|
||||
|
||||
@@ -164,7 +198,7 @@ M.stream = function(opts)
|
||||
on_error = function(err)
|
||||
if err.exit == 23 then
|
||||
local xdg_runtime_dir = os.getenv("XDG_RUNTIME_DIR")
|
||||
if fn.isdirectory(xdg_runtime_dir) == 0 then
|
||||
if not xdg_runtime_dir or fn.isdirectory(xdg_runtime_dir) == 0 then
|
||||
Utils.error(
|
||||
"$XDG_RUNTIME_DIR="
|
||||
.. xdg_runtime_dir
|
||||
|
||||
Reference in New Issue
Block a user