From 8620ea3e12cfdb90aef2e8ce6f7d5e864758ab71 Mon Sep 17 00:00:00 2001 From: yetone Date: Fri, 7 Mar 2025 00:12:57 +0800 Subject: [PATCH] refactor: summarize memory (#1508) --- crates/avante-templates/src/lib.rs | 2 + lua/avante/config.lua | 5 + lua/avante/llm.lua | 243 +++++++++++++++-------- lua/avante/path.lua | 71 ++++--- lua/avante/providers/azure.lua | 8 +- lua/avante/providers/bedrock.lua | 7 +- lua/avante/providers/claude.lua | 7 +- lua/avante/providers/cohere.lua | 6 +- lua/avante/providers/copilot.lua | 4 +- lua/avante/providers/gemini.lua | 6 +- lua/avante/providers/init.lua | 7 + lua/avante/providers/openai.lua | 6 +- lua/avante/providers/vertex.lua | 4 +- lua/avante/sidebar.lua | 176 ++++++++-------- lua/avante/templates/_memory.avanterules | 5 + lua/avante/types.lua | 41 +++- lua/avante/utils/history.lua | 52 +++++ lua/avante/utils/init.lua | 1 + 18 files changed, 434 insertions(+), 217 deletions(-) create mode 100644 lua/avante/templates/_memory.avanterules create mode 100644 lua/avante/utils/history.lua diff --git a/crates/avante-templates/src/lib.rs b/crates/avante-templates/src/lib.rs index b4d275b..d04fa46 100644 --- a/crates/avante-templates/src/lib.rs +++ b/crates/avante-templates/src/lib.rs @@ -33,6 +33,7 @@ struct TemplateContext { diagnostics: Option, system_info: Option, model_name: Option, + memory: Option, } // Given the file name registered after add, the context table in Lua, resulted in a formatted @@ -58,6 +59,7 @@ fn render(state: &State, template: &str, context: TemplateContext) -> LuaResult< diagnostics => context.diagnostics, system_info => context.system_info, model_name => context.model_name, + memory => context.memory, }) .map_err(LuaError::external) .unwrap()) diff --git a/lua/avante/config.lua b/lua/avante/config.lua index eb59b93..0e6014f 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -27,6 +27,7 @@ M._defaults = { -- Of course, you can reduce the request frequency by increasing `suggestion.debounce`. auto_suggestions_provider = "claude", cursor_applying_provider = nil, + memory_summary_provider = nil, ---@alias Tokenizer "tiktoken" | "hf" -- Used for counting tokens and encoding text. -- By default, we will use tiktoken. @@ -273,6 +274,10 @@ M._defaults = { temperature = 0, max_tokens = 8000, }, + ["openai-gpt-4o-mini"] = { + __inherited_from = "openai", + model = "gpt-4o-mini", + }, }, ---Specify the special dual_boost mode ---1. enabled: Whether to enable dual_boost mode. Default to false. diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 49ce67f..ff6636e 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -19,6 +19,71 @@ M.CANCEL_PATTERN = "AvanteLLMEscape" local group = api.nvim_create_augroup("avante_llm", { clear = true }) +---@param bufnr integer +---@param history avante.ChatHistory +---@param cb fun(memory: avante.ChatMemory | nil): nil +function M.summarize_memory(bufnr, history, cb) + local system_prompt = + [[Summarize the following conversation to extract the most critical information (such as languages used, conversation style, tech stack, considerations, user information, etc.) for memory in subsequent conversations. Since it is for memory purposes, be detailed and rigorous to ensure that no information from previous summaries is lost in the newly generated summary.]] + local entries = Utils.history.filter_active_entries(history.entries) + if #entries == 0 then + cb(nil) + return + end + if history.memory then + system_prompt = system_prompt .. "\n\nThe previous summary is:\n\n" .. history.memory.content + entries = vim + .iter(entries) + :filter(function(entry) return entry.timestamp > history.memory.last_summarized_timestamp end) + :totable() + end + if #entries == 0 then + cb(history.memory) + return + end + local history_messages = Utils.history.entries_to_llm_messages(entries) + history_messages = vim.list_slice(history_messages, 1, 4) + if #history_messages == 0 then + cb(history.memory) + return + end + Utils.debug("summarize memory", #history_messages, history_messages[#history_messages].content) + local response_content = "" + local provider = Providers[Config.memory_summary_provider or Config.provider] + M.curl({ + provider = provider, + prompt_opts = { + system_prompt = system_prompt, + messages = { + { role = "user", content = vim.json.encode(history_messages) }, + }, + }, + handler_opts = { + on_start = function(_) end, + on_chunk = function(chunk) + if not chunk then return end + response_content = response_content .. chunk + end, + on_stop = function(stop_opts) + if stop_opts.error ~= nil then + Utils.error(string.format("summarize failed: %s", vim.inspect(stop_opts.error))) + return + end + if stop_opts.reason == "complete" then + response_content = Utils.trim_think_content(response_content) + local memory = { + content = response_content, + last_summarized_timestamp = entries[#entries].timestamp, + } + history.memory = memory + Path.history.save(bufnr, history) + cb(memory) + end + end, + }, + }) +end + ---@param opts AvanteGeneratePromptsOptions ---@return AvantePromptOptions function M.generate_prompts(opts) @@ -58,6 +123,7 @@ function M.generate_prompts(opts) diagnostics = opts.diagnostics, system_info = system_info, model_name = provider.model or "unknown", + memory = opts.memory, } local system_prompt = Path.prompts.render_mode(mode, template_opts) @@ -80,6 +146,11 @@ function M.generate_prompts(opts) if code_context ~= "" then table.insert(messages, { role = "user", content = code_context }) end end + if opts.memory ~= nil and opts.memory ~= "" and opts.memory ~= "null" then + local memory = Path.prompts.render_file("_memory.avanterules", template_opts) + if memory ~= "" then table.insert(messages, { role = "user", content = memory }) end + end + if instructions then if opts.use_xml_format then table.insert(messages, { role = "user", content = string.format("%s", instructions) }) @@ -150,90 +221,17 @@ function M.calculate_tokens(opts) return tokens end ----@param opts AvanteLLMStreamOptions -function M._stream(opts) - local provider = opts.provider or Providers[Config.provider] +---@param opts avante.CurlOpts +function M.curl(opts) + local provider = opts.provider + local prompt_opts = opts.prompt_opts + local handler_opts = opts.handler_opts - ---@cast provider AvanteProviderFunctor - - local prompt_opts = M.generate_prompts(opts) + ---@type AvanteCurlOutput + local spec = provider:parse_curl_args(prompt_opts) ---@type string local current_event_state = nil - - ---@type AvanteHandlerOptions - local handler_opts = { - on_start = opts.on_start, - on_chunk = opts.on_chunk, - on_stop = function(stop_opts) - ---@param tool_use_list AvanteLLMToolUse[] - ---@param tool_use_index integer - ---@param tool_histories AvanteLLMToolHistory[] - local function handle_next_tool_use(tool_use_list, tool_use_index, tool_histories) - if tool_use_index > #tool_use_list then - local new_opts = vim.tbl_deep_extend("force", opts, { - tool_histories = tool_histories, - }) - return M._stream(new_opts) - end - local tool_use = tool_use_list[tool_use_index] - ---@param result string | nil - ---@param error string | nil - local function handle_tool_result(result, error) - local tool_result = { - tool_use_id = tool_use.id, - content = error ~= nil and error or result, - is_error = error ~= nil, - } - table.insert(tool_histories, { tool_result = tool_result, tool_use = tool_use }) - return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_histories) - end - -- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil - local result, error = LLMTools.process_tool_use(opts.tools, tool_use, opts.on_tool_log, handle_tool_result) - if result ~= nil or error ~= nil then return handle_tool_result(result, error) end - end - if stop_opts.reason == "tool_use" and stop_opts.tool_use_list then - local old_tool_histories = vim.deepcopy(opts.tool_histories) or {} - local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[] - for _, tool_use in vim.spairs(stop_opts.tool_use_list) do - table.insert(sorted_tool_use_list, tool_use) - end - return handle_next_tool_use(sorted_tool_use_list, 1, old_tool_histories) - end - if stop_opts.reason == "rate_limit" then - local msg = "Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ..." - opts.on_chunk("\n*[" .. msg .. "]*\n") - local timer = vim.loop.new_timer() - if timer then - local retry_after = stop_opts.retry_after - local function countdown() - timer:start( - 1000, - 0, - vim.schedule_wrap(function() - if retry_after > 0 then retry_after = retry_after - 1 end - local msg_ = "Rate limit reached. Retrying in " .. retry_after .. " seconds ..." - opts.on_chunk([[\033[1A\033[K]] .. "\n*[" .. msg_ .. "]*\n") - countdown() - end) - ) - end - countdown() - end - Utils.info("Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds", { title = "Avante" }) - vim.defer_fn(function() - if timer then timer:stop() end - M._stream(opts) - end, stop_opts.retry_after * 1000) - return - end - return opts.on_stop(stop_opts) - end, - } - - ---@type AvanteCurlOutput - local spec = provider:parse_curl_args(provider, prompt_opts) - local resp_ctx = {} ---@param line string @@ -383,6 +381,91 @@ function M._stream(opts) return active_job end +---@param opts AvanteLLMStreamOptions +function M._stream(opts) + local provider = opts.provider or Providers[Config.provider] + + ---@cast provider AvanteProviderFunctor + + local prompt_opts = M.generate_prompts(opts) + + ---@type AvanteHandlerOptions + local handler_opts = { + on_start = opts.on_start, + on_chunk = opts.on_chunk, + on_stop = function(stop_opts) + ---@param tool_use_list AvanteLLMToolUse[] + ---@param tool_use_index integer + ---@param tool_histories AvanteLLMToolHistory[] + local function handle_next_tool_use(tool_use_list, tool_use_index, tool_histories) + if tool_use_index > #tool_use_list then + local new_opts = vim.tbl_deep_extend("force", opts, { + tool_histories = tool_histories, + }) + return M._stream(new_opts) + end + local tool_use = tool_use_list[tool_use_index] + ---@param result string | nil + ---@param error string | nil + local function handle_tool_result(result, error) + local tool_result = { + tool_use_id = tool_use.id, + content = error ~= nil and error or result, + is_error = error ~= nil, + } + table.insert(tool_histories, { tool_result = tool_result, tool_use = tool_use }) + return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_histories) + end + -- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil + local result, error = LLMTools.process_tool_use(opts.tools, tool_use, opts.on_tool_log, handle_tool_result) + if result ~= nil or error ~= nil then return handle_tool_result(result, error) end + end + if stop_opts.reason == "tool_use" and stop_opts.tool_use_list then + local old_tool_histories = vim.deepcopy(opts.tool_histories) or {} + local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[] + for _, tool_use in vim.spairs(stop_opts.tool_use_list) do + table.insert(sorted_tool_use_list, tool_use) + end + return handle_next_tool_use(sorted_tool_use_list, 1, old_tool_histories) + end + if stop_opts.reason == "rate_limit" then + local msg = "Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ..." + opts.on_chunk("\n*[" .. msg .. "]*\n") + local timer = vim.loop.new_timer() + if timer then + local retry_after = stop_opts.retry_after + local function countdown() + timer:start( + 1000, + 0, + vim.schedule_wrap(function() + if retry_after > 0 then retry_after = retry_after - 1 end + local msg_ = "Rate limit reached. Retrying in " .. retry_after .. " seconds ..." + opts.on_chunk([[\033[1A\033[K]] .. "\n*[" .. msg_ .. "]*\n") + countdown() + end) + ) + end + countdown() + end + Utils.info("Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds", { title = "Avante" }) + vim.defer_fn(function() + if timer then timer:stop() end + M._stream(opts) + end, stop_opts.retry_after * 1000) + return + end + return opts.on_stop(stop_opts) + end, + } + + return M.curl({ + provider = provider, + prompt_opts = prompt_opts, + handler_opts = handler_opts, + }) +end + local function _merge_response(first_response, second_response, opts) local prompt = "\n" .. Config.dual_boost.prompt prompt = prompt diff --git a/lua/avante/path.lua b/lua/avante/path.lua index e404a45..e93d726 100644 --- a/lua/avante/path.lua +++ b/lua/avante/path.lua @@ -5,58 +5,75 @@ local Path = require("plenary.path") local Scan = require("plenary.scandir") local Config = require("avante.config") ----@class avante.ChatHistoryEntry ----@field timestamp string ----@field provider string ----@field model string ----@field request string ----@field response string ----@field original_response string ----@field selected_file {filepath: string}? ----@field selected_code {filetype: string, content: string}? ----@field reset_memory boolean? ----@field selected_filepaths string[] | nil - ---@class avante.Path ----@field history_path Path ----@field cache_path Path local P = {} local history_file_cache = LRUCache:new(12) --- History path -local History = {} - --- Get a chat history file name given a buffer ----@param bufnr integer ----@return string -function History.filename(bufnr) +---@param bufnr integer | nil +---@return string dirname +local function generate_project_dirname_in_storage(bufnr) local project_root = Utils.root.get({ buf = bufnr, }) -- Replace path separators with double underscores local path_with_separators = fn.substitute(project_root, "/", "__", "g") -- Replace other non-alphanumeric characters with single underscores - return fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g") .. ".json" + local dirname = fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g") + return tostring(Path:new("projects"):joinpath(dirname)) +end + +-- History path +local History = {} + +-- Get a chat history file name given a buffer +---@param bufnr integer +---@param new boolean +---@return Path +function History.filepath(bufnr, new) + local dirname = generate_project_dirname_in_storage(bufnr) + local history_dir = Path:new(Config.history.storage_path):joinpath(dirname):joinpath("history") + if not history_dir:exists() then history_dir:mkdir({ parents = true }) end + local pattern = tostring(history_dir:joinpath("*.json")) + local files = vim.fn.glob(pattern, true, true) + local filename = #files .. ".json" + if #files > 0 and not new then filename = (#files - 1) .. ".json" end + return history_dir:joinpath(filename) end -- Returns the Path to the chat history file for the given buffer. ---@param bufnr integer ---@return Path -function History.get(bufnr) return Path:new(Config.history.storage_path):joinpath(History.filename(bufnr)) end +function History.get(bufnr) return History.filepath(bufnr, false) end + +---@param bufnr integer +function History.new(bufnr) + local filepath = History.filepath(bufnr, true) + local history = { + title = "untitled", + timestamp = tostring(os.date("%Y-%m-%d %H:%M:%S")), + entries = {}, + } + filepath:write(vim.json.encode(history), "w") +end -- Loads the chat history for the given buffer. ---@param bufnr integer ----@return avante.ChatHistoryEntry[] +---@return avante.ChatHistory function History.load(bufnr) local history_file = History.get(bufnr) local cached_key = tostring(history_file:absolute()) local cached_value = history_file_cache:get(cached_key) if cached_value ~= nil then return cached_value end - local value = {} + ---@type avante.ChatHistory + local value = { + title = "untitled", + timestamp = tostring(os.date("%Y-%m-%d %H:%M:%S")), + entries = {}, + } if history_file:exists() then local content = history_file:read() - value = content ~= nil and vim.json.decode(content) or {} + value = content ~= nil and vim.json.decode(content) or value end history_file_cache:set(cached_key, value) return value @@ -64,7 +81,7 @@ end -- Saves the chat history for the given buffer. ---@param bufnr integer ----@param history avante.ChatHistoryEntry[] +---@param history avante.ChatHistory History.save = vim.schedule_wrap(function(bufnr, history) local history_file = History.get(bufnr) local cached_key = tostring(history_file:absolute()) diff --git a/lua/avante/providers/azure.lua b/lua/avante/providers/azure.lua index 04a7b14..dc7a5e0 100644 --- a/lua/avante/providers/azure.lua +++ b/lua/avante/providers/azure.lua @@ -18,8 +18,8 @@ M.parse_response = O.parse_response M.parse_response_without_stream = O.parse_response_without_stream M.is_disable_stream = O.is_disable_stream -function M:parse_curl_args(provider, prompt_opts) - local provider_conf, request_body = P.parse_config(provider) +function M:parse_curl_args(prompt_opts) + local provider_conf, request_body = P.parse_config(self) local headers = { ["Content-Type"] = "application/json", @@ -27,9 +27,9 @@ function M:parse_curl_args(provider, prompt_opts) if P.env.require_api_key(provider_conf) then if provider_conf.entra then - headers["Authorization"] = "Bearer " .. provider.parse_api_key() + headers["Authorization"] = "Bearer " .. self.parse_api_key() else - headers["api-key"] = provider.parse_api_key() + headers["api-key"] = self.parse_api_key() end end diff --git a/lua/avante/providers/bedrock.lua b/lua/avante/providers/bedrock.lua index 3bfa4ae..ba4d737 100644 --- a/lua/avante/providers/bedrock.lua +++ b/lua/avante/providers/bedrock.lua @@ -51,13 +51,12 @@ function M:parse_response_without_stream(data, event_state, opts) vim.schedule(function() opts.on_stop({ reason = "complete" }) end) end ----@param provider AvanteBedrockProviderFunctor ---@param prompt_opts AvantePromptOptions ---@return table -function M:parse_curl_args(provider, prompt_opts) - local base, body_opts = P.parse_config(provider) +function M:parse_curl_args(prompt_opts) + local base, body_opts = P.parse_config(self) - local api_key = provider.parse_api_key() + local api_key = self.parse_api_key() if api_key == nil then error("Cannot get the bedrock api key!") end local parts = vim.split(api_key, ",") local aws_access_key_id = parts[1] diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index dcfae6d..ce64a83 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -262,11 +262,10 @@ function M:parse_response(ctx, data_stream, event_state, opts) end end ----@param provider AvanteProviderFunctor ---@param prompt_opts AvantePromptOptions ---@return table -function M:parse_curl_args(provider, prompt_opts) - local provider_conf, request_body = P.parse_config(provider) +function M:parse_curl_args(prompt_opts) + local provider_conf, request_body = P.parse_config(self) local disable_tools = provider_conf.disable_tools or false local headers = { @@ -275,7 +274,7 @@ function M:parse_curl_args(provider, prompt_opts) ["anthropic-beta"] = "prompt-caching-2024-07-31", } - if P.env.require_api_key(provider_conf) then headers["x-api-key"] = provider.parse_api_key() end + if P.env.require_api_key(provider_conf) then headers["x-api-key"] = self.parse_api_key() end local messages = self:parse_messages(prompt_opts) diff --git a/lua/avante/providers/cohere.lua b/lua/avante/providers/cohere.lua index 8725a13..a67e10a 100644 --- a/lua/avante/providers/cohere.lua +++ b/lua/avante/providers/cohere.lua @@ -71,8 +71,8 @@ function M:parse_stream_data(ctx, data, opts) end end -function M:parse_curl_args(provider, prompt_opts) - local provider_conf, request_body = P.parse_config(provider) +function M:parse_curl_args(prompt_opts) + local provider_conf, request_body = P.parse_config(self) local headers = { ["Accept"] = "application/json", @@ -84,7 +84,7 @@ function M:parse_curl_args(provider, prompt_opts) .. "." .. vim.version().patch, } - if P.env.require_api_key(provider_conf) then headers["Authorization"] = "Bearer " .. provider.parse_api_key() end + if P.env.require_api_key(provider_conf) then headers["Authorization"] = "Bearer " .. self.parse_api_key() end return { url = Utils.url_join(provider_conf.endpoint, "/chat"), diff --git a/lua/avante/providers/copilot.lua b/lua/avante/providers/copilot.lua index ddd22e3..313b52b 100644 --- a/lua/avante/providers/copilot.lua +++ b/lua/avante/providers/copilot.lua @@ -215,12 +215,12 @@ M.parse_messages = OpenAI.parse_messages M.parse_response = OpenAI.parse_response -function M:parse_curl_args(provider, prompt_opts) +function M:parse_curl_args(prompt_opts) -- refresh token synchronously, only if it has expired -- (this should rarely happen, as we refresh the token in the background) H.refresh_token(false, false) - local provider_conf, request_body = P.parse_config(provider) + local provider_conf, request_body = P.parse_config(self) local disable_tools = provider_conf.disable_tools or false local tools = {} diff --git a/lua/avante/providers/gemini.lua b/lua/avante/providers/gemini.lua index 14d15da..eafc2b8 100644 --- a/lua/avante/providers/gemini.lua +++ b/lua/avante/providers/gemini.lua @@ -83,8 +83,8 @@ function M:parse_response(ctx, data_stream, _, opts) end end -function M:parse_curl_args(provider, prompt_opts) - local provider_conf, request_body = P.parse_config(provider) +function M:parse_curl_args(prompt_opts) + local provider_conf, request_body = P.parse_config(self) request_body = vim.tbl_deep_extend("force", request_body, { generationConfig = { @@ -95,7 +95,7 @@ function M:parse_curl_args(provider, prompt_opts) request_body.temperature = nil request_body.max_tokens = nil - local api_key = provider.parse_api_key() + local api_key = self.parse_api_key() if api_key == nil then error("Cannot get the gemini api key!") end return { diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index ff6927c..26692f4 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -261,6 +261,13 @@ function M.setup() E.setup({ provider = cursor_applying_provider }) end end + + if Config.memory_summary_provider then + local memory_summary_provider = M[Config.memory_summary_provider] + if memory_summary_provider and memory_summary_provider ~= provider then + E.setup({ provider = memory_summary_provider }) + end + end end ---@param provider Provider diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index 36f53d0..83594e4 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -219,8 +219,8 @@ function M:parse_response_without_stream(data, _, opts) end end -function M:parse_curl_args(provider, prompt_opts) - local provider_conf, request_body = P.parse_config(provider) +function M:parse_curl_args(prompt_opts) + local provider_conf, request_body = P.parse_config(self) local disable_tools = provider_conf.disable_tools or false local headers = { @@ -228,7 +228,7 @@ function M:parse_curl_args(provider, prompt_opts) } if P.env.require_api_key(provider_conf) then - local api_key = provider.parse_api_key() + local api_key = self.parse_api_key() if api_key == nil then error(Config.provider .. " API key is not set, please set it in your environment variable or config file") end diff --git a/lua/avante/providers/vertex.lua b/lua/avante/providers/vertex.lua index e1fb61e..ccd93af 100644 --- a/lua/avante/providers/vertex.lua +++ b/lua/avante/providers/vertex.lua @@ -32,8 +32,8 @@ function M.parse_api_key() return direct_output end -function M:parse_curl_args(provider, prompt_opts) - local provider_conf, request_body = P.parse_config(provider) +function M:parse_curl_args(prompt_opts) + local provider_conf, request_body = P.parse_config(self) local location = vim.fn.getenv("LOCATION") local project_id = vim.fn.getenv("PROJECT_ID") local model_id = provider_conf.model or "default-model-id" diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 2a881cd..3b99756 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -2073,11 +2073,11 @@ function Sidebar:get_layout() return vim.tbl_contains({ "left", "right" }, calculate_config_window_position()) and "vertical" or "horizontal" end ----@param history avante.ChatHistoryEntry[] +---@param history avante.ChatHistory ---@return string function Sidebar:render_history_content(history) local content = "" - for idx, entry in ipairs(history) do + for idx, entry in ipairs(history.entries) do if entry.reset_memory then content = content .. "***MEMORY RESET***\n\n" if idx < #history then content = content .. "-------\n\n" end @@ -2146,11 +2146,11 @@ end function Sidebar:clear_history(args, cb) local chat_history = Path.history.load(self.code.bufnr) if next(chat_history) ~= nil then - chat_history = {} + chat_history.entries = {} Path.history.save(self.code.bufnr, chat_history) self:update_content( "Chat history cleared", - { focus = false, scroll = false, callback = function() self:focus_input() end } + { ignore_history = true, focus = false, scroll = false, callback = function() self:focus_input() end } ) if cb then cb(args) end else @@ -2161,6 +2161,16 @@ function Sidebar:clear_history(args, cb) end end +function Sidebar:new_chat(args, cb) + Path.history.new(self.code.bufnr) + Sidebar.reload_chat_history() + self:update_content( + "New chat", + { ignore_history = true, focus = false, scroll = false, callback = function() self:focus_input() end } + ) + if cb then cb(args) end +end + function Sidebar:reset_memory(args, cb) local chat_history = Path.history.load(self.code.bufnr) if next(chat_history) ~= nil then @@ -2176,6 +2186,7 @@ function Sidebar:reset_memory(args, cb) reset_memory = true, }) Path.history.save(self.code.bufnr, chat_history) + Sidebar.reload_chat_history() local history_content = self:render_history_content(chat_history) self:update_content(history_content, { focus = false, @@ -2184,6 +2195,7 @@ function Sidebar:reset_memory(args, cb) }) if cb then cb(args) end else + Sidebar.reload_chat_history() self:update_content( "Chat history is already empty", { focus = false, scroll = false, callback = function() self:focus_input() end } @@ -2191,7 +2203,7 @@ function Sidebar:reset_memory(args, cb) end end ----@alias AvanteSlashCommandType "clear" | "help" | "lines" | "reset" | "commit" +---@alias AvanteSlashCommandType "clear" | "help" | "lines" | "reset" | "commit" | "new" ---@alias AvanteSlashCommandCallback fun(args: string, cb?: fun(args: string): nil): nil ---@alias AvanteSlashCommand {description: string, command: AvanteSlashCommandType, details: string, shorthelp?: string, callback?: AvanteSlashCommandCallback} ---@return AvanteSlashCommand[] @@ -2211,6 +2223,7 @@ function Sidebar:get_commands() { description = "Show help message", command = "help" }, { description = "Clear chat history", command = "clear" }, { description = "Reset memory", command = "reset" }, + { description = "New chat", command = "new" }, { shorthelp = "Ask a question about specific lines", description = "/lines - ", @@ -2228,6 +2241,7 @@ function Sidebar:get_commands() end, clear = function(args, cb) self:clear_history(args, cb) end, reset = function(args, cb) self:reset_memory(args, cb) end, + new = function(args, cb) self:new_chat(args, cb) end, lines = function(args, cb) if cb then cb(args) end end, @@ -2298,6 +2312,8 @@ function Sidebar:create_input_container(opts) local chat_history = Path.history.load(self.code.bufnr) + Sidebar.reload_chat_history = function() chat_history = Path.history.load(self.code.bufnr) end + local tools = vim.deepcopy(LLMTools.get_tools()) table.insert(tools, { name = "add_file_to_context", @@ -2330,8 +2346,9 @@ function Sidebar:create_input_container(opts) }) ---@param request string - ---@return AvanteGeneratePromptsOptions - local function get_generate_prompts_options(request) + ---@param summarize_memory boolean + ---@param cb fun(opts: AvanteGeneratePromptsOptions): nil + local function get_generate_prompts_options(request, summarize_memory, cb) local filetype = api.nvim_get_option_value("filetype", { buf = self.code.bufnr }) local selected_code_content = nil @@ -2355,40 +2372,19 @@ function Sidebar:create_input_container(opts) end end - local history_messages = {} - for i = #chat_history, 1, -1 do - local entry = chat_history[i] - if entry.reset_memory then break end - if - entry.request == nil - or entry.original_response == nil - or entry.request == "" - or entry.original_response == "" - then - break - end - table.insert( - history_messages, - 1, - { role = "assistant", content = Utils.trim_think_content(entry.original_response) } - ) - local user_content = "" - if entry.selected_file ~= nil then - user_content = user_content .. "SELECTED FILE: " .. entry.selected_file.filepath .. "\n\n" - end - if entry.selected_code ~= nil then - user_content = user_content - .. "SELECTED CODE:\n\n```" - .. entry.selected_code.filetype - .. "\n" - .. entry.selected_code.content - .. "\n```\n\n" - end - user_content = user_content .. "USER PROMPT:\n\n" .. entry.request - table.insert(history_messages, 1, { role = "user", content = user_content }) + local entries = Utils.history.filter_active_entries(chat_history.entries) + + if chat_history.memory then + entries = vim + .iter(entries) + :filter(function(entry) return entry.timestamp > chat_history.memory.last_summarized_timestamp end) + :totable() end - return { + local history_messages = Utils.history.entries_to_llm_messages(entries) + + ---@type AvanteGeneratePromptsOptions + local prompts_opts = { ask = opts.ask or true, project_context = vim.json.encode(project_context), selected_files = selected_files_contents, @@ -2400,6 +2396,20 @@ function Sidebar:create_input_container(opts) mode = Config.behaviour.enable_cursor_planning_mode and "cursor-planning" or "planning", tools = tools, } + + if chat_history.memory then prompts_opts.memory = chat_history.memory.content end + + if not summarize_memory or #history_messages < 8 then + cb(prompts_opts) + return + end + + prompts_opts.history_messages = vim.list_slice(prompts_opts.history_messages, 5) + + Llm.summarize_memory(self.code.bufnr, chat_history, function(memory) + if memory then prompts_opts.memory = memory.content end + cb(prompts_opts) + end) end ---@param request string @@ -2584,7 +2594,8 @@ function Sidebar:create_input_container(opts) end, 0) -- Save chat history - table.insert(chat_history or {}, { + chat_history.entries = chat_history.entries or {} + table.insert(chat_history.entries, { timestamp = timestamp, provider = Config.provider, model = model, @@ -2597,17 +2608,18 @@ function Sidebar:create_input_container(opts) Path.history.save(self.code.bufnr, chat_history) end - local generate_prompts_options = get_generate_prompts_options(request) - ---@type AvanteLLMStreamOptions - ---@diagnostic disable-next-line: assign-type-mismatch - local stream_options = vim.tbl_deep_extend("force", generate_prompts_options, { - on_start = on_start, - on_chunk = on_chunk, - on_stop = on_stop, - on_tool_log = on_tool_log, - }) + get_generate_prompts_options(request, true, function(generate_prompts_options) + ---@type AvanteLLMStreamOptions + ---@diagnostic disable-next-line: assign-type-mismatch + local stream_options = vim.tbl_deep_extend("force", generate_prompts_options, { + on_start = on_start, + on_chunk = on_chunk, + on_stop = on_stop, + on_tool_log = on_tool_log, + }) - Llm.stream(stream_options) + Llm.stream(stream_options) + end) end local function get_position() @@ -2762,37 +2774,43 @@ function Sidebar:create_input_container(opts) local hint_text = (fn.mode() ~= "i" and Config.mappings.submit.normal or Config.mappings.submit.insert) .. ": submit" - if Config.behaviour.enable_token_counting then - local input_value = table.concat(api.nvim_buf_get_lines(self.input_container.bufnr, 0, -1, false), "\n") - local generate_prompts_options = get_generate_prompts_options(input_value) - local tokens = Llm.calculate_tokens(generate_prompts_options) - hint_text = "Tokens: " .. tostring(tokens) .. "; " .. hint_text + local function show() + local buf = api.nvim_create_buf(false, true) + api.nvim_buf_set_lines(buf, 0, -1, false, { hint_text }) + api.nvim_buf_add_highlight(buf, 0, "AvantePopupHint", 0, 0, -1) + + -- Get the current window size + local win_width = api.nvim_win_get_width(self.input_container.winid) + local width = #hint_text + + -- Set the floating window options + local win_opts = { + relative = "win", + win = self.input_container.winid, + width = width, + height = 1, + row = get_float_window_row(), + col = math.max(win_width - width, 0), -- Display in the bottom right corner + style = "minimal", + border = "none", + focusable = false, + zindex = 100, + } + + -- Create the floating window + hint_window = api.nvim_open_win(buf, false, win_opts) end - local buf = api.nvim_create_buf(false, true) - api.nvim_buf_set_lines(buf, 0, -1, false, { hint_text }) - api.nvim_buf_add_highlight(buf, 0, "AvantePopupHint", 0, 0, -1) - - -- Get the current window size - local win_width = api.nvim_win_get_width(self.input_container.winid) - local width = #hint_text - - -- Set the floating window options - local win_opts = { - relative = "win", - win = self.input_container.winid, - width = width, - height = 1, - row = get_float_window_row(), - col = math.max(win_width - width, 0), -- Display in the bottom right corner - style = "minimal", - border = "none", - focusable = false, - zindex = 100, - } - - -- Create the floating window - hint_window = api.nvim_open_win(buf, false, win_opts) + if Config.behaviour.enable_token_counting then + local input_value = table.concat(api.nvim_buf_get_lines(self.input_container.bufnr, 0, -1, false), "\n") + get_generate_prompts_options(input_value, false, function(generate_prompts_options) + local tokens = Llm.calculate_tokens(generate_prompts_options) + hint_text = "Tokens: " .. tostring(tokens) .. "; " .. hint_text + show() + end) + else + show() + end end api.nvim_create_autocmd({ "TextChanged", "TextChangedI", "VimResized" }, { diff --git a/lua/avante/templates/_memory.avanterules b/lua/avante/templates/_memory.avanterules new file mode 100644 index 0000000..6a7a838 --- /dev/null +++ b/lua/avante/templates/_memory.avanterules @@ -0,0 +1,5 @@ +{%- if memory -%} + +{{memory}} + +{%- endif %} diff --git a/lua/avante/types.lua b/lua/avante/types.lua index 78591ba..e5b4bf3 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -190,13 +190,9 @@ vim.g.avante_login = vim.g.avante_login ---@alias AvanteMessagesParser fun(self: AvanteProviderFunctor, opts: AvantePromptOptions): AvanteChatMessage[] --- ---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table | string, headers: table, rawArgs: string[] | nil} ----@alias AvanteCurlArgsParser fun(self: AvanteProviderFunctor, provider: AvanteProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions): AvanteCurlOutput +---@alias AvanteCurlArgsParser fun(self: AvanteProviderFunctor, prompt_opts: AvantePromptOptions): AvanteCurlOutput --- ----@class AvanteResponseParserOptions ----@field on_start AvanteLLMStartCallback ----@field on_chunk AvanteLLMChunkCallback ----@field on_stop AvanteLLMStopCallback ----@alias AvanteResponseParser fun(self: AvanteProviderFunctor, ctx: any, data_stream: string, event_state: string, opts: AvanteResponseParserOptions): nil +---@alias AvanteResponseParser fun(self: AvanteProviderFunctor, ctx: any, data_stream: string, event_state: string, opts: AvanteHandlerOptions): nil --- ---@class AvanteDefaultBaseProvider: table ---@field endpoint? string @@ -305,6 +301,7 @@ vim.g.avante_login = vim.g.avante_login ---@field selected_files AvanteSelectedFiles[] | nil ---@field diagnostics string | nil ---@field history_messages AvanteLLMMessage[] | nil +---@field memory string | nil --- ---@class AvanteGeneratePromptsOptions: AvanteTemplateOptions ---@field ask boolean @@ -358,3 +355,35 @@ vim.g.avante_login = vim.g.avante_login ---@field description string ---@field type 'string' | 'string[]' | 'boolean' ---@field optional? boolean +--- +---@class avante.ChatHistoryEntry +---@field timestamp string +---@field provider string +---@field model string +---@field request string +---@field response string +---@field original_response string +---@field selected_file {filepath: string}? +---@field selected_code {filetype: string, content: string}? +---@field reset_memory boolean? +---@field selected_filepaths string[] | nil +--- +---@class avante.ChatHistory +---@field title string +---@field timestamp string +---@field entries avante.ChatHistoryEntry[] +---@field memory avante.ChatMemory | nil +--- +---@class avante.ChatMemory +---@field content string +---@field last_summarized_timestamp string + +---@class avante.Path +---@field history_path Path +---@field cache_path Path +--- +---@class avante.CurlOpts +---@field provider AvanteProviderFunctor +---@field prompt_opts AvantePromptOptions +---@field handler_opts AvanteHandlerOptions +--- diff --git a/lua/avante/utils/history.lua b/lua/avante/utils/history.lua new file mode 100644 index 0000000..89085fa --- /dev/null +++ b/lua/avante/utils/history.lua @@ -0,0 +1,52 @@ +local Utils = require("avante.utils") + +---@class avante.utils.history +local M = {} + +---@param entries avante.ChatHistoryEntry[] +---@return avante.ChatHistoryEntry[] +function M.filter_active_entries(entries) + local entries_ = {} + + for i = #entries, 1, -1 do + local entry = entries[i] + if entry.reset_memory then break end + if + entry.request == nil + or entry.original_response == nil + or entry.request == "" + or entry.original_response == "" + then + break + end + table.insert(entries_, 1, entry) + end + + return entries_ +end + +---@param entries avante.ChatHistoryEntry[] +---@return AvanteLLMMessage[] +function M.entries_to_llm_messages(entries) + local messages = {} + for _, entry in ipairs(entries) do + local user_content = "" + if entry.selected_file ~= nil then + user_content = user_content .. "SELECTED FILE: " .. entry.selected_file.filepath .. "\n\n" + end + if entry.selected_code ~= nil then + user_content = user_content + .. "SELECTED CODE:\n\n```" + .. entry.selected_code.filetype + .. "\n" + .. entry.selected_code.content + .. "\n```\n\n" + end + user_content = user_content .. "USER PROMPT:\n\n" .. entry.request + table.insert(messages, { role = "user", content = user_content }) + table.insert(messages, { role = "assistant", content = Utils.trim_think_content(entry.original_response) }) + end + return messages +end + +return M diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index 0b36502..93b34b4 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -8,6 +8,7 @@ local lsp = vim.lsp ---@field tokens avante.utils.tokens ---@field root avante.utils.root ---@field file avante.utils.file +---@field history avante.utils.history local M = {} setmetatable(M, {