feat: memory 🧠 (#793)

This commit is contained in:
yetone
2024-11-04 16:20:28 +08:00
committed by GitHub
parent 579ef12f76
commit 1e8abbf798
13 changed files with 197 additions and 154 deletions

View File

@@ -28,7 +28,7 @@ local group = api.nvim_create_augroup("avante_llm", { clear = true })
---@field file_content string
---@field selected_code string | nil
---@field project_context string | nil
---@field memory_context string | nil
---@field history_messages AvanteLLMMessage[]
---
---@class StreamOptions: TemplateOptions
---@field ask boolean
@@ -44,10 +44,12 @@ M.stream = function(opts)
local mode = opts.mode or "planning"
---@type AvanteProviderFunctor
local Provider = opts.provider or P[Config.provider]
local _, body_opts = P.parse_config(Provider)
local max_tokens = body_opts.max_tokens or 4096
-- Check if the instructions contains an image path
local image_paths = {}
local original_instructions = opts.instructions
local instructions = opts.instructions
if opts.instructions:match("image: ") then
local lines = vim.split(opts.instructions, "\n")
for i, line in ipairs(lines) do
@@ -57,7 +59,7 @@ M.stream = function(opts)
table.remove(lines, i)
end
end
original_instructions = table.concat(lines, "\n")
instructions = table.concat(lines, "\n")
end
Path.prompts.initialize(Path.prompts.get(opts.bufnr))
@@ -67,29 +69,61 @@ M.stream = function(opts)
local template_opts = {
use_xml_format = Provider.use_xml_format,
ask = opts.ask, -- TODO: add mode without ask instruction
question = original_instructions,
code_lang = opts.code_lang,
filepath = filepath,
file_content = opts.file_content,
selected_code = opts.selected_code,
project_context = opts.project_context,
memory_context = opts.memory_context,
}
local user_prompts = vim
.iter({
Path.prompts.render_file("_project.avanterules", template_opts),
Path.prompts.render_file("_memory.avanterules", template_opts),
Path.prompts.render_file("_context.avanterules", template_opts),
Path.prompts.render_mode(mode, template_opts),
})
:filter(function(k) return k ~= "" end)
:totable()
local system_prompt = Path.prompts.render_mode(mode, template_opts)
---@type AvanteLLMMessage[]
local messages = {}
if opts.project_context ~= nil and opts.project_context ~= "" and opts.project_context ~= "null" then
local project_context = Path.prompts.render_file("_project.avanterules", template_opts)
if project_context ~= "" then table.insert(messages, { role = "user", content = project_context }) end
end
local code_context = Path.prompts.render_file("_context.avanterules", template_opts)
if code_context ~= "" then table.insert(messages, { role = "user", content = code_context }) end
if opts.use_xml_format then
table.insert(messages, { role = "user", content = string.format("<question>%s</question>", instructions) })
else
table.insert(messages, { role = "user", content = string.format("QUESTION:\n%s", instructions) })
end
local remaining_tokens = max_tokens - Utils.tokens.calculate_tokens(system_prompt)
for _, message in ipairs(messages) do
remaining_tokens = remaining_tokens - Utils.tokens.calculate_tokens(message.content)
end
if opts.history_messages then
if Config.history.max_tokens > 0 then remaining_tokens = math.min(Config.history.max_tokens, remaining_tokens) end
-- Traverse the history in reverse, keeping only the latest history until the remaining tokens are exhausted and the first message role is "user"
local history_messages = {}
for i = #opts.history_messages, 1, -1 do
local message = opts.history_messages[i]
local tokens = Utils.tokens.calculate_tokens(message.content)
remaining_tokens = remaining_tokens - tokens
if remaining_tokens > 0 then
table.insert(history_messages, message)
else
break
end
end
if #history_messages > 0 and history_messages[1].role == "assistant" then table.remove(history_messages, 1) end
-- prepend the history messages to the messages table
vim.iter(history_messages):each(function(msg) table.insert(messages, 1, msg) end)
end
---@type AvantePromptOptions
local code_opts = {
system_prompt = Config.system_prompt,
user_prompts = user_prompts,
system_prompt = system_prompt,
messages = messages,
image_paths = image_paths,
}
@@ -164,7 +198,7 @@ M.stream = function(opts)
on_error = function(err)
if err.exit == 23 then
local xdg_runtime_dir = os.getenv("XDG_RUNTIME_DIR")
if fn.isdirectory(xdg_runtime_dir) == 0 then
if not xdg_runtime_dir or fn.isdirectory(xdg_runtime_dir) == 0 then
Utils.error(
"$XDG_RUNTIME_DIR="
.. xdg_runtime_dir