refactor: summarize memory (#1508)

This commit is contained in:
yetone
2025-03-07 00:12:57 +08:00
committed by GitHub
parent 2b0e7e09ae
commit 8620ea3e12
18 changed files with 434 additions and 217 deletions

View File

@@ -33,6 +33,7 @@ struct TemplateContext {
diagnostics: Option<String>,
system_info: Option<String>,
model_name: Option<String>,
memory: Option<String>,
}
// Given the file name registered after add, the context table in Lua, resulted in a formatted
@@ -58,6 +59,7 @@ fn render(state: &State, template: &str, context: TemplateContext) -> LuaResult<
diagnostics => context.diagnostics,
system_info => context.system_info,
model_name => context.model_name,
memory => context.memory,
})
.map_err(LuaError::external)
.unwrap())

View File

@@ -27,6 +27,7 @@ M._defaults = {
-- Of course, you can reduce the request frequency by increasing `suggestion.debounce`.
auto_suggestions_provider = "claude",
cursor_applying_provider = nil,
memory_summary_provider = nil,
---@alias Tokenizer "tiktoken" | "hf"
-- Used for counting tokens and encoding text.
-- By default, we will use tiktoken.
@@ -273,6 +274,10 @@ M._defaults = {
temperature = 0,
max_tokens = 8000,
},
["openai-gpt-4o-mini"] = {
__inherited_from = "openai",
model = "gpt-4o-mini",
},
},
---Specify the special dual_boost mode
---1. enabled: Whether to enable dual_boost mode. Default to false.

View File

@@ -19,6 +19,71 @@ M.CANCEL_PATTERN = "AvanteLLMEscape"
local group = api.nvim_create_augroup("avante_llm", { clear = true })
---@param bufnr integer
---@param history avante.ChatHistory
---@param cb fun(memory: avante.ChatMemory | nil): nil
function M.summarize_memory(bufnr, history, cb)
local system_prompt =
[[Summarize the following conversation to extract the most critical information (such as languages used, conversation style, tech stack, considerations, user information, etc.) for memory in subsequent conversations. Since it is for memory purposes, be detailed and rigorous to ensure that no information from previous summaries is lost in the newly generated summary.]]
local entries = Utils.history.filter_active_entries(history.entries)
if #entries == 0 then
cb(nil)
return
end
if history.memory then
system_prompt = system_prompt .. "\n\nThe previous summary is:\n\n" .. history.memory.content
entries = vim
.iter(entries)
:filter(function(entry) return entry.timestamp > history.memory.last_summarized_timestamp end)
:totable()
end
if #entries == 0 then
cb(history.memory)
return
end
local history_messages = Utils.history.entries_to_llm_messages(entries)
history_messages = vim.list_slice(history_messages, 1, 4)
if #history_messages == 0 then
cb(history.memory)
return
end
Utils.debug("summarize memory", #history_messages, history_messages[#history_messages].content)
local response_content = ""
local provider = Providers[Config.memory_summary_provider or Config.provider]
M.curl({
provider = provider,
prompt_opts = {
system_prompt = system_prompt,
messages = {
{ role = "user", content = vim.json.encode(history_messages) },
},
},
handler_opts = {
on_start = function(_) end,
on_chunk = function(chunk)
if not chunk then return end
response_content = response_content .. chunk
end,
on_stop = function(stop_opts)
if stop_opts.error ~= nil then
Utils.error(string.format("summarize failed: %s", vim.inspect(stop_opts.error)))
return
end
if stop_opts.reason == "complete" then
response_content = Utils.trim_think_content(response_content)
local memory = {
content = response_content,
last_summarized_timestamp = entries[#entries].timestamp,
}
history.memory = memory
Path.history.save(bufnr, history)
cb(memory)
end
end,
},
})
end
---@param opts AvanteGeneratePromptsOptions
---@return AvantePromptOptions
function M.generate_prompts(opts)
@@ -58,6 +123,7 @@ function M.generate_prompts(opts)
diagnostics = opts.diagnostics,
system_info = system_info,
model_name = provider.model or "unknown",
memory = opts.memory,
}
local system_prompt = Path.prompts.render_mode(mode, template_opts)
@@ -80,6 +146,11 @@ function M.generate_prompts(opts)
if code_context ~= "" then table.insert(messages, { role = "user", content = code_context }) end
end
if opts.memory ~= nil and opts.memory ~= "" and opts.memory ~= "null" then
local memory = Path.prompts.render_file("_memory.avanterules", template_opts)
if memory ~= "" then table.insert(messages, { role = "user", content = memory }) end
end
if instructions then
if opts.use_xml_format then
table.insert(messages, { role = "user", content = string.format("<question>%s</question>", instructions) })
@@ -150,90 +221,17 @@ function M.calculate_tokens(opts)
return tokens
end
---@param opts AvanteLLMStreamOptions
function M._stream(opts)
local provider = opts.provider or Providers[Config.provider]
---@param opts avante.CurlOpts
function M.curl(opts)
local provider = opts.provider
local prompt_opts = opts.prompt_opts
local handler_opts = opts.handler_opts
---@cast provider AvanteProviderFunctor
local prompt_opts = M.generate_prompts(opts)
---@type AvanteCurlOutput
local spec = provider:parse_curl_args(prompt_opts)
---@type string
local current_event_state = nil
---@type AvanteHandlerOptions
local handler_opts = {
on_start = opts.on_start,
on_chunk = opts.on_chunk,
on_stop = function(stop_opts)
---@param tool_use_list AvanteLLMToolUse[]
---@param tool_use_index integer
---@param tool_histories AvanteLLMToolHistory[]
local function handle_next_tool_use(tool_use_list, tool_use_index, tool_histories)
if tool_use_index > #tool_use_list then
local new_opts = vim.tbl_deep_extend("force", opts, {
tool_histories = tool_histories,
})
return M._stream(new_opts)
end
local tool_use = tool_use_list[tool_use_index]
---@param result string | nil
---@param error string | nil
local function handle_tool_result(result, error)
local tool_result = {
tool_use_id = tool_use.id,
content = error ~= nil and error or result,
is_error = error ~= nil,
}
table.insert(tool_histories, { tool_result = tool_result, tool_use = tool_use })
return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_histories)
end
-- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil
local result, error = LLMTools.process_tool_use(opts.tools, tool_use, opts.on_tool_log, handle_tool_result)
if result ~= nil or error ~= nil then return handle_tool_result(result, error) end
end
if stop_opts.reason == "tool_use" and stop_opts.tool_use_list then
local old_tool_histories = vim.deepcopy(opts.tool_histories) or {}
local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[]
for _, tool_use in vim.spairs(stop_opts.tool_use_list) do
table.insert(sorted_tool_use_list, tool_use)
end
return handle_next_tool_use(sorted_tool_use_list, 1, old_tool_histories)
end
if stop_opts.reason == "rate_limit" then
local msg = "Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ..."
opts.on_chunk("\n*[" .. msg .. "]*\n")
local timer = vim.loop.new_timer()
if timer then
local retry_after = stop_opts.retry_after
local function countdown()
timer:start(
1000,
0,
vim.schedule_wrap(function()
if retry_after > 0 then retry_after = retry_after - 1 end
local msg_ = "Rate limit reached. Retrying in " .. retry_after .. " seconds ..."
opts.on_chunk([[\033[1A\033[K]] .. "\n*[" .. msg_ .. "]*\n")
countdown()
end)
)
end
countdown()
end
Utils.info("Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds", { title = "Avante" })
vim.defer_fn(function()
if timer then timer:stop() end
M._stream(opts)
end, stop_opts.retry_after * 1000)
return
end
return opts.on_stop(stop_opts)
end,
}
---@type AvanteCurlOutput
local spec = provider:parse_curl_args(provider, prompt_opts)
local resp_ctx = {}
---@param line string
@@ -383,6 +381,91 @@ function M._stream(opts)
return active_job
end
---@param opts AvanteLLMStreamOptions
function M._stream(opts)
local provider = opts.provider or Providers[Config.provider]
---@cast provider AvanteProviderFunctor
local prompt_opts = M.generate_prompts(opts)
---@type AvanteHandlerOptions
local handler_opts = {
on_start = opts.on_start,
on_chunk = opts.on_chunk,
on_stop = function(stop_opts)
---@param tool_use_list AvanteLLMToolUse[]
---@param tool_use_index integer
---@param tool_histories AvanteLLMToolHistory[]
local function handle_next_tool_use(tool_use_list, tool_use_index, tool_histories)
if tool_use_index > #tool_use_list then
local new_opts = vim.tbl_deep_extend("force", opts, {
tool_histories = tool_histories,
})
return M._stream(new_opts)
end
local tool_use = tool_use_list[tool_use_index]
---@param result string | nil
---@param error string | nil
local function handle_tool_result(result, error)
local tool_result = {
tool_use_id = tool_use.id,
content = error ~= nil and error or result,
is_error = error ~= nil,
}
table.insert(tool_histories, { tool_result = tool_result, tool_use = tool_use })
return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_histories)
end
-- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil
local result, error = LLMTools.process_tool_use(opts.tools, tool_use, opts.on_tool_log, handle_tool_result)
if result ~= nil or error ~= nil then return handle_tool_result(result, error) end
end
if stop_opts.reason == "tool_use" and stop_opts.tool_use_list then
local old_tool_histories = vim.deepcopy(opts.tool_histories) or {}
local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[]
for _, tool_use in vim.spairs(stop_opts.tool_use_list) do
table.insert(sorted_tool_use_list, tool_use)
end
return handle_next_tool_use(sorted_tool_use_list, 1, old_tool_histories)
end
if stop_opts.reason == "rate_limit" then
local msg = "Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ..."
opts.on_chunk("\n*[" .. msg .. "]*\n")
local timer = vim.loop.new_timer()
if timer then
local retry_after = stop_opts.retry_after
local function countdown()
timer:start(
1000,
0,
vim.schedule_wrap(function()
if retry_after > 0 then retry_after = retry_after - 1 end
local msg_ = "Rate limit reached. Retrying in " .. retry_after .. " seconds ..."
opts.on_chunk([[\033[1A\033[K]] .. "\n*[" .. msg_ .. "]*\n")
countdown()
end)
)
end
countdown()
end
Utils.info("Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds", { title = "Avante" })
vim.defer_fn(function()
if timer then timer:stop() end
M._stream(opts)
end, stop_opts.retry_after * 1000)
return
end
return opts.on_stop(stop_opts)
end,
}
return M.curl({
provider = provider,
prompt_opts = prompt_opts,
handler_opts = handler_opts,
})
end
local function _merge_response(first_response, second_response, opts)
local prompt = "\n" .. Config.dual_boost.prompt
prompt = prompt

View File

@@ -5,58 +5,75 @@ local Path = require("plenary.path")
local Scan = require("plenary.scandir")
local Config = require("avante.config")
---@class avante.ChatHistoryEntry
---@field timestamp string
---@field provider string
---@field model string
---@field request string
---@field response string
---@field original_response string
---@field selected_file {filepath: string}?
---@field selected_code {filetype: string, content: string}?
---@field reset_memory boolean?
---@field selected_filepaths string[] | nil
---@class avante.Path
---@field history_path Path
---@field cache_path Path
local P = {}
local history_file_cache = LRUCache:new(12)
-- History path
local History = {}
-- Get a chat history file name given a buffer
---@param bufnr integer
---@return string
function History.filename(bufnr)
---@param bufnr integer | nil
---@return string dirname
local function generate_project_dirname_in_storage(bufnr)
local project_root = Utils.root.get({
buf = bufnr,
})
-- Replace path separators with double underscores
local path_with_separators = fn.substitute(project_root, "/", "__", "g")
-- Replace other non-alphanumeric characters with single underscores
return fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g") .. ".json"
local dirname = fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g")
return tostring(Path:new("projects"):joinpath(dirname))
end
-- History path
local History = {}
-- Get a chat history file name given a buffer
---@param bufnr integer
---@param new boolean
---@return Path
function History.filepath(bufnr, new)
local dirname = generate_project_dirname_in_storage(bufnr)
local history_dir = Path:new(Config.history.storage_path):joinpath(dirname):joinpath("history")
if not history_dir:exists() then history_dir:mkdir({ parents = true }) end
local pattern = tostring(history_dir:joinpath("*.json"))
local files = vim.fn.glob(pattern, true, true)
local filename = #files .. ".json"
if #files > 0 and not new then filename = (#files - 1) .. ".json" end
return history_dir:joinpath(filename)
end
-- Returns the Path to the chat history file for the given buffer.
---@param bufnr integer
---@return Path
function History.get(bufnr) return Path:new(Config.history.storage_path):joinpath(History.filename(bufnr)) end
function History.get(bufnr) return History.filepath(bufnr, false) end
---@param bufnr integer
function History.new(bufnr)
local filepath = History.filepath(bufnr, true)
local history = {
title = "untitled",
timestamp = tostring(os.date("%Y-%m-%d %H:%M:%S")),
entries = {},
}
filepath:write(vim.json.encode(history), "w")
end
-- Loads the chat history for the given buffer.
---@param bufnr integer
---@return avante.ChatHistoryEntry[]
---@return avante.ChatHistory
function History.load(bufnr)
local history_file = History.get(bufnr)
local cached_key = tostring(history_file:absolute())
local cached_value = history_file_cache:get(cached_key)
if cached_value ~= nil then return cached_value end
local value = {}
---@type avante.ChatHistory
local value = {
title = "untitled",
timestamp = tostring(os.date("%Y-%m-%d %H:%M:%S")),
entries = {},
}
if history_file:exists() then
local content = history_file:read()
value = content ~= nil and vim.json.decode(content) or {}
value = content ~= nil and vim.json.decode(content) or value
end
history_file_cache:set(cached_key, value)
return value
@@ -64,7 +81,7 @@ end
-- Saves the chat history for the given buffer.
---@param bufnr integer
---@param history avante.ChatHistoryEntry[]
---@param history avante.ChatHistory
History.save = vim.schedule_wrap(function(bufnr, history)
local history_file = History.get(bufnr)
local cached_key = tostring(history_file:absolute())

View File

@@ -18,8 +18,8 @@ M.parse_response = O.parse_response
M.parse_response_without_stream = O.parse_response_without_stream
M.is_disable_stream = O.is_disable_stream
function M:parse_curl_args(provider, prompt_opts)
local provider_conf, request_body = P.parse_config(provider)
function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = P.parse_config(self)
local headers = {
["Content-Type"] = "application/json",
@@ -27,9 +27,9 @@ function M:parse_curl_args(provider, prompt_opts)
if P.env.require_api_key(provider_conf) then
if provider_conf.entra then
headers["Authorization"] = "Bearer " .. provider.parse_api_key()
headers["Authorization"] = "Bearer " .. self.parse_api_key()
else
headers["api-key"] = provider.parse_api_key()
headers["api-key"] = self.parse_api_key()
end
end

View File

@@ -51,13 +51,12 @@ function M:parse_response_without_stream(data, event_state, opts)
vim.schedule(function() opts.on_stop({ reason = "complete" }) end)
end
---@param provider AvanteBedrockProviderFunctor
---@param prompt_opts AvantePromptOptions
---@return table
function M:parse_curl_args(provider, prompt_opts)
local base, body_opts = P.parse_config(provider)
function M:parse_curl_args(prompt_opts)
local base, body_opts = P.parse_config(self)
local api_key = provider.parse_api_key()
local api_key = self.parse_api_key()
if api_key == nil then error("Cannot get the bedrock api key!") end
local parts = vim.split(api_key, ",")
local aws_access_key_id = parts[1]

View File

@@ -262,11 +262,10 @@ function M:parse_response(ctx, data_stream, event_state, opts)
end
end
---@param provider AvanteProviderFunctor
---@param prompt_opts AvantePromptOptions
---@return table
function M:parse_curl_args(provider, prompt_opts)
local provider_conf, request_body = P.parse_config(provider)
function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = P.parse_config(self)
local disable_tools = provider_conf.disable_tools or false
local headers = {
@@ -275,7 +274,7 @@ function M:parse_curl_args(provider, prompt_opts)
["anthropic-beta"] = "prompt-caching-2024-07-31",
}
if P.env.require_api_key(provider_conf) then headers["x-api-key"] = provider.parse_api_key() end
if P.env.require_api_key(provider_conf) then headers["x-api-key"] = self.parse_api_key() end
local messages = self:parse_messages(prompt_opts)

View File

@@ -71,8 +71,8 @@ function M:parse_stream_data(ctx, data, opts)
end
end
function M:parse_curl_args(provider, prompt_opts)
local provider_conf, request_body = P.parse_config(provider)
function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = P.parse_config(self)
local headers = {
["Accept"] = "application/json",
@@ -84,7 +84,7 @@ function M:parse_curl_args(provider, prompt_opts)
.. "."
.. vim.version().patch,
}
if P.env.require_api_key(provider_conf) then headers["Authorization"] = "Bearer " .. provider.parse_api_key() end
if P.env.require_api_key(provider_conf) then headers["Authorization"] = "Bearer " .. self.parse_api_key() end
return {
url = Utils.url_join(provider_conf.endpoint, "/chat"),

View File

@@ -215,12 +215,12 @@ M.parse_messages = OpenAI.parse_messages
M.parse_response = OpenAI.parse_response
function M:parse_curl_args(provider, prompt_opts)
function M:parse_curl_args(prompt_opts)
-- refresh token synchronously, only if it has expired
-- (this should rarely happen, as we refresh the token in the background)
H.refresh_token(false, false)
local provider_conf, request_body = P.parse_config(provider)
local provider_conf, request_body = P.parse_config(self)
local disable_tools = provider_conf.disable_tools or false
local tools = {}

View File

@@ -83,8 +83,8 @@ function M:parse_response(ctx, data_stream, _, opts)
end
end
function M:parse_curl_args(provider, prompt_opts)
local provider_conf, request_body = P.parse_config(provider)
function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = P.parse_config(self)
request_body = vim.tbl_deep_extend("force", request_body, {
generationConfig = {
@@ -95,7 +95,7 @@ function M:parse_curl_args(provider, prompt_opts)
request_body.temperature = nil
request_body.max_tokens = nil
local api_key = provider.parse_api_key()
local api_key = self.parse_api_key()
if api_key == nil then error("Cannot get the gemini api key!") end
return {

View File

@@ -261,6 +261,13 @@ function M.setup()
E.setup({ provider = cursor_applying_provider })
end
end
if Config.memory_summary_provider then
local memory_summary_provider = M[Config.memory_summary_provider]
if memory_summary_provider and memory_summary_provider ~= provider then
E.setup({ provider = memory_summary_provider })
end
end
end
---@param provider Provider

View File

@@ -219,8 +219,8 @@ function M:parse_response_without_stream(data, _, opts)
end
end
function M:parse_curl_args(provider, prompt_opts)
local provider_conf, request_body = P.parse_config(provider)
function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = P.parse_config(self)
local disable_tools = provider_conf.disable_tools or false
local headers = {
@@ -228,7 +228,7 @@ function M:parse_curl_args(provider, prompt_opts)
}
if P.env.require_api_key(provider_conf) then
local api_key = provider.parse_api_key()
local api_key = self.parse_api_key()
if api_key == nil then
error(Config.provider .. " API key is not set, please set it in your environment variable or config file")
end

View File

@@ -32,8 +32,8 @@ function M.parse_api_key()
return direct_output
end
function M:parse_curl_args(provider, prompt_opts)
local provider_conf, request_body = P.parse_config(provider)
function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = P.parse_config(self)
local location = vim.fn.getenv("LOCATION")
local project_id = vim.fn.getenv("PROJECT_ID")
local model_id = provider_conf.model or "default-model-id"

View File

@@ -2073,11 +2073,11 @@ function Sidebar:get_layout()
return vim.tbl_contains({ "left", "right" }, calculate_config_window_position()) and "vertical" or "horizontal"
end
---@param history avante.ChatHistoryEntry[]
---@param history avante.ChatHistory
---@return string
function Sidebar:render_history_content(history)
local content = ""
for idx, entry in ipairs(history) do
for idx, entry in ipairs(history.entries) do
if entry.reset_memory then
content = content .. "***MEMORY RESET***\n\n"
if idx < #history then content = content .. "-------\n\n" end
@@ -2146,11 +2146,11 @@ end
function Sidebar:clear_history(args, cb)
local chat_history = Path.history.load(self.code.bufnr)
if next(chat_history) ~= nil then
chat_history = {}
chat_history.entries = {}
Path.history.save(self.code.bufnr, chat_history)
self:update_content(
"Chat history cleared",
{ focus = false, scroll = false, callback = function() self:focus_input() end }
{ ignore_history = true, focus = false, scroll = false, callback = function() self:focus_input() end }
)
if cb then cb(args) end
else
@@ -2161,6 +2161,16 @@ function Sidebar:clear_history(args, cb)
end
end
function Sidebar:new_chat(args, cb)
Path.history.new(self.code.bufnr)
Sidebar.reload_chat_history()
self:update_content(
"New chat",
{ ignore_history = true, focus = false, scroll = false, callback = function() self:focus_input() end }
)
if cb then cb(args) end
end
function Sidebar:reset_memory(args, cb)
local chat_history = Path.history.load(self.code.bufnr)
if next(chat_history) ~= nil then
@@ -2176,6 +2186,7 @@ function Sidebar:reset_memory(args, cb)
reset_memory = true,
})
Path.history.save(self.code.bufnr, chat_history)
Sidebar.reload_chat_history()
local history_content = self:render_history_content(chat_history)
self:update_content(history_content, {
focus = false,
@@ -2184,6 +2195,7 @@ function Sidebar:reset_memory(args, cb)
})
if cb then cb(args) end
else
Sidebar.reload_chat_history()
self:update_content(
"Chat history is already empty",
{ focus = false, scroll = false, callback = function() self:focus_input() end }
@@ -2191,7 +2203,7 @@ function Sidebar:reset_memory(args, cb)
end
end
---@alias AvanteSlashCommandType "clear" | "help" | "lines" | "reset" | "commit"
---@alias AvanteSlashCommandType "clear" | "help" | "lines" | "reset" | "commit" | "new"
---@alias AvanteSlashCommandCallback fun(args: string, cb?: fun(args: string): nil): nil
---@alias AvanteSlashCommand {description: string, command: AvanteSlashCommandType, details: string, shorthelp?: string, callback?: AvanteSlashCommandCallback}
---@return AvanteSlashCommand[]
@@ -2211,6 +2223,7 @@ function Sidebar:get_commands()
{ description = "Show help message", command = "help" },
{ description = "Clear chat history", command = "clear" },
{ description = "Reset memory", command = "reset" },
{ description = "New chat", command = "new" },
{
shorthelp = "Ask a question about specific lines",
description = "/lines <start>-<end> <question>",
@@ -2228,6 +2241,7 @@ function Sidebar:get_commands()
end,
clear = function(args, cb) self:clear_history(args, cb) end,
reset = function(args, cb) self:reset_memory(args, cb) end,
new = function(args, cb) self:new_chat(args, cb) end,
lines = function(args, cb)
if cb then cb(args) end
end,
@@ -2298,6 +2312,8 @@ function Sidebar:create_input_container(opts)
local chat_history = Path.history.load(self.code.bufnr)
Sidebar.reload_chat_history = function() chat_history = Path.history.load(self.code.bufnr) end
local tools = vim.deepcopy(LLMTools.get_tools())
table.insert(tools, {
name = "add_file_to_context",
@@ -2330,8 +2346,9 @@ function Sidebar:create_input_container(opts)
})
---@param request string
---@return AvanteGeneratePromptsOptions
local function get_generate_prompts_options(request)
---@param summarize_memory boolean
---@param cb fun(opts: AvanteGeneratePromptsOptions): nil
local function get_generate_prompts_options(request, summarize_memory, cb)
local filetype = api.nvim_get_option_value("filetype", { buf = self.code.bufnr })
local selected_code_content = nil
@@ -2355,40 +2372,19 @@ function Sidebar:create_input_container(opts)
end
end
local history_messages = {}
for i = #chat_history, 1, -1 do
local entry = chat_history[i]
if entry.reset_memory then break end
if
entry.request == nil
or entry.original_response == nil
or entry.request == ""
or entry.original_response == ""
then
break
end
table.insert(
history_messages,
1,
{ role = "assistant", content = Utils.trim_think_content(entry.original_response) }
)
local user_content = ""
if entry.selected_file ~= nil then
user_content = user_content .. "SELECTED FILE: " .. entry.selected_file.filepath .. "\n\n"
end
if entry.selected_code ~= nil then
user_content = user_content
.. "SELECTED CODE:\n\n```"
.. entry.selected_code.filetype
.. "\n"
.. entry.selected_code.content
.. "\n```\n\n"
end
user_content = user_content .. "USER PROMPT:\n\n" .. entry.request
table.insert(history_messages, 1, { role = "user", content = user_content })
local entries = Utils.history.filter_active_entries(chat_history.entries)
if chat_history.memory then
entries = vim
.iter(entries)
:filter(function(entry) return entry.timestamp > chat_history.memory.last_summarized_timestamp end)
:totable()
end
return {
local history_messages = Utils.history.entries_to_llm_messages(entries)
---@type AvanteGeneratePromptsOptions
local prompts_opts = {
ask = opts.ask or true,
project_context = vim.json.encode(project_context),
selected_files = selected_files_contents,
@@ -2400,6 +2396,20 @@ function Sidebar:create_input_container(opts)
mode = Config.behaviour.enable_cursor_planning_mode and "cursor-planning" or "planning",
tools = tools,
}
if chat_history.memory then prompts_opts.memory = chat_history.memory.content end
if not summarize_memory or #history_messages < 8 then
cb(prompts_opts)
return
end
prompts_opts.history_messages = vim.list_slice(prompts_opts.history_messages, 5)
Llm.summarize_memory(self.code.bufnr, chat_history, function(memory)
if memory then prompts_opts.memory = memory.content end
cb(prompts_opts)
end)
end
---@param request string
@@ -2584,7 +2594,8 @@ function Sidebar:create_input_container(opts)
end, 0)
-- Save chat history
table.insert(chat_history or {}, {
chat_history.entries = chat_history.entries or {}
table.insert(chat_history.entries, {
timestamp = timestamp,
provider = Config.provider,
model = model,
@@ -2597,17 +2608,18 @@ function Sidebar:create_input_container(opts)
Path.history.save(self.code.bufnr, chat_history)
end
local generate_prompts_options = get_generate_prompts_options(request)
---@type AvanteLLMStreamOptions
---@diagnostic disable-next-line: assign-type-mismatch
local stream_options = vim.tbl_deep_extend("force", generate_prompts_options, {
on_start = on_start,
on_chunk = on_chunk,
on_stop = on_stop,
on_tool_log = on_tool_log,
})
get_generate_prompts_options(request, true, function(generate_prompts_options)
---@type AvanteLLMStreamOptions
---@diagnostic disable-next-line: assign-type-mismatch
local stream_options = vim.tbl_deep_extend("force", generate_prompts_options, {
on_start = on_start,
on_chunk = on_chunk,
on_stop = on_stop,
on_tool_log = on_tool_log,
})
Llm.stream(stream_options)
Llm.stream(stream_options)
end)
end
local function get_position()
@@ -2762,37 +2774,43 @@ function Sidebar:create_input_container(opts)
local hint_text = (fn.mode() ~= "i" and Config.mappings.submit.normal or Config.mappings.submit.insert)
.. ": submit"
if Config.behaviour.enable_token_counting then
local input_value = table.concat(api.nvim_buf_get_lines(self.input_container.bufnr, 0, -1, false), "\n")
local generate_prompts_options = get_generate_prompts_options(input_value)
local tokens = Llm.calculate_tokens(generate_prompts_options)
hint_text = "Tokens: " .. tostring(tokens) .. "; " .. hint_text
local function show()
local buf = api.nvim_create_buf(false, true)
api.nvim_buf_set_lines(buf, 0, -1, false, { hint_text })
api.nvim_buf_add_highlight(buf, 0, "AvantePopupHint", 0, 0, -1)
-- Get the current window size
local win_width = api.nvim_win_get_width(self.input_container.winid)
local width = #hint_text
-- Set the floating window options
local win_opts = {
relative = "win",
win = self.input_container.winid,
width = width,
height = 1,
row = get_float_window_row(),
col = math.max(win_width - width, 0), -- Display in the bottom right corner
style = "minimal",
border = "none",
focusable = false,
zindex = 100,
}
-- Create the floating window
hint_window = api.nvim_open_win(buf, false, win_opts)
end
local buf = api.nvim_create_buf(false, true)
api.nvim_buf_set_lines(buf, 0, -1, false, { hint_text })
api.nvim_buf_add_highlight(buf, 0, "AvantePopupHint", 0, 0, -1)
-- Get the current window size
local win_width = api.nvim_win_get_width(self.input_container.winid)
local width = #hint_text
-- Set the floating window options
local win_opts = {
relative = "win",
win = self.input_container.winid,
width = width,
height = 1,
row = get_float_window_row(),
col = math.max(win_width - width, 0), -- Display in the bottom right corner
style = "minimal",
border = "none",
focusable = false,
zindex = 100,
}
-- Create the floating window
hint_window = api.nvim_open_win(buf, false, win_opts)
if Config.behaviour.enable_token_counting then
local input_value = table.concat(api.nvim_buf_get_lines(self.input_container.bufnr, 0, -1, false), "\n")
get_generate_prompts_options(input_value, false, function(generate_prompts_options)
local tokens = Llm.calculate_tokens(generate_prompts_options)
hint_text = "Tokens: " .. tostring(tokens) .. "; " .. hint_text
show()
end)
else
show()
end
end
api.nvim_create_autocmd({ "TextChanged", "TextChangedI", "VimResized" }, {

View File

@@ -0,0 +1,5 @@
{%- if memory -%}
<memory>
{{memory}}
</memory>
{%- endif %}

View File

@@ -190,13 +190,9 @@ vim.g.avante_login = vim.g.avante_login
---@alias AvanteMessagesParser fun(self: AvanteProviderFunctor, opts: AvantePromptOptions): AvanteChatMessage[]
---
---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table<string, any> | string, headers: table<string, string>, rawArgs: string[] | nil}
---@alias AvanteCurlArgsParser fun(self: AvanteProviderFunctor, provider: AvanteProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions): AvanteCurlOutput
---@alias AvanteCurlArgsParser fun(self: AvanteProviderFunctor, prompt_opts: AvantePromptOptions): AvanteCurlOutput
---
---@class AvanteResponseParserOptions
---@field on_start AvanteLLMStartCallback
---@field on_chunk AvanteLLMChunkCallback
---@field on_stop AvanteLLMStopCallback
---@alias AvanteResponseParser fun(self: AvanteProviderFunctor, ctx: any, data_stream: string, event_state: string, opts: AvanteResponseParserOptions): nil
---@alias AvanteResponseParser fun(self: AvanteProviderFunctor, ctx: any, data_stream: string, event_state: string, opts: AvanteHandlerOptions): nil
---
---@class AvanteDefaultBaseProvider: table<string, any>
---@field endpoint? string
@@ -305,6 +301,7 @@ vim.g.avante_login = vim.g.avante_login
---@field selected_files AvanteSelectedFiles[] | nil
---@field diagnostics string | nil
---@field history_messages AvanteLLMMessage[] | nil
---@field memory string | nil
---
---@class AvanteGeneratePromptsOptions: AvanteTemplateOptions
---@field ask boolean
@@ -358,3 +355,35 @@ vim.g.avante_login = vim.g.avante_login
---@field description string
---@field type 'string' | 'string[]' | 'boolean'
---@field optional? boolean
---
---@class avante.ChatHistoryEntry
---@field timestamp string
---@field provider string
---@field model string
---@field request string
---@field response string
---@field original_response string
---@field selected_file {filepath: string}?
---@field selected_code {filetype: string, content: string}?
---@field reset_memory boolean?
---@field selected_filepaths string[] | nil
---
---@class avante.ChatHistory
---@field title string
---@field timestamp string
---@field entries avante.ChatHistoryEntry[]
---@field memory avante.ChatMemory | nil
---
---@class avante.ChatMemory
---@field content string
---@field last_summarized_timestamp string
---@class avante.Path
---@field history_path Path
---@field cache_path Path
---
---@class avante.CurlOpts
---@field provider AvanteProviderFunctor
---@field prompt_opts AvantePromptOptions
---@field handler_opts AvanteHandlerOptions
---

View File

@@ -0,0 +1,52 @@
local Utils = require("avante.utils")
---@class avante.utils.history
local M = {}
---@param entries avante.ChatHistoryEntry[]
---@return avante.ChatHistoryEntry[]
function M.filter_active_entries(entries)
local entries_ = {}
for i = #entries, 1, -1 do
local entry = entries[i]
if entry.reset_memory then break end
if
entry.request == nil
or entry.original_response == nil
or entry.request == ""
or entry.original_response == ""
then
break
end
table.insert(entries_, 1, entry)
end
return entries_
end
---@param entries avante.ChatHistoryEntry[]
---@return AvanteLLMMessage[]
function M.entries_to_llm_messages(entries)
local messages = {}
for _, entry in ipairs(entries) do
local user_content = ""
if entry.selected_file ~= nil then
user_content = user_content .. "SELECTED FILE: " .. entry.selected_file.filepath .. "\n\n"
end
if entry.selected_code ~= nil then
user_content = user_content
.. "SELECTED CODE:\n\n```"
.. entry.selected_code.filetype
.. "\n"
.. entry.selected_code.content
.. "\n```\n\n"
end
user_content = user_content .. "USER PROMPT:\n\n" .. entry.request
table.insert(messages, { role = "user", content = user_content })
table.insert(messages, { role = "assistant", content = Utils.trim_think_content(entry.original_response) })
end
return messages
end
return M

View File

@@ -8,6 +8,7 @@ local lsp = vim.lsp
---@field tokens avante.utils.tokens
---@field root avante.utils.root
---@field file avante.utils.file
---@field history avante.utils.history
local M = {}
setmetatable(M, {