refactor: summarize memory (#1508)

This commit is contained in:
yetone
2025-03-07 00:12:57 +08:00
committed by GitHub
parent 2b0e7e09ae
commit 8620ea3e12
18 changed files with 434 additions and 217 deletions

View File

@@ -33,6 +33,7 @@ struct TemplateContext {
diagnostics: Option<String>, diagnostics: Option<String>,
system_info: Option<String>, system_info: Option<String>,
model_name: Option<String>, model_name: Option<String>,
memory: Option<String>,
} }
// Given the file name registered after add, the context table in Lua, resulted in a formatted // Given the file name registered after add, the context table in Lua, resulted in a formatted
@@ -58,6 +59,7 @@ fn render(state: &State, template: &str, context: TemplateContext) -> LuaResult<
diagnostics => context.diagnostics, diagnostics => context.diagnostics,
system_info => context.system_info, system_info => context.system_info,
model_name => context.model_name, model_name => context.model_name,
memory => context.memory,
}) })
.map_err(LuaError::external) .map_err(LuaError::external)
.unwrap()) .unwrap())

View File

@@ -27,6 +27,7 @@ M._defaults = {
-- Of course, you can reduce the request frequency by increasing `suggestion.debounce`. -- Of course, you can reduce the request frequency by increasing `suggestion.debounce`.
auto_suggestions_provider = "claude", auto_suggestions_provider = "claude",
cursor_applying_provider = nil, cursor_applying_provider = nil,
memory_summary_provider = nil,
---@alias Tokenizer "tiktoken" | "hf" ---@alias Tokenizer "tiktoken" | "hf"
-- Used for counting tokens and encoding text. -- Used for counting tokens and encoding text.
-- By default, we will use tiktoken. -- By default, we will use tiktoken.
@@ -273,6 +274,10 @@ M._defaults = {
temperature = 0, temperature = 0,
max_tokens = 8000, max_tokens = 8000,
}, },
["openai-gpt-4o-mini"] = {
__inherited_from = "openai",
model = "gpt-4o-mini",
},
}, },
---Specify the special dual_boost mode ---Specify the special dual_boost mode
---1. enabled: Whether to enable dual_boost mode. Default to false. ---1. enabled: Whether to enable dual_boost mode. Default to false.

View File

@@ -19,6 +19,71 @@ M.CANCEL_PATTERN = "AvanteLLMEscape"
local group = api.nvim_create_augroup("avante_llm", { clear = true }) local group = api.nvim_create_augroup("avante_llm", { clear = true })
---@param bufnr integer
---@param history avante.ChatHistory
---@param cb fun(memory: avante.ChatMemory | nil): nil
function M.summarize_memory(bufnr, history, cb)
local system_prompt =
[[Summarize the following conversation to extract the most critical information (such as languages used, conversation style, tech stack, considerations, user information, etc.) for memory in subsequent conversations. Since it is for memory purposes, be detailed and rigorous to ensure that no information from previous summaries is lost in the newly generated summary.]]
local entries = Utils.history.filter_active_entries(history.entries)
if #entries == 0 then
cb(nil)
return
end
if history.memory then
system_prompt = system_prompt .. "\n\nThe previous summary is:\n\n" .. history.memory.content
entries = vim
.iter(entries)
:filter(function(entry) return entry.timestamp > history.memory.last_summarized_timestamp end)
:totable()
end
if #entries == 0 then
cb(history.memory)
return
end
local history_messages = Utils.history.entries_to_llm_messages(entries)
history_messages = vim.list_slice(history_messages, 1, 4)
if #history_messages == 0 then
cb(history.memory)
return
end
Utils.debug("summarize memory", #history_messages, history_messages[#history_messages].content)
local response_content = ""
local provider = Providers[Config.memory_summary_provider or Config.provider]
M.curl({
provider = provider,
prompt_opts = {
system_prompt = system_prompt,
messages = {
{ role = "user", content = vim.json.encode(history_messages) },
},
},
handler_opts = {
on_start = function(_) end,
on_chunk = function(chunk)
if not chunk then return end
response_content = response_content .. chunk
end,
on_stop = function(stop_opts)
if stop_opts.error ~= nil then
Utils.error(string.format("summarize failed: %s", vim.inspect(stop_opts.error)))
return
end
if stop_opts.reason == "complete" then
response_content = Utils.trim_think_content(response_content)
local memory = {
content = response_content,
last_summarized_timestamp = entries[#entries].timestamp,
}
history.memory = memory
Path.history.save(bufnr, history)
cb(memory)
end
end,
},
})
end
---@param opts AvanteGeneratePromptsOptions ---@param opts AvanteGeneratePromptsOptions
---@return AvantePromptOptions ---@return AvantePromptOptions
function M.generate_prompts(opts) function M.generate_prompts(opts)
@@ -58,6 +123,7 @@ function M.generate_prompts(opts)
diagnostics = opts.diagnostics, diagnostics = opts.diagnostics,
system_info = system_info, system_info = system_info,
model_name = provider.model or "unknown", model_name = provider.model or "unknown",
memory = opts.memory,
} }
local system_prompt = Path.prompts.render_mode(mode, template_opts) local system_prompt = Path.prompts.render_mode(mode, template_opts)
@@ -80,6 +146,11 @@ function M.generate_prompts(opts)
if code_context ~= "" then table.insert(messages, { role = "user", content = code_context }) end if code_context ~= "" then table.insert(messages, { role = "user", content = code_context }) end
end end
if opts.memory ~= nil and opts.memory ~= "" and opts.memory ~= "null" then
local memory = Path.prompts.render_file("_memory.avanterules", template_opts)
if memory ~= "" then table.insert(messages, { role = "user", content = memory }) end
end
if instructions then if instructions then
if opts.use_xml_format then if opts.use_xml_format then
table.insert(messages, { role = "user", content = string.format("<question>%s</question>", instructions) }) table.insert(messages, { role = "user", content = string.format("<question>%s</question>", instructions) })
@@ -150,90 +221,17 @@ function M.calculate_tokens(opts)
return tokens return tokens
end end
---@param opts AvanteLLMStreamOptions ---@param opts avante.CurlOpts
function M._stream(opts) function M.curl(opts)
local provider = opts.provider or Providers[Config.provider] local provider = opts.provider
local prompt_opts = opts.prompt_opts
local handler_opts = opts.handler_opts
---@cast provider AvanteProviderFunctor ---@type AvanteCurlOutput
local spec = provider:parse_curl_args(prompt_opts)
local prompt_opts = M.generate_prompts(opts)
---@type string ---@type string
local current_event_state = nil local current_event_state = nil
---@type AvanteHandlerOptions
local handler_opts = {
on_start = opts.on_start,
on_chunk = opts.on_chunk,
on_stop = function(stop_opts)
---@param tool_use_list AvanteLLMToolUse[]
---@param tool_use_index integer
---@param tool_histories AvanteLLMToolHistory[]
local function handle_next_tool_use(tool_use_list, tool_use_index, tool_histories)
if tool_use_index > #tool_use_list then
local new_opts = vim.tbl_deep_extend("force", opts, {
tool_histories = tool_histories,
})
return M._stream(new_opts)
end
local tool_use = tool_use_list[tool_use_index]
---@param result string | nil
---@param error string | nil
local function handle_tool_result(result, error)
local tool_result = {
tool_use_id = tool_use.id,
content = error ~= nil and error or result,
is_error = error ~= nil,
}
table.insert(tool_histories, { tool_result = tool_result, tool_use = tool_use })
return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_histories)
end
-- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil
local result, error = LLMTools.process_tool_use(opts.tools, tool_use, opts.on_tool_log, handle_tool_result)
if result ~= nil or error ~= nil then return handle_tool_result(result, error) end
end
if stop_opts.reason == "tool_use" and stop_opts.tool_use_list then
local old_tool_histories = vim.deepcopy(opts.tool_histories) or {}
local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[]
for _, tool_use in vim.spairs(stop_opts.tool_use_list) do
table.insert(sorted_tool_use_list, tool_use)
end
return handle_next_tool_use(sorted_tool_use_list, 1, old_tool_histories)
end
if stop_opts.reason == "rate_limit" then
local msg = "Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ..."
opts.on_chunk("\n*[" .. msg .. "]*\n")
local timer = vim.loop.new_timer()
if timer then
local retry_after = stop_opts.retry_after
local function countdown()
timer:start(
1000,
0,
vim.schedule_wrap(function()
if retry_after > 0 then retry_after = retry_after - 1 end
local msg_ = "Rate limit reached. Retrying in " .. retry_after .. " seconds ..."
opts.on_chunk([[\033[1A\033[K]] .. "\n*[" .. msg_ .. "]*\n")
countdown()
end)
)
end
countdown()
end
Utils.info("Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds", { title = "Avante" })
vim.defer_fn(function()
if timer then timer:stop() end
M._stream(opts)
end, stop_opts.retry_after * 1000)
return
end
return opts.on_stop(stop_opts)
end,
}
---@type AvanteCurlOutput
local spec = provider:parse_curl_args(provider, prompt_opts)
local resp_ctx = {} local resp_ctx = {}
---@param line string ---@param line string
@@ -383,6 +381,91 @@ function M._stream(opts)
return active_job return active_job
end end
---@param opts AvanteLLMStreamOptions
function M._stream(opts)
local provider = opts.provider or Providers[Config.provider]
---@cast provider AvanteProviderFunctor
local prompt_opts = M.generate_prompts(opts)
---@type AvanteHandlerOptions
local handler_opts = {
on_start = opts.on_start,
on_chunk = opts.on_chunk,
on_stop = function(stop_opts)
---@param tool_use_list AvanteLLMToolUse[]
---@param tool_use_index integer
---@param tool_histories AvanteLLMToolHistory[]
local function handle_next_tool_use(tool_use_list, tool_use_index, tool_histories)
if tool_use_index > #tool_use_list then
local new_opts = vim.tbl_deep_extend("force", opts, {
tool_histories = tool_histories,
})
return M._stream(new_opts)
end
local tool_use = tool_use_list[tool_use_index]
---@param result string | nil
---@param error string | nil
local function handle_tool_result(result, error)
local tool_result = {
tool_use_id = tool_use.id,
content = error ~= nil and error or result,
is_error = error ~= nil,
}
table.insert(tool_histories, { tool_result = tool_result, tool_use = tool_use })
return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_histories)
end
-- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil
local result, error = LLMTools.process_tool_use(opts.tools, tool_use, opts.on_tool_log, handle_tool_result)
if result ~= nil or error ~= nil then return handle_tool_result(result, error) end
end
if stop_opts.reason == "tool_use" and stop_opts.tool_use_list then
local old_tool_histories = vim.deepcopy(opts.tool_histories) or {}
local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[]
for _, tool_use in vim.spairs(stop_opts.tool_use_list) do
table.insert(sorted_tool_use_list, tool_use)
end
return handle_next_tool_use(sorted_tool_use_list, 1, old_tool_histories)
end
if stop_opts.reason == "rate_limit" then
local msg = "Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ..."
opts.on_chunk("\n*[" .. msg .. "]*\n")
local timer = vim.loop.new_timer()
if timer then
local retry_after = stop_opts.retry_after
local function countdown()
timer:start(
1000,
0,
vim.schedule_wrap(function()
if retry_after > 0 then retry_after = retry_after - 1 end
local msg_ = "Rate limit reached. Retrying in " .. retry_after .. " seconds ..."
opts.on_chunk([[\033[1A\033[K]] .. "\n*[" .. msg_ .. "]*\n")
countdown()
end)
)
end
countdown()
end
Utils.info("Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds", { title = "Avante" })
vim.defer_fn(function()
if timer then timer:stop() end
M._stream(opts)
end, stop_opts.retry_after * 1000)
return
end
return opts.on_stop(stop_opts)
end,
}
return M.curl({
provider = provider,
prompt_opts = prompt_opts,
handler_opts = handler_opts,
})
end
local function _merge_response(first_response, second_response, opts) local function _merge_response(first_response, second_response, opts)
local prompt = "\n" .. Config.dual_boost.prompt local prompt = "\n" .. Config.dual_boost.prompt
prompt = prompt prompt = prompt

View File

@@ -5,58 +5,75 @@ local Path = require("plenary.path")
local Scan = require("plenary.scandir") local Scan = require("plenary.scandir")
local Config = require("avante.config") local Config = require("avante.config")
---@class avante.ChatHistoryEntry
---@field timestamp string
---@field provider string
---@field model string
---@field request string
---@field response string
---@field original_response string
---@field selected_file {filepath: string}?
---@field selected_code {filetype: string, content: string}?
---@field reset_memory boolean?
---@field selected_filepaths string[] | nil
---@class avante.Path ---@class avante.Path
---@field history_path Path
---@field cache_path Path
local P = {} local P = {}
local history_file_cache = LRUCache:new(12) local history_file_cache = LRUCache:new(12)
-- History path ---@param bufnr integer | nil
local History = {} ---@return string dirname
local function generate_project_dirname_in_storage(bufnr)
-- Get a chat history file name given a buffer
---@param bufnr integer
---@return string
function History.filename(bufnr)
local project_root = Utils.root.get({ local project_root = Utils.root.get({
buf = bufnr, buf = bufnr,
}) })
-- Replace path separators with double underscores -- Replace path separators with double underscores
local path_with_separators = fn.substitute(project_root, "/", "__", "g") local path_with_separators = fn.substitute(project_root, "/", "__", "g")
-- Replace other non-alphanumeric characters with single underscores -- Replace other non-alphanumeric characters with single underscores
return fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g") .. ".json" local dirname = fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g")
return tostring(Path:new("projects"):joinpath(dirname))
end
-- History path
local History = {}
-- Get a chat history file name given a buffer
---@param bufnr integer
---@param new boolean
---@return Path
function History.filepath(bufnr, new)
local dirname = generate_project_dirname_in_storage(bufnr)
local history_dir = Path:new(Config.history.storage_path):joinpath(dirname):joinpath("history")
if not history_dir:exists() then history_dir:mkdir({ parents = true }) end
local pattern = tostring(history_dir:joinpath("*.json"))
local files = vim.fn.glob(pattern, true, true)
local filename = #files .. ".json"
if #files > 0 and not new then filename = (#files - 1) .. ".json" end
return history_dir:joinpath(filename)
end end
-- Returns the Path to the chat history file for the given buffer. -- Returns the Path to the chat history file for the given buffer.
---@param bufnr integer ---@param bufnr integer
---@return Path ---@return Path
function History.get(bufnr) return Path:new(Config.history.storage_path):joinpath(History.filename(bufnr)) end function History.get(bufnr) return History.filepath(bufnr, false) end
---@param bufnr integer
function History.new(bufnr)
local filepath = History.filepath(bufnr, true)
local history = {
title = "untitled",
timestamp = tostring(os.date("%Y-%m-%d %H:%M:%S")),
entries = {},
}
filepath:write(vim.json.encode(history), "w")
end
-- Loads the chat history for the given buffer. -- Loads the chat history for the given buffer.
---@param bufnr integer ---@param bufnr integer
---@return avante.ChatHistoryEntry[] ---@return avante.ChatHistory
function History.load(bufnr) function History.load(bufnr)
local history_file = History.get(bufnr) local history_file = History.get(bufnr)
local cached_key = tostring(history_file:absolute()) local cached_key = tostring(history_file:absolute())
local cached_value = history_file_cache:get(cached_key) local cached_value = history_file_cache:get(cached_key)
if cached_value ~= nil then return cached_value end if cached_value ~= nil then return cached_value end
local value = {} ---@type avante.ChatHistory
local value = {
title = "untitled",
timestamp = tostring(os.date("%Y-%m-%d %H:%M:%S")),
entries = {},
}
if history_file:exists() then if history_file:exists() then
local content = history_file:read() local content = history_file:read()
value = content ~= nil and vim.json.decode(content) or {} value = content ~= nil and vim.json.decode(content) or value
end end
history_file_cache:set(cached_key, value) history_file_cache:set(cached_key, value)
return value return value
@@ -64,7 +81,7 @@ end
-- Saves the chat history for the given buffer. -- Saves the chat history for the given buffer.
---@param bufnr integer ---@param bufnr integer
---@param history avante.ChatHistoryEntry[] ---@param history avante.ChatHistory
History.save = vim.schedule_wrap(function(bufnr, history) History.save = vim.schedule_wrap(function(bufnr, history)
local history_file = History.get(bufnr) local history_file = History.get(bufnr)
local cached_key = tostring(history_file:absolute()) local cached_key = tostring(history_file:absolute())

View File

@@ -18,8 +18,8 @@ M.parse_response = O.parse_response
M.parse_response_without_stream = O.parse_response_without_stream M.parse_response_without_stream = O.parse_response_without_stream
M.is_disable_stream = O.is_disable_stream M.is_disable_stream = O.is_disable_stream
function M:parse_curl_args(provider, prompt_opts) function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = P.parse_config(provider) local provider_conf, request_body = P.parse_config(self)
local headers = { local headers = {
["Content-Type"] = "application/json", ["Content-Type"] = "application/json",
@@ -27,9 +27,9 @@ function M:parse_curl_args(provider, prompt_opts)
if P.env.require_api_key(provider_conf) then if P.env.require_api_key(provider_conf) then
if provider_conf.entra then if provider_conf.entra then
headers["Authorization"] = "Bearer " .. provider.parse_api_key() headers["Authorization"] = "Bearer " .. self.parse_api_key()
else else
headers["api-key"] = provider.parse_api_key() headers["api-key"] = self.parse_api_key()
end end
end end

View File

@@ -51,13 +51,12 @@ function M:parse_response_without_stream(data, event_state, opts)
vim.schedule(function() opts.on_stop({ reason = "complete" }) end) vim.schedule(function() opts.on_stop({ reason = "complete" }) end)
end end
---@param provider AvanteBedrockProviderFunctor
---@param prompt_opts AvantePromptOptions ---@param prompt_opts AvantePromptOptions
---@return table ---@return table
function M:parse_curl_args(provider, prompt_opts) function M:parse_curl_args(prompt_opts)
local base, body_opts = P.parse_config(provider) local base, body_opts = P.parse_config(self)
local api_key = provider.parse_api_key() local api_key = self.parse_api_key()
if api_key == nil then error("Cannot get the bedrock api key!") end if api_key == nil then error("Cannot get the bedrock api key!") end
local parts = vim.split(api_key, ",") local parts = vim.split(api_key, ",")
local aws_access_key_id = parts[1] local aws_access_key_id = parts[1]

View File

@@ -262,11 +262,10 @@ function M:parse_response(ctx, data_stream, event_state, opts)
end end
end end
---@param provider AvanteProviderFunctor
---@param prompt_opts AvantePromptOptions ---@param prompt_opts AvantePromptOptions
---@return table ---@return table
function M:parse_curl_args(provider, prompt_opts) function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = P.parse_config(provider) local provider_conf, request_body = P.parse_config(self)
local disable_tools = provider_conf.disable_tools or false local disable_tools = provider_conf.disable_tools or false
local headers = { local headers = {
@@ -275,7 +274,7 @@ function M:parse_curl_args(provider, prompt_opts)
["anthropic-beta"] = "prompt-caching-2024-07-31", ["anthropic-beta"] = "prompt-caching-2024-07-31",
} }
if P.env.require_api_key(provider_conf) then headers["x-api-key"] = provider.parse_api_key() end if P.env.require_api_key(provider_conf) then headers["x-api-key"] = self.parse_api_key() end
local messages = self:parse_messages(prompt_opts) local messages = self:parse_messages(prompt_opts)

View File

@@ -71,8 +71,8 @@ function M:parse_stream_data(ctx, data, opts)
end end
end end
function M:parse_curl_args(provider, prompt_opts) function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = P.parse_config(provider) local provider_conf, request_body = P.parse_config(self)
local headers = { local headers = {
["Accept"] = "application/json", ["Accept"] = "application/json",
@@ -84,7 +84,7 @@ function M:parse_curl_args(provider, prompt_opts)
.. "." .. "."
.. vim.version().patch, .. vim.version().patch,
} }
if P.env.require_api_key(provider_conf) then headers["Authorization"] = "Bearer " .. provider.parse_api_key() end if P.env.require_api_key(provider_conf) then headers["Authorization"] = "Bearer " .. self.parse_api_key() end
return { return {
url = Utils.url_join(provider_conf.endpoint, "/chat"), url = Utils.url_join(provider_conf.endpoint, "/chat"),

View File

@@ -215,12 +215,12 @@ M.parse_messages = OpenAI.parse_messages
M.parse_response = OpenAI.parse_response M.parse_response = OpenAI.parse_response
function M:parse_curl_args(provider, prompt_opts) function M:parse_curl_args(prompt_opts)
-- refresh token synchronously, only if it has expired -- refresh token synchronously, only if it has expired
-- (this should rarely happen, as we refresh the token in the background) -- (this should rarely happen, as we refresh the token in the background)
H.refresh_token(false, false) H.refresh_token(false, false)
local provider_conf, request_body = P.parse_config(provider) local provider_conf, request_body = P.parse_config(self)
local disable_tools = provider_conf.disable_tools or false local disable_tools = provider_conf.disable_tools or false
local tools = {} local tools = {}

View File

@@ -83,8 +83,8 @@ function M:parse_response(ctx, data_stream, _, opts)
end end
end end
function M:parse_curl_args(provider, prompt_opts) function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = P.parse_config(provider) local provider_conf, request_body = P.parse_config(self)
request_body = vim.tbl_deep_extend("force", request_body, { request_body = vim.tbl_deep_extend("force", request_body, {
generationConfig = { generationConfig = {
@@ -95,7 +95,7 @@ function M:parse_curl_args(provider, prompt_opts)
request_body.temperature = nil request_body.temperature = nil
request_body.max_tokens = nil request_body.max_tokens = nil
local api_key = provider.parse_api_key() local api_key = self.parse_api_key()
if api_key == nil then error("Cannot get the gemini api key!") end if api_key == nil then error("Cannot get the gemini api key!") end
return { return {

View File

@@ -261,6 +261,13 @@ function M.setup()
E.setup({ provider = cursor_applying_provider }) E.setup({ provider = cursor_applying_provider })
end end
end end
if Config.memory_summary_provider then
local memory_summary_provider = M[Config.memory_summary_provider]
if memory_summary_provider and memory_summary_provider ~= provider then
E.setup({ provider = memory_summary_provider })
end
end
end end
---@param provider Provider ---@param provider Provider

View File

@@ -219,8 +219,8 @@ function M:parse_response_without_stream(data, _, opts)
end end
end end
function M:parse_curl_args(provider, prompt_opts) function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = P.parse_config(provider) local provider_conf, request_body = P.parse_config(self)
local disable_tools = provider_conf.disable_tools or false local disable_tools = provider_conf.disable_tools or false
local headers = { local headers = {
@@ -228,7 +228,7 @@ function M:parse_curl_args(provider, prompt_opts)
} }
if P.env.require_api_key(provider_conf) then if P.env.require_api_key(provider_conf) then
local api_key = provider.parse_api_key() local api_key = self.parse_api_key()
if api_key == nil then if api_key == nil then
error(Config.provider .. " API key is not set, please set it in your environment variable or config file") error(Config.provider .. " API key is not set, please set it in your environment variable or config file")
end end

View File

@@ -32,8 +32,8 @@ function M.parse_api_key()
return direct_output return direct_output
end end
function M:parse_curl_args(provider, prompt_opts) function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = P.parse_config(provider) local provider_conf, request_body = P.parse_config(self)
local location = vim.fn.getenv("LOCATION") local location = vim.fn.getenv("LOCATION")
local project_id = vim.fn.getenv("PROJECT_ID") local project_id = vim.fn.getenv("PROJECT_ID")
local model_id = provider_conf.model or "default-model-id" local model_id = provider_conf.model or "default-model-id"

View File

@@ -2073,11 +2073,11 @@ function Sidebar:get_layout()
return vim.tbl_contains({ "left", "right" }, calculate_config_window_position()) and "vertical" or "horizontal" return vim.tbl_contains({ "left", "right" }, calculate_config_window_position()) and "vertical" or "horizontal"
end end
---@param history avante.ChatHistoryEntry[] ---@param history avante.ChatHistory
---@return string ---@return string
function Sidebar:render_history_content(history) function Sidebar:render_history_content(history)
local content = "" local content = ""
for idx, entry in ipairs(history) do for idx, entry in ipairs(history.entries) do
if entry.reset_memory then if entry.reset_memory then
content = content .. "***MEMORY RESET***\n\n" content = content .. "***MEMORY RESET***\n\n"
if idx < #history then content = content .. "-------\n\n" end if idx < #history then content = content .. "-------\n\n" end
@@ -2146,11 +2146,11 @@ end
function Sidebar:clear_history(args, cb) function Sidebar:clear_history(args, cb)
local chat_history = Path.history.load(self.code.bufnr) local chat_history = Path.history.load(self.code.bufnr)
if next(chat_history) ~= nil then if next(chat_history) ~= nil then
chat_history = {} chat_history.entries = {}
Path.history.save(self.code.bufnr, chat_history) Path.history.save(self.code.bufnr, chat_history)
self:update_content( self:update_content(
"Chat history cleared", "Chat history cleared",
{ focus = false, scroll = false, callback = function() self:focus_input() end } { ignore_history = true, focus = false, scroll = false, callback = function() self:focus_input() end }
) )
if cb then cb(args) end if cb then cb(args) end
else else
@@ -2161,6 +2161,16 @@ function Sidebar:clear_history(args, cb)
end end
end end
function Sidebar:new_chat(args, cb)
Path.history.new(self.code.bufnr)
Sidebar.reload_chat_history()
self:update_content(
"New chat",
{ ignore_history = true, focus = false, scroll = false, callback = function() self:focus_input() end }
)
if cb then cb(args) end
end
function Sidebar:reset_memory(args, cb) function Sidebar:reset_memory(args, cb)
local chat_history = Path.history.load(self.code.bufnr) local chat_history = Path.history.load(self.code.bufnr)
if next(chat_history) ~= nil then if next(chat_history) ~= nil then
@@ -2176,6 +2186,7 @@ function Sidebar:reset_memory(args, cb)
reset_memory = true, reset_memory = true,
}) })
Path.history.save(self.code.bufnr, chat_history) Path.history.save(self.code.bufnr, chat_history)
Sidebar.reload_chat_history()
local history_content = self:render_history_content(chat_history) local history_content = self:render_history_content(chat_history)
self:update_content(history_content, { self:update_content(history_content, {
focus = false, focus = false,
@@ -2184,6 +2195,7 @@ function Sidebar:reset_memory(args, cb)
}) })
if cb then cb(args) end if cb then cb(args) end
else else
Sidebar.reload_chat_history()
self:update_content( self:update_content(
"Chat history is already empty", "Chat history is already empty",
{ focus = false, scroll = false, callback = function() self:focus_input() end } { focus = false, scroll = false, callback = function() self:focus_input() end }
@@ -2191,7 +2203,7 @@ function Sidebar:reset_memory(args, cb)
end end
end end
---@alias AvanteSlashCommandType "clear" | "help" | "lines" | "reset" | "commit" ---@alias AvanteSlashCommandType "clear" | "help" | "lines" | "reset" | "commit" | "new"
---@alias AvanteSlashCommandCallback fun(args: string, cb?: fun(args: string): nil): nil ---@alias AvanteSlashCommandCallback fun(args: string, cb?: fun(args: string): nil): nil
---@alias AvanteSlashCommand {description: string, command: AvanteSlashCommandType, details: string, shorthelp?: string, callback?: AvanteSlashCommandCallback} ---@alias AvanteSlashCommand {description: string, command: AvanteSlashCommandType, details: string, shorthelp?: string, callback?: AvanteSlashCommandCallback}
---@return AvanteSlashCommand[] ---@return AvanteSlashCommand[]
@@ -2211,6 +2223,7 @@ function Sidebar:get_commands()
{ description = "Show help message", command = "help" }, { description = "Show help message", command = "help" },
{ description = "Clear chat history", command = "clear" }, { description = "Clear chat history", command = "clear" },
{ description = "Reset memory", command = "reset" }, { description = "Reset memory", command = "reset" },
{ description = "New chat", command = "new" },
{ {
shorthelp = "Ask a question about specific lines", shorthelp = "Ask a question about specific lines",
description = "/lines <start>-<end> <question>", description = "/lines <start>-<end> <question>",
@@ -2228,6 +2241,7 @@ function Sidebar:get_commands()
end, end,
clear = function(args, cb) self:clear_history(args, cb) end, clear = function(args, cb) self:clear_history(args, cb) end,
reset = function(args, cb) self:reset_memory(args, cb) end, reset = function(args, cb) self:reset_memory(args, cb) end,
new = function(args, cb) self:new_chat(args, cb) end,
lines = function(args, cb) lines = function(args, cb)
if cb then cb(args) end if cb then cb(args) end
end, end,
@@ -2298,6 +2312,8 @@ function Sidebar:create_input_container(opts)
local chat_history = Path.history.load(self.code.bufnr) local chat_history = Path.history.load(self.code.bufnr)
Sidebar.reload_chat_history = function() chat_history = Path.history.load(self.code.bufnr) end
local tools = vim.deepcopy(LLMTools.get_tools()) local tools = vim.deepcopy(LLMTools.get_tools())
table.insert(tools, { table.insert(tools, {
name = "add_file_to_context", name = "add_file_to_context",
@@ -2330,8 +2346,9 @@ function Sidebar:create_input_container(opts)
}) })
---@param request string ---@param request string
---@return AvanteGeneratePromptsOptions ---@param summarize_memory boolean
local function get_generate_prompts_options(request) ---@param cb fun(opts: AvanteGeneratePromptsOptions): nil
local function get_generate_prompts_options(request, summarize_memory, cb)
local filetype = api.nvim_get_option_value("filetype", { buf = self.code.bufnr }) local filetype = api.nvim_get_option_value("filetype", { buf = self.code.bufnr })
local selected_code_content = nil local selected_code_content = nil
@@ -2355,40 +2372,19 @@ function Sidebar:create_input_container(opts)
end end
end end
local history_messages = {} local entries = Utils.history.filter_active_entries(chat_history.entries)
for i = #chat_history, 1, -1 do
local entry = chat_history[i] if chat_history.memory then
if entry.reset_memory then break end entries = vim
if .iter(entries)
entry.request == nil :filter(function(entry) return entry.timestamp > chat_history.memory.last_summarized_timestamp end)
or entry.original_response == nil :totable()
or entry.request == ""
or entry.original_response == ""
then
break
end
table.insert(
history_messages,
1,
{ role = "assistant", content = Utils.trim_think_content(entry.original_response) }
)
local user_content = ""
if entry.selected_file ~= nil then
user_content = user_content .. "SELECTED FILE: " .. entry.selected_file.filepath .. "\n\n"
end
if entry.selected_code ~= nil then
user_content = user_content
.. "SELECTED CODE:\n\n```"
.. entry.selected_code.filetype
.. "\n"
.. entry.selected_code.content
.. "\n```\n\n"
end
user_content = user_content .. "USER PROMPT:\n\n" .. entry.request
table.insert(history_messages, 1, { role = "user", content = user_content })
end end
return { local history_messages = Utils.history.entries_to_llm_messages(entries)
---@type AvanteGeneratePromptsOptions
local prompts_opts = {
ask = opts.ask or true, ask = opts.ask or true,
project_context = vim.json.encode(project_context), project_context = vim.json.encode(project_context),
selected_files = selected_files_contents, selected_files = selected_files_contents,
@@ -2400,6 +2396,20 @@ function Sidebar:create_input_container(opts)
mode = Config.behaviour.enable_cursor_planning_mode and "cursor-planning" or "planning", mode = Config.behaviour.enable_cursor_planning_mode and "cursor-planning" or "planning",
tools = tools, tools = tools,
} }
if chat_history.memory then prompts_opts.memory = chat_history.memory.content end
if not summarize_memory or #history_messages < 8 then
cb(prompts_opts)
return
end
prompts_opts.history_messages = vim.list_slice(prompts_opts.history_messages, 5)
Llm.summarize_memory(self.code.bufnr, chat_history, function(memory)
if memory then prompts_opts.memory = memory.content end
cb(prompts_opts)
end)
end end
---@param request string ---@param request string
@@ -2584,7 +2594,8 @@ function Sidebar:create_input_container(opts)
end, 0) end, 0)
-- Save chat history -- Save chat history
table.insert(chat_history or {}, { chat_history.entries = chat_history.entries or {}
table.insert(chat_history.entries, {
timestamp = timestamp, timestamp = timestamp,
provider = Config.provider, provider = Config.provider,
model = model, model = model,
@@ -2597,17 +2608,18 @@ function Sidebar:create_input_container(opts)
Path.history.save(self.code.bufnr, chat_history) Path.history.save(self.code.bufnr, chat_history)
end end
local generate_prompts_options = get_generate_prompts_options(request) get_generate_prompts_options(request, true, function(generate_prompts_options)
---@type AvanteLLMStreamOptions ---@type AvanteLLMStreamOptions
---@diagnostic disable-next-line: assign-type-mismatch ---@diagnostic disable-next-line: assign-type-mismatch
local stream_options = vim.tbl_deep_extend("force", generate_prompts_options, { local stream_options = vim.tbl_deep_extend("force", generate_prompts_options, {
on_start = on_start, on_start = on_start,
on_chunk = on_chunk, on_chunk = on_chunk,
on_stop = on_stop, on_stop = on_stop,
on_tool_log = on_tool_log, on_tool_log = on_tool_log,
}) })
Llm.stream(stream_options) Llm.stream(stream_options)
end)
end end
local function get_position() local function get_position()
@@ -2762,37 +2774,43 @@ function Sidebar:create_input_container(opts)
local hint_text = (fn.mode() ~= "i" and Config.mappings.submit.normal or Config.mappings.submit.insert) local hint_text = (fn.mode() ~= "i" and Config.mappings.submit.normal or Config.mappings.submit.insert)
.. ": submit" .. ": submit"
if Config.behaviour.enable_token_counting then local function show()
local input_value = table.concat(api.nvim_buf_get_lines(self.input_container.bufnr, 0, -1, false), "\n") local buf = api.nvim_create_buf(false, true)
local generate_prompts_options = get_generate_prompts_options(input_value) api.nvim_buf_set_lines(buf, 0, -1, false, { hint_text })
local tokens = Llm.calculate_tokens(generate_prompts_options) api.nvim_buf_add_highlight(buf, 0, "AvantePopupHint", 0, 0, -1)
hint_text = "Tokens: " .. tostring(tokens) .. "; " .. hint_text
-- Get the current window size
local win_width = api.nvim_win_get_width(self.input_container.winid)
local width = #hint_text
-- Set the floating window options
local win_opts = {
relative = "win",
win = self.input_container.winid,
width = width,
height = 1,
row = get_float_window_row(),
col = math.max(win_width - width, 0), -- Display in the bottom right corner
style = "minimal",
border = "none",
focusable = false,
zindex = 100,
}
-- Create the floating window
hint_window = api.nvim_open_win(buf, false, win_opts)
end end
local buf = api.nvim_create_buf(false, true) if Config.behaviour.enable_token_counting then
api.nvim_buf_set_lines(buf, 0, -1, false, { hint_text }) local input_value = table.concat(api.nvim_buf_get_lines(self.input_container.bufnr, 0, -1, false), "\n")
api.nvim_buf_add_highlight(buf, 0, "AvantePopupHint", 0, 0, -1) get_generate_prompts_options(input_value, false, function(generate_prompts_options)
local tokens = Llm.calculate_tokens(generate_prompts_options)
-- Get the current window size hint_text = "Tokens: " .. tostring(tokens) .. "; " .. hint_text
local win_width = api.nvim_win_get_width(self.input_container.winid) show()
local width = #hint_text end)
else
-- Set the floating window options show()
local win_opts = { end
relative = "win",
win = self.input_container.winid,
width = width,
height = 1,
row = get_float_window_row(),
col = math.max(win_width - width, 0), -- Display in the bottom right corner
style = "minimal",
border = "none",
focusable = false,
zindex = 100,
}
-- Create the floating window
hint_window = api.nvim_open_win(buf, false, win_opts)
end end
api.nvim_create_autocmd({ "TextChanged", "TextChangedI", "VimResized" }, { api.nvim_create_autocmd({ "TextChanged", "TextChangedI", "VimResized" }, {

View File

@@ -0,0 +1,5 @@
{%- if memory -%}
<memory>
{{memory}}
</memory>
{%- endif %}

View File

@@ -190,13 +190,9 @@ vim.g.avante_login = vim.g.avante_login
---@alias AvanteMessagesParser fun(self: AvanteProviderFunctor, opts: AvantePromptOptions): AvanteChatMessage[] ---@alias AvanteMessagesParser fun(self: AvanteProviderFunctor, opts: AvantePromptOptions): AvanteChatMessage[]
--- ---
---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table<string, any> | string, headers: table<string, string>, rawArgs: string[] | nil} ---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table<string, any> | string, headers: table<string, string>, rawArgs: string[] | nil}
---@alias AvanteCurlArgsParser fun(self: AvanteProviderFunctor, provider: AvanteProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions): AvanteCurlOutput ---@alias AvanteCurlArgsParser fun(self: AvanteProviderFunctor, prompt_opts: AvantePromptOptions): AvanteCurlOutput
--- ---
---@class AvanteResponseParserOptions ---@alias AvanteResponseParser fun(self: AvanteProviderFunctor, ctx: any, data_stream: string, event_state: string, opts: AvanteHandlerOptions): nil
---@field on_start AvanteLLMStartCallback
---@field on_chunk AvanteLLMChunkCallback
---@field on_stop AvanteLLMStopCallback
---@alias AvanteResponseParser fun(self: AvanteProviderFunctor, ctx: any, data_stream: string, event_state: string, opts: AvanteResponseParserOptions): nil
--- ---
---@class AvanteDefaultBaseProvider: table<string, any> ---@class AvanteDefaultBaseProvider: table<string, any>
---@field endpoint? string ---@field endpoint? string
@@ -305,6 +301,7 @@ vim.g.avante_login = vim.g.avante_login
---@field selected_files AvanteSelectedFiles[] | nil ---@field selected_files AvanteSelectedFiles[] | nil
---@field diagnostics string | nil ---@field diagnostics string | nil
---@field history_messages AvanteLLMMessage[] | nil ---@field history_messages AvanteLLMMessage[] | nil
---@field memory string | nil
--- ---
---@class AvanteGeneratePromptsOptions: AvanteTemplateOptions ---@class AvanteGeneratePromptsOptions: AvanteTemplateOptions
---@field ask boolean ---@field ask boolean
@@ -358,3 +355,35 @@ vim.g.avante_login = vim.g.avante_login
---@field description string ---@field description string
---@field type 'string' | 'string[]' | 'boolean' ---@field type 'string' | 'string[]' | 'boolean'
---@field optional? boolean ---@field optional? boolean
---
---@class avante.ChatHistoryEntry
---@field timestamp string
---@field provider string
---@field model string
---@field request string
---@field response string
---@field original_response string
---@field selected_file {filepath: string}?
---@field selected_code {filetype: string, content: string}?
---@field reset_memory boolean?
---@field selected_filepaths string[] | nil
---
---@class avante.ChatHistory
---@field title string
---@field timestamp string
---@field entries avante.ChatHistoryEntry[]
---@field memory avante.ChatMemory | nil
---
---@class avante.ChatMemory
---@field content string
---@field last_summarized_timestamp string
---@class avante.Path
---@field history_path Path
---@field cache_path Path
---
---@class avante.CurlOpts
---@field provider AvanteProviderFunctor
---@field prompt_opts AvantePromptOptions
---@field handler_opts AvanteHandlerOptions
---

View File

@@ -0,0 +1,52 @@
local Utils = require("avante.utils")
---@class avante.utils.history
local M = {}
---@param entries avante.ChatHistoryEntry[]
---@return avante.ChatHistoryEntry[]
function M.filter_active_entries(entries)
local entries_ = {}
for i = #entries, 1, -1 do
local entry = entries[i]
if entry.reset_memory then break end
if
entry.request == nil
or entry.original_response == nil
or entry.request == ""
or entry.original_response == ""
then
break
end
table.insert(entries_, 1, entry)
end
return entries_
end
---@param entries avante.ChatHistoryEntry[]
---@return AvanteLLMMessage[]
function M.entries_to_llm_messages(entries)
local messages = {}
for _, entry in ipairs(entries) do
local user_content = ""
if entry.selected_file ~= nil then
user_content = user_content .. "SELECTED FILE: " .. entry.selected_file.filepath .. "\n\n"
end
if entry.selected_code ~= nil then
user_content = user_content
.. "SELECTED CODE:\n\n```"
.. entry.selected_code.filetype
.. "\n"
.. entry.selected_code.content
.. "\n```\n\n"
end
user_content = user_content .. "USER PROMPT:\n\n" .. entry.request
table.insert(messages, { role = "user", content = user_content })
table.insert(messages, { role = "assistant", content = Utils.trim_think_content(entry.original_response) })
end
return messages
end
return M

View File

@@ -8,6 +8,7 @@ local lsp = vim.lsp
---@field tokens avante.utils.tokens ---@field tokens avante.utils.tokens
---@field root avante.utils.root ---@field root avante.utils.root
---@field file avante.utils.file ---@field file avante.utils.file
---@field history avante.utils.history
local M = {} local M = {}
setmetatable(M, { setmetatable(M, {