refactor: summarize memory (#1508)
This commit is contained in:
@@ -33,6 +33,7 @@ struct TemplateContext {
|
||||
diagnostics: Option<String>,
|
||||
system_info: Option<String>,
|
||||
model_name: Option<String>,
|
||||
memory: Option<String>,
|
||||
}
|
||||
|
||||
// Given the file name registered after add, the context table in Lua, resulted in a formatted
|
||||
@@ -58,6 +59,7 @@ fn render(state: &State, template: &str, context: TemplateContext) -> LuaResult<
|
||||
diagnostics => context.diagnostics,
|
||||
system_info => context.system_info,
|
||||
model_name => context.model_name,
|
||||
memory => context.memory,
|
||||
})
|
||||
.map_err(LuaError::external)
|
||||
.unwrap())
|
||||
|
||||
@@ -27,6 +27,7 @@ M._defaults = {
|
||||
-- Of course, you can reduce the request frequency by increasing `suggestion.debounce`.
|
||||
auto_suggestions_provider = "claude",
|
||||
cursor_applying_provider = nil,
|
||||
memory_summary_provider = nil,
|
||||
---@alias Tokenizer "tiktoken" | "hf"
|
||||
-- Used for counting tokens and encoding text.
|
||||
-- By default, we will use tiktoken.
|
||||
@@ -273,6 +274,10 @@ M._defaults = {
|
||||
temperature = 0,
|
||||
max_tokens = 8000,
|
||||
},
|
||||
["openai-gpt-4o-mini"] = {
|
||||
__inherited_from = "openai",
|
||||
model = "gpt-4o-mini",
|
||||
},
|
||||
},
|
||||
---Specify the special dual_boost mode
|
||||
---1. enabled: Whether to enable dual_boost mode. Default to false.
|
||||
|
||||
@@ -19,6 +19,71 @@ M.CANCEL_PATTERN = "AvanteLLMEscape"
|
||||
|
||||
local group = api.nvim_create_augroup("avante_llm", { clear = true })
|
||||
|
||||
---@param bufnr integer
|
||||
---@param history avante.ChatHistory
|
||||
---@param cb fun(memory: avante.ChatMemory | nil): nil
|
||||
function M.summarize_memory(bufnr, history, cb)
|
||||
local system_prompt =
|
||||
[[Summarize the following conversation to extract the most critical information (such as languages used, conversation style, tech stack, considerations, user information, etc.) for memory in subsequent conversations. Since it is for memory purposes, be detailed and rigorous to ensure that no information from previous summaries is lost in the newly generated summary.]]
|
||||
local entries = Utils.history.filter_active_entries(history.entries)
|
||||
if #entries == 0 then
|
||||
cb(nil)
|
||||
return
|
||||
end
|
||||
if history.memory then
|
||||
system_prompt = system_prompt .. "\n\nThe previous summary is:\n\n" .. history.memory.content
|
||||
entries = vim
|
||||
.iter(entries)
|
||||
:filter(function(entry) return entry.timestamp > history.memory.last_summarized_timestamp end)
|
||||
:totable()
|
||||
end
|
||||
if #entries == 0 then
|
||||
cb(history.memory)
|
||||
return
|
||||
end
|
||||
local history_messages = Utils.history.entries_to_llm_messages(entries)
|
||||
history_messages = vim.list_slice(history_messages, 1, 4)
|
||||
if #history_messages == 0 then
|
||||
cb(history.memory)
|
||||
return
|
||||
end
|
||||
Utils.debug("summarize memory", #history_messages, history_messages[#history_messages].content)
|
||||
local response_content = ""
|
||||
local provider = Providers[Config.memory_summary_provider or Config.provider]
|
||||
M.curl({
|
||||
provider = provider,
|
||||
prompt_opts = {
|
||||
system_prompt = system_prompt,
|
||||
messages = {
|
||||
{ role = "user", content = vim.json.encode(history_messages) },
|
||||
},
|
||||
},
|
||||
handler_opts = {
|
||||
on_start = function(_) end,
|
||||
on_chunk = function(chunk)
|
||||
if not chunk then return end
|
||||
response_content = response_content .. chunk
|
||||
end,
|
||||
on_stop = function(stop_opts)
|
||||
if stop_opts.error ~= nil then
|
||||
Utils.error(string.format("summarize failed: %s", vim.inspect(stop_opts.error)))
|
||||
return
|
||||
end
|
||||
if stop_opts.reason == "complete" then
|
||||
response_content = Utils.trim_think_content(response_content)
|
||||
local memory = {
|
||||
content = response_content,
|
||||
last_summarized_timestamp = entries[#entries].timestamp,
|
||||
}
|
||||
history.memory = memory
|
||||
Path.history.save(bufnr, history)
|
||||
cb(memory)
|
||||
end
|
||||
end,
|
||||
},
|
||||
})
|
||||
end
|
||||
|
||||
---@param opts AvanteGeneratePromptsOptions
|
||||
---@return AvantePromptOptions
|
||||
function M.generate_prompts(opts)
|
||||
@@ -58,6 +123,7 @@ function M.generate_prompts(opts)
|
||||
diagnostics = opts.diagnostics,
|
||||
system_info = system_info,
|
||||
model_name = provider.model or "unknown",
|
||||
memory = opts.memory,
|
||||
}
|
||||
|
||||
local system_prompt = Path.prompts.render_mode(mode, template_opts)
|
||||
@@ -80,6 +146,11 @@ function M.generate_prompts(opts)
|
||||
if code_context ~= "" then table.insert(messages, { role = "user", content = code_context }) end
|
||||
end
|
||||
|
||||
if opts.memory ~= nil and opts.memory ~= "" and opts.memory ~= "null" then
|
||||
local memory = Path.prompts.render_file("_memory.avanterules", template_opts)
|
||||
if memory ~= "" then table.insert(messages, { role = "user", content = memory }) end
|
||||
end
|
||||
|
||||
if instructions then
|
||||
if opts.use_xml_format then
|
||||
table.insert(messages, { role = "user", content = string.format("<question>%s</question>", instructions) })
|
||||
@@ -150,90 +221,17 @@ function M.calculate_tokens(opts)
|
||||
return tokens
|
||||
end
|
||||
|
||||
---@param opts AvanteLLMStreamOptions
|
||||
function M._stream(opts)
|
||||
local provider = opts.provider or Providers[Config.provider]
|
||||
---@param opts avante.CurlOpts
|
||||
function M.curl(opts)
|
||||
local provider = opts.provider
|
||||
local prompt_opts = opts.prompt_opts
|
||||
local handler_opts = opts.handler_opts
|
||||
|
||||
---@cast provider AvanteProviderFunctor
|
||||
|
||||
local prompt_opts = M.generate_prompts(opts)
|
||||
---@type AvanteCurlOutput
|
||||
local spec = provider:parse_curl_args(prompt_opts)
|
||||
|
||||
---@type string
|
||||
local current_event_state = nil
|
||||
|
||||
---@type AvanteHandlerOptions
|
||||
local handler_opts = {
|
||||
on_start = opts.on_start,
|
||||
on_chunk = opts.on_chunk,
|
||||
on_stop = function(stop_opts)
|
||||
---@param tool_use_list AvanteLLMToolUse[]
|
||||
---@param tool_use_index integer
|
||||
---@param tool_histories AvanteLLMToolHistory[]
|
||||
local function handle_next_tool_use(tool_use_list, tool_use_index, tool_histories)
|
||||
if tool_use_index > #tool_use_list then
|
||||
local new_opts = vim.tbl_deep_extend("force", opts, {
|
||||
tool_histories = tool_histories,
|
||||
})
|
||||
return M._stream(new_opts)
|
||||
end
|
||||
local tool_use = tool_use_list[tool_use_index]
|
||||
---@param result string | nil
|
||||
---@param error string | nil
|
||||
local function handle_tool_result(result, error)
|
||||
local tool_result = {
|
||||
tool_use_id = tool_use.id,
|
||||
content = error ~= nil and error or result,
|
||||
is_error = error ~= nil,
|
||||
}
|
||||
table.insert(tool_histories, { tool_result = tool_result, tool_use = tool_use })
|
||||
return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_histories)
|
||||
end
|
||||
-- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil
|
||||
local result, error = LLMTools.process_tool_use(opts.tools, tool_use, opts.on_tool_log, handle_tool_result)
|
||||
if result ~= nil or error ~= nil then return handle_tool_result(result, error) end
|
||||
end
|
||||
if stop_opts.reason == "tool_use" and stop_opts.tool_use_list then
|
||||
local old_tool_histories = vim.deepcopy(opts.tool_histories) or {}
|
||||
local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[]
|
||||
for _, tool_use in vim.spairs(stop_opts.tool_use_list) do
|
||||
table.insert(sorted_tool_use_list, tool_use)
|
||||
end
|
||||
return handle_next_tool_use(sorted_tool_use_list, 1, old_tool_histories)
|
||||
end
|
||||
if stop_opts.reason == "rate_limit" then
|
||||
local msg = "Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ..."
|
||||
opts.on_chunk("\n*[" .. msg .. "]*\n")
|
||||
local timer = vim.loop.new_timer()
|
||||
if timer then
|
||||
local retry_after = stop_opts.retry_after
|
||||
local function countdown()
|
||||
timer:start(
|
||||
1000,
|
||||
0,
|
||||
vim.schedule_wrap(function()
|
||||
if retry_after > 0 then retry_after = retry_after - 1 end
|
||||
local msg_ = "Rate limit reached. Retrying in " .. retry_after .. " seconds ..."
|
||||
opts.on_chunk([[\033[1A\033[K]] .. "\n*[" .. msg_ .. "]*\n")
|
||||
countdown()
|
||||
end)
|
||||
)
|
||||
end
|
||||
countdown()
|
||||
end
|
||||
Utils.info("Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds", { title = "Avante" })
|
||||
vim.defer_fn(function()
|
||||
if timer then timer:stop() end
|
||||
M._stream(opts)
|
||||
end, stop_opts.retry_after * 1000)
|
||||
return
|
||||
end
|
||||
return opts.on_stop(stop_opts)
|
||||
end,
|
||||
}
|
||||
|
||||
---@type AvanteCurlOutput
|
||||
local spec = provider:parse_curl_args(provider, prompt_opts)
|
||||
|
||||
local resp_ctx = {}
|
||||
|
||||
---@param line string
|
||||
@@ -383,6 +381,91 @@ function M._stream(opts)
|
||||
return active_job
|
||||
end
|
||||
|
||||
---@param opts AvanteLLMStreamOptions
|
||||
function M._stream(opts)
|
||||
local provider = opts.provider or Providers[Config.provider]
|
||||
|
||||
---@cast provider AvanteProviderFunctor
|
||||
|
||||
local prompt_opts = M.generate_prompts(opts)
|
||||
|
||||
---@type AvanteHandlerOptions
|
||||
local handler_opts = {
|
||||
on_start = opts.on_start,
|
||||
on_chunk = opts.on_chunk,
|
||||
on_stop = function(stop_opts)
|
||||
---@param tool_use_list AvanteLLMToolUse[]
|
||||
---@param tool_use_index integer
|
||||
---@param tool_histories AvanteLLMToolHistory[]
|
||||
local function handle_next_tool_use(tool_use_list, tool_use_index, tool_histories)
|
||||
if tool_use_index > #tool_use_list then
|
||||
local new_opts = vim.tbl_deep_extend("force", opts, {
|
||||
tool_histories = tool_histories,
|
||||
})
|
||||
return M._stream(new_opts)
|
||||
end
|
||||
local tool_use = tool_use_list[tool_use_index]
|
||||
---@param result string | nil
|
||||
---@param error string | nil
|
||||
local function handle_tool_result(result, error)
|
||||
local tool_result = {
|
||||
tool_use_id = tool_use.id,
|
||||
content = error ~= nil and error or result,
|
||||
is_error = error ~= nil,
|
||||
}
|
||||
table.insert(tool_histories, { tool_result = tool_result, tool_use = tool_use })
|
||||
return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_histories)
|
||||
end
|
||||
-- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil
|
||||
local result, error = LLMTools.process_tool_use(opts.tools, tool_use, opts.on_tool_log, handle_tool_result)
|
||||
if result ~= nil or error ~= nil then return handle_tool_result(result, error) end
|
||||
end
|
||||
if stop_opts.reason == "tool_use" and stop_opts.tool_use_list then
|
||||
local old_tool_histories = vim.deepcopy(opts.tool_histories) or {}
|
||||
local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[]
|
||||
for _, tool_use in vim.spairs(stop_opts.tool_use_list) do
|
||||
table.insert(sorted_tool_use_list, tool_use)
|
||||
end
|
||||
return handle_next_tool_use(sorted_tool_use_list, 1, old_tool_histories)
|
||||
end
|
||||
if stop_opts.reason == "rate_limit" then
|
||||
local msg = "Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ..."
|
||||
opts.on_chunk("\n*[" .. msg .. "]*\n")
|
||||
local timer = vim.loop.new_timer()
|
||||
if timer then
|
||||
local retry_after = stop_opts.retry_after
|
||||
local function countdown()
|
||||
timer:start(
|
||||
1000,
|
||||
0,
|
||||
vim.schedule_wrap(function()
|
||||
if retry_after > 0 then retry_after = retry_after - 1 end
|
||||
local msg_ = "Rate limit reached. Retrying in " .. retry_after .. " seconds ..."
|
||||
opts.on_chunk([[\033[1A\033[K]] .. "\n*[" .. msg_ .. "]*\n")
|
||||
countdown()
|
||||
end)
|
||||
)
|
||||
end
|
||||
countdown()
|
||||
end
|
||||
Utils.info("Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds", { title = "Avante" })
|
||||
vim.defer_fn(function()
|
||||
if timer then timer:stop() end
|
||||
M._stream(opts)
|
||||
end, stop_opts.retry_after * 1000)
|
||||
return
|
||||
end
|
||||
return opts.on_stop(stop_opts)
|
||||
end,
|
||||
}
|
||||
|
||||
return M.curl({
|
||||
provider = provider,
|
||||
prompt_opts = prompt_opts,
|
||||
handler_opts = handler_opts,
|
||||
})
|
||||
end
|
||||
|
||||
local function _merge_response(first_response, second_response, opts)
|
||||
local prompt = "\n" .. Config.dual_boost.prompt
|
||||
prompt = prompt
|
||||
|
||||
@@ -5,58 +5,75 @@ local Path = require("plenary.path")
|
||||
local Scan = require("plenary.scandir")
|
||||
local Config = require("avante.config")
|
||||
|
||||
---@class avante.ChatHistoryEntry
|
||||
---@field timestamp string
|
||||
---@field provider string
|
||||
---@field model string
|
||||
---@field request string
|
||||
---@field response string
|
||||
---@field original_response string
|
||||
---@field selected_file {filepath: string}?
|
||||
---@field selected_code {filetype: string, content: string}?
|
||||
---@field reset_memory boolean?
|
||||
---@field selected_filepaths string[] | nil
|
||||
|
||||
---@class avante.Path
|
||||
---@field history_path Path
|
||||
---@field cache_path Path
|
||||
local P = {}
|
||||
|
||||
local history_file_cache = LRUCache:new(12)
|
||||
|
||||
-- History path
|
||||
local History = {}
|
||||
|
||||
-- Get a chat history file name given a buffer
|
||||
---@param bufnr integer
|
||||
---@return string
|
||||
function History.filename(bufnr)
|
||||
---@param bufnr integer | nil
|
||||
---@return string dirname
|
||||
local function generate_project_dirname_in_storage(bufnr)
|
||||
local project_root = Utils.root.get({
|
||||
buf = bufnr,
|
||||
})
|
||||
-- Replace path separators with double underscores
|
||||
local path_with_separators = fn.substitute(project_root, "/", "__", "g")
|
||||
-- Replace other non-alphanumeric characters with single underscores
|
||||
return fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g") .. ".json"
|
||||
local dirname = fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g")
|
||||
return tostring(Path:new("projects"):joinpath(dirname))
|
||||
end
|
||||
|
||||
-- History path
|
||||
local History = {}
|
||||
|
||||
-- Get a chat history file name given a buffer
|
||||
---@param bufnr integer
|
||||
---@param new boolean
|
||||
---@return Path
|
||||
function History.filepath(bufnr, new)
|
||||
local dirname = generate_project_dirname_in_storage(bufnr)
|
||||
local history_dir = Path:new(Config.history.storage_path):joinpath(dirname):joinpath("history")
|
||||
if not history_dir:exists() then history_dir:mkdir({ parents = true }) end
|
||||
local pattern = tostring(history_dir:joinpath("*.json"))
|
||||
local files = vim.fn.glob(pattern, true, true)
|
||||
local filename = #files .. ".json"
|
||||
if #files > 0 and not new then filename = (#files - 1) .. ".json" end
|
||||
return history_dir:joinpath(filename)
|
||||
end
|
||||
|
||||
-- Returns the Path to the chat history file for the given buffer.
|
||||
---@param bufnr integer
|
||||
---@return Path
|
||||
function History.get(bufnr) return Path:new(Config.history.storage_path):joinpath(History.filename(bufnr)) end
|
||||
function History.get(bufnr) return History.filepath(bufnr, false) end
|
||||
|
||||
---@param bufnr integer
|
||||
function History.new(bufnr)
|
||||
local filepath = History.filepath(bufnr, true)
|
||||
local history = {
|
||||
title = "untitled",
|
||||
timestamp = tostring(os.date("%Y-%m-%d %H:%M:%S")),
|
||||
entries = {},
|
||||
}
|
||||
filepath:write(vim.json.encode(history), "w")
|
||||
end
|
||||
|
||||
-- Loads the chat history for the given buffer.
|
||||
---@param bufnr integer
|
||||
---@return avante.ChatHistoryEntry[]
|
||||
---@return avante.ChatHistory
|
||||
function History.load(bufnr)
|
||||
local history_file = History.get(bufnr)
|
||||
local cached_key = tostring(history_file:absolute())
|
||||
local cached_value = history_file_cache:get(cached_key)
|
||||
if cached_value ~= nil then return cached_value end
|
||||
local value = {}
|
||||
---@type avante.ChatHistory
|
||||
local value = {
|
||||
title = "untitled",
|
||||
timestamp = tostring(os.date("%Y-%m-%d %H:%M:%S")),
|
||||
entries = {},
|
||||
}
|
||||
if history_file:exists() then
|
||||
local content = history_file:read()
|
||||
value = content ~= nil and vim.json.decode(content) or {}
|
||||
value = content ~= nil and vim.json.decode(content) or value
|
||||
end
|
||||
history_file_cache:set(cached_key, value)
|
||||
return value
|
||||
@@ -64,7 +81,7 @@ end
|
||||
|
||||
-- Saves the chat history for the given buffer.
|
||||
---@param bufnr integer
|
||||
---@param history avante.ChatHistoryEntry[]
|
||||
---@param history avante.ChatHistory
|
||||
History.save = vim.schedule_wrap(function(bufnr, history)
|
||||
local history_file = History.get(bufnr)
|
||||
local cached_key = tostring(history_file:absolute())
|
||||
|
||||
@@ -18,8 +18,8 @@ M.parse_response = O.parse_response
|
||||
M.parse_response_without_stream = O.parse_response_without_stream
|
||||
M.is_disable_stream = O.is_disable_stream
|
||||
|
||||
function M:parse_curl_args(provider, prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(provider)
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(self)
|
||||
|
||||
local headers = {
|
||||
["Content-Type"] = "application/json",
|
||||
@@ -27,9 +27,9 @@ function M:parse_curl_args(provider, prompt_opts)
|
||||
|
||||
if P.env.require_api_key(provider_conf) then
|
||||
if provider_conf.entra then
|
||||
headers["Authorization"] = "Bearer " .. provider.parse_api_key()
|
||||
headers["Authorization"] = "Bearer " .. self.parse_api_key()
|
||||
else
|
||||
headers["api-key"] = provider.parse_api_key()
|
||||
headers["api-key"] = self.parse_api_key()
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
@@ -51,13 +51,12 @@ function M:parse_response_without_stream(data, event_state, opts)
|
||||
vim.schedule(function() opts.on_stop({ reason = "complete" }) end)
|
||||
end
|
||||
|
||||
---@param provider AvanteBedrockProviderFunctor
|
||||
---@param prompt_opts AvantePromptOptions
|
||||
---@return table
|
||||
function M:parse_curl_args(provider, prompt_opts)
|
||||
local base, body_opts = P.parse_config(provider)
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local base, body_opts = P.parse_config(self)
|
||||
|
||||
local api_key = provider.parse_api_key()
|
||||
local api_key = self.parse_api_key()
|
||||
if api_key == nil then error("Cannot get the bedrock api key!") end
|
||||
local parts = vim.split(api_key, ",")
|
||||
local aws_access_key_id = parts[1]
|
||||
|
||||
@@ -262,11 +262,10 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
end
|
||||
end
|
||||
|
||||
---@param provider AvanteProviderFunctor
|
||||
---@param prompt_opts AvantePromptOptions
|
||||
---@return table
|
||||
function M:parse_curl_args(provider, prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(provider)
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(self)
|
||||
local disable_tools = provider_conf.disable_tools or false
|
||||
|
||||
local headers = {
|
||||
@@ -275,7 +274,7 @@ function M:parse_curl_args(provider, prompt_opts)
|
||||
["anthropic-beta"] = "prompt-caching-2024-07-31",
|
||||
}
|
||||
|
||||
if P.env.require_api_key(provider_conf) then headers["x-api-key"] = provider.parse_api_key() end
|
||||
if P.env.require_api_key(provider_conf) then headers["x-api-key"] = self.parse_api_key() end
|
||||
|
||||
local messages = self:parse_messages(prompt_opts)
|
||||
|
||||
|
||||
@@ -71,8 +71,8 @@ function M:parse_stream_data(ctx, data, opts)
|
||||
end
|
||||
end
|
||||
|
||||
function M:parse_curl_args(provider, prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(provider)
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(self)
|
||||
|
||||
local headers = {
|
||||
["Accept"] = "application/json",
|
||||
@@ -84,7 +84,7 @@ function M:parse_curl_args(provider, prompt_opts)
|
||||
.. "."
|
||||
.. vim.version().patch,
|
||||
}
|
||||
if P.env.require_api_key(provider_conf) then headers["Authorization"] = "Bearer " .. provider.parse_api_key() end
|
||||
if P.env.require_api_key(provider_conf) then headers["Authorization"] = "Bearer " .. self.parse_api_key() end
|
||||
|
||||
return {
|
||||
url = Utils.url_join(provider_conf.endpoint, "/chat"),
|
||||
|
||||
@@ -215,12 +215,12 @@ M.parse_messages = OpenAI.parse_messages
|
||||
|
||||
M.parse_response = OpenAI.parse_response
|
||||
|
||||
function M:parse_curl_args(provider, prompt_opts)
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
-- refresh token synchronously, only if it has expired
|
||||
-- (this should rarely happen, as we refresh the token in the background)
|
||||
H.refresh_token(false, false)
|
||||
|
||||
local provider_conf, request_body = P.parse_config(provider)
|
||||
local provider_conf, request_body = P.parse_config(self)
|
||||
local disable_tools = provider_conf.disable_tools or false
|
||||
|
||||
local tools = {}
|
||||
|
||||
@@ -83,8 +83,8 @@ function M:parse_response(ctx, data_stream, _, opts)
|
||||
end
|
||||
end
|
||||
|
||||
function M:parse_curl_args(provider, prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(provider)
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(self)
|
||||
|
||||
request_body = vim.tbl_deep_extend("force", request_body, {
|
||||
generationConfig = {
|
||||
@@ -95,7 +95,7 @@ function M:parse_curl_args(provider, prompt_opts)
|
||||
request_body.temperature = nil
|
||||
request_body.max_tokens = nil
|
||||
|
||||
local api_key = provider.parse_api_key()
|
||||
local api_key = self.parse_api_key()
|
||||
if api_key == nil then error("Cannot get the gemini api key!") end
|
||||
|
||||
return {
|
||||
|
||||
@@ -261,6 +261,13 @@ function M.setup()
|
||||
E.setup({ provider = cursor_applying_provider })
|
||||
end
|
||||
end
|
||||
|
||||
if Config.memory_summary_provider then
|
||||
local memory_summary_provider = M[Config.memory_summary_provider]
|
||||
if memory_summary_provider and memory_summary_provider ~= provider then
|
||||
E.setup({ provider = memory_summary_provider })
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
---@param provider Provider
|
||||
|
||||
@@ -219,8 +219,8 @@ function M:parse_response_without_stream(data, _, opts)
|
||||
end
|
||||
end
|
||||
|
||||
function M:parse_curl_args(provider, prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(provider)
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(self)
|
||||
local disable_tools = provider_conf.disable_tools or false
|
||||
|
||||
local headers = {
|
||||
@@ -228,7 +228,7 @@ function M:parse_curl_args(provider, prompt_opts)
|
||||
}
|
||||
|
||||
if P.env.require_api_key(provider_conf) then
|
||||
local api_key = provider.parse_api_key()
|
||||
local api_key = self.parse_api_key()
|
||||
if api_key == nil then
|
||||
error(Config.provider .. " API key is not set, please set it in your environment variable or config file")
|
||||
end
|
||||
|
||||
@@ -32,8 +32,8 @@ function M.parse_api_key()
|
||||
return direct_output
|
||||
end
|
||||
|
||||
function M:parse_curl_args(provider, prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(provider)
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(self)
|
||||
local location = vim.fn.getenv("LOCATION")
|
||||
local project_id = vim.fn.getenv("PROJECT_ID")
|
||||
local model_id = provider_conf.model or "default-model-id"
|
||||
|
||||
@@ -2073,11 +2073,11 @@ function Sidebar:get_layout()
|
||||
return vim.tbl_contains({ "left", "right" }, calculate_config_window_position()) and "vertical" or "horizontal"
|
||||
end
|
||||
|
||||
---@param history avante.ChatHistoryEntry[]
|
||||
---@param history avante.ChatHistory
|
||||
---@return string
|
||||
function Sidebar:render_history_content(history)
|
||||
local content = ""
|
||||
for idx, entry in ipairs(history) do
|
||||
for idx, entry in ipairs(history.entries) do
|
||||
if entry.reset_memory then
|
||||
content = content .. "***MEMORY RESET***\n\n"
|
||||
if idx < #history then content = content .. "-------\n\n" end
|
||||
@@ -2146,11 +2146,11 @@ end
|
||||
function Sidebar:clear_history(args, cb)
|
||||
local chat_history = Path.history.load(self.code.bufnr)
|
||||
if next(chat_history) ~= nil then
|
||||
chat_history = {}
|
||||
chat_history.entries = {}
|
||||
Path.history.save(self.code.bufnr, chat_history)
|
||||
self:update_content(
|
||||
"Chat history cleared",
|
||||
{ focus = false, scroll = false, callback = function() self:focus_input() end }
|
||||
{ ignore_history = true, focus = false, scroll = false, callback = function() self:focus_input() end }
|
||||
)
|
||||
if cb then cb(args) end
|
||||
else
|
||||
@@ -2161,6 +2161,16 @@ function Sidebar:clear_history(args, cb)
|
||||
end
|
||||
end
|
||||
|
||||
function Sidebar:new_chat(args, cb)
|
||||
Path.history.new(self.code.bufnr)
|
||||
Sidebar.reload_chat_history()
|
||||
self:update_content(
|
||||
"New chat",
|
||||
{ ignore_history = true, focus = false, scroll = false, callback = function() self:focus_input() end }
|
||||
)
|
||||
if cb then cb(args) end
|
||||
end
|
||||
|
||||
function Sidebar:reset_memory(args, cb)
|
||||
local chat_history = Path.history.load(self.code.bufnr)
|
||||
if next(chat_history) ~= nil then
|
||||
@@ -2176,6 +2186,7 @@ function Sidebar:reset_memory(args, cb)
|
||||
reset_memory = true,
|
||||
})
|
||||
Path.history.save(self.code.bufnr, chat_history)
|
||||
Sidebar.reload_chat_history()
|
||||
local history_content = self:render_history_content(chat_history)
|
||||
self:update_content(history_content, {
|
||||
focus = false,
|
||||
@@ -2184,6 +2195,7 @@ function Sidebar:reset_memory(args, cb)
|
||||
})
|
||||
if cb then cb(args) end
|
||||
else
|
||||
Sidebar.reload_chat_history()
|
||||
self:update_content(
|
||||
"Chat history is already empty",
|
||||
{ focus = false, scroll = false, callback = function() self:focus_input() end }
|
||||
@@ -2191,7 +2203,7 @@ function Sidebar:reset_memory(args, cb)
|
||||
end
|
||||
end
|
||||
|
||||
---@alias AvanteSlashCommandType "clear" | "help" | "lines" | "reset" | "commit"
|
||||
---@alias AvanteSlashCommandType "clear" | "help" | "lines" | "reset" | "commit" | "new"
|
||||
---@alias AvanteSlashCommandCallback fun(args: string, cb?: fun(args: string): nil): nil
|
||||
---@alias AvanteSlashCommand {description: string, command: AvanteSlashCommandType, details: string, shorthelp?: string, callback?: AvanteSlashCommandCallback}
|
||||
---@return AvanteSlashCommand[]
|
||||
@@ -2211,6 +2223,7 @@ function Sidebar:get_commands()
|
||||
{ description = "Show help message", command = "help" },
|
||||
{ description = "Clear chat history", command = "clear" },
|
||||
{ description = "Reset memory", command = "reset" },
|
||||
{ description = "New chat", command = "new" },
|
||||
{
|
||||
shorthelp = "Ask a question about specific lines",
|
||||
description = "/lines <start>-<end> <question>",
|
||||
@@ -2228,6 +2241,7 @@ function Sidebar:get_commands()
|
||||
end,
|
||||
clear = function(args, cb) self:clear_history(args, cb) end,
|
||||
reset = function(args, cb) self:reset_memory(args, cb) end,
|
||||
new = function(args, cb) self:new_chat(args, cb) end,
|
||||
lines = function(args, cb)
|
||||
if cb then cb(args) end
|
||||
end,
|
||||
@@ -2298,6 +2312,8 @@ function Sidebar:create_input_container(opts)
|
||||
|
||||
local chat_history = Path.history.load(self.code.bufnr)
|
||||
|
||||
Sidebar.reload_chat_history = function() chat_history = Path.history.load(self.code.bufnr) end
|
||||
|
||||
local tools = vim.deepcopy(LLMTools.get_tools())
|
||||
table.insert(tools, {
|
||||
name = "add_file_to_context",
|
||||
@@ -2330,8 +2346,9 @@ function Sidebar:create_input_container(opts)
|
||||
})
|
||||
|
||||
---@param request string
|
||||
---@return AvanteGeneratePromptsOptions
|
||||
local function get_generate_prompts_options(request)
|
||||
---@param summarize_memory boolean
|
||||
---@param cb fun(opts: AvanteGeneratePromptsOptions): nil
|
||||
local function get_generate_prompts_options(request, summarize_memory, cb)
|
||||
local filetype = api.nvim_get_option_value("filetype", { buf = self.code.bufnr })
|
||||
|
||||
local selected_code_content = nil
|
||||
@@ -2355,40 +2372,19 @@ function Sidebar:create_input_container(opts)
|
||||
end
|
||||
end
|
||||
|
||||
local history_messages = {}
|
||||
for i = #chat_history, 1, -1 do
|
||||
local entry = chat_history[i]
|
||||
if entry.reset_memory then break end
|
||||
if
|
||||
entry.request == nil
|
||||
or entry.original_response == nil
|
||||
or entry.request == ""
|
||||
or entry.original_response == ""
|
||||
then
|
||||
break
|
||||
end
|
||||
table.insert(
|
||||
history_messages,
|
||||
1,
|
||||
{ role = "assistant", content = Utils.trim_think_content(entry.original_response) }
|
||||
)
|
||||
local user_content = ""
|
||||
if entry.selected_file ~= nil then
|
||||
user_content = user_content .. "SELECTED FILE: " .. entry.selected_file.filepath .. "\n\n"
|
||||
end
|
||||
if entry.selected_code ~= nil then
|
||||
user_content = user_content
|
||||
.. "SELECTED CODE:\n\n```"
|
||||
.. entry.selected_code.filetype
|
||||
.. "\n"
|
||||
.. entry.selected_code.content
|
||||
.. "\n```\n\n"
|
||||
end
|
||||
user_content = user_content .. "USER PROMPT:\n\n" .. entry.request
|
||||
table.insert(history_messages, 1, { role = "user", content = user_content })
|
||||
local entries = Utils.history.filter_active_entries(chat_history.entries)
|
||||
|
||||
if chat_history.memory then
|
||||
entries = vim
|
||||
.iter(entries)
|
||||
:filter(function(entry) return entry.timestamp > chat_history.memory.last_summarized_timestamp end)
|
||||
:totable()
|
||||
end
|
||||
|
||||
return {
|
||||
local history_messages = Utils.history.entries_to_llm_messages(entries)
|
||||
|
||||
---@type AvanteGeneratePromptsOptions
|
||||
local prompts_opts = {
|
||||
ask = opts.ask or true,
|
||||
project_context = vim.json.encode(project_context),
|
||||
selected_files = selected_files_contents,
|
||||
@@ -2400,6 +2396,20 @@ function Sidebar:create_input_container(opts)
|
||||
mode = Config.behaviour.enable_cursor_planning_mode and "cursor-planning" or "planning",
|
||||
tools = tools,
|
||||
}
|
||||
|
||||
if chat_history.memory then prompts_opts.memory = chat_history.memory.content end
|
||||
|
||||
if not summarize_memory or #history_messages < 8 then
|
||||
cb(prompts_opts)
|
||||
return
|
||||
end
|
||||
|
||||
prompts_opts.history_messages = vim.list_slice(prompts_opts.history_messages, 5)
|
||||
|
||||
Llm.summarize_memory(self.code.bufnr, chat_history, function(memory)
|
||||
if memory then prompts_opts.memory = memory.content end
|
||||
cb(prompts_opts)
|
||||
end)
|
||||
end
|
||||
|
||||
---@param request string
|
||||
@@ -2584,7 +2594,8 @@ function Sidebar:create_input_container(opts)
|
||||
end, 0)
|
||||
|
||||
-- Save chat history
|
||||
table.insert(chat_history or {}, {
|
||||
chat_history.entries = chat_history.entries or {}
|
||||
table.insert(chat_history.entries, {
|
||||
timestamp = timestamp,
|
||||
provider = Config.provider,
|
||||
model = model,
|
||||
@@ -2597,17 +2608,18 @@ function Sidebar:create_input_container(opts)
|
||||
Path.history.save(self.code.bufnr, chat_history)
|
||||
end
|
||||
|
||||
local generate_prompts_options = get_generate_prompts_options(request)
|
||||
---@type AvanteLLMStreamOptions
|
||||
---@diagnostic disable-next-line: assign-type-mismatch
|
||||
local stream_options = vim.tbl_deep_extend("force", generate_prompts_options, {
|
||||
on_start = on_start,
|
||||
on_chunk = on_chunk,
|
||||
on_stop = on_stop,
|
||||
on_tool_log = on_tool_log,
|
||||
})
|
||||
get_generate_prompts_options(request, true, function(generate_prompts_options)
|
||||
---@type AvanteLLMStreamOptions
|
||||
---@diagnostic disable-next-line: assign-type-mismatch
|
||||
local stream_options = vim.tbl_deep_extend("force", generate_prompts_options, {
|
||||
on_start = on_start,
|
||||
on_chunk = on_chunk,
|
||||
on_stop = on_stop,
|
||||
on_tool_log = on_tool_log,
|
||||
})
|
||||
|
||||
Llm.stream(stream_options)
|
||||
Llm.stream(stream_options)
|
||||
end)
|
||||
end
|
||||
|
||||
local function get_position()
|
||||
@@ -2762,37 +2774,43 @@ function Sidebar:create_input_container(opts)
|
||||
local hint_text = (fn.mode() ~= "i" and Config.mappings.submit.normal or Config.mappings.submit.insert)
|
||||
.. ": submit"
|
||||
|
||||
if Config.behaviour.enable_token_counting then
|
||||
local input_value = table.concat(api.nvim_buf_get_lines(self.input_container.bufnr, 0, -1, false), "\n")
|
||||
local generate_prompts_options = get_generate_prompts_options(input_value)
|
||||
local tokens = Llm.calculate_tokens(generate_prompts_options)
|
||||
hint_text = "Tokens: " .. tostring(tokens) .. "; " .. hint_text
|
||||
local function show()
|
||||
local buf = api.nvim_create_buf(false, true)
|
||||
api.nvim_buf_set_lines(buf, 0, -1, false, { hint_text })
|
||||
api.nvim_buf_add_highlight(buf, 0, "AvantePopupHint", 0, 0, -1)
|
||||
|
||||
-- Get the current window size
|
||||
local win_width = api.nvim_win_get_width(self.input_container.winid)
|
||||
local width = #hint_text
|
||||
|
||||
-- Set the floating window options
|
||||
local win_opts = {
|
||||
relative = "win",
|
||||
win = self.input_container.winid,
|
||||
width = width,
|
||||
height = 1,
|
||||
row = get_float_window_row(),
|
||||
col = math.max(win_width - width, 0), -- Display in the bottom right corner
|
||||
style = "minimal",
|
||||
border = "none",
|
||||
focusable = false,
|
||||
zindex = 100,
|
||||
}
|
||||
|
||||
-- Create the floating window
|
||||
hint_window = api.nvim_open_win(buf, false, win_opts)
|
||||
end
|
||||
|
||||
local buf = api.nvim_create_buf(false, true)
|
||||
api.nvim_buf_set_lines(buf, 0, -1, false, { hint_text })
|
||||
api.nvim_buf_add_highlight(buf, 0, "AvantePopupHint", 0, 0, -1)
|
||||
|
||||
-- Get the current window size
|
||||
local win_width = api.nvim_win_get_width(self.input_container.winid)
|
||||
local width = #hint_text
|
||||
|
||||
-- Set the floating window options
|
||||
local win_opts = {
|
||||
relative = "win",
|
||||
win = self.input_container.winid,
|
||||
width = width,
|
||||
height = 1,
|
||||
row = get_float_window_row(),
|
||||
col = math.max(win_width - width, 0), -- Display in the bottom right corner
|
||||
style = "minimal",
|
||||
border = "none",
|
||||
focusable = false,
|
||||
zindex = 100,
|
||||
}
|
||||
|
||||
-- Create the floating window
|
||||
hint_window = api.nvim_open_win(buf, false, win_opts)
|
||||
if Config.behaviour.enable_token_counting then
|
||||
local input_value = table.concat(api.nvim_buf_get_lines(self.input_container.bufnr, 0, -1, false), "\n")
|
||||
get_generate_prompts_options(input_value, false, function(generate_prompts_options)
|
||||
local tokens = Llm.calculate_tokens(generate_prompts_options)
|
||||
hint_text = "Tokens: " .. tostring(tokens) .. "; " .. hint_text
|
||||
show()
|
||||
end)
|
||||
else
|
||||
show()
|
||||
end
|
||||
end
|
||||
|
||||
api.nvim_create_autocmd({ "TextChanged", "TextChangedI", "VimResized" }, {
|
||||
|
||||
5
lua/avante/templates/_memory.avanterules
Normal file
5
lua/avante/templates/_memory.avanterules
Normal file
@@ -0,0 +1,5 @@
|
||||
{%- if memory -%}
|
||||
<memory>
|
||||
{{memory}}
|
||||
</memory>
|
||||
{%- endif %}
|
||||
@@ -190,13 +190,9 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@alias AvanteMessagesParser fun(self: AvanteProviderFunctor, opts: AvantePromptOptions): AvanteChatMessage[]
|
||||
---
|
||||
---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table<string, any> | string, headers: table<string, string>, rawArgs: string[] | nil}
|
||||
---@alias AvanteCurlArgsParser fun(self: AvanteProviderFunctor, provider: AvanteProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions): AvanteCurlOutput
|
||||
---@alias AvanteCurlArgsParser fun(self: AvanteProviderFunctor, prompt_opts: AvantePromptOptions): AvanteCurlOutput
|
||||
---
|
||||
---@class AvanteResponseParserOptions
|
||||
---@field on_start AvanteLLMStartCallback
|
||||
---@field on_chunk AvanteLLMChunkCallback
|
||||
---@field on_stop AvanteLLMStopCallback
|
||||
---@alias AvanteResponseParser fun(self: AvanteProviderFunctor, ctx: any, data_stream: string, event_state: string, opts: AvanteResponseParserOptions): nil
|
||||
---@alias AvanteResponseParser fun(self: AvanteProviderFunctor, ctx: any, data_stream: string, event_state: string, opts: AvanteHandlerOptions): nil
|
||||
---
|
||||
---@class AvanteDefaultBaseProvider: table<string, any>
|
||||
---@field endpoint? string
|
||||
@@ -305,6 +301,7 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field selected_files AvanteSelectedFiles[] | nil
|
||||
---@field diagnostics string | nil
|
||||
---@field history_messages AvanteLLMMessage[] | nil
|
||||
---@field memory string | nil
|
||||
---
|
||||
---@class AvanteGeneratePromptsOptions: AvanteTemplateOptions
|
||||
---@field ask boolean
|
||||
@@ -358,3 +355,35 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field description string
|
||||
---@field type 'string' | 'string[]' | 'boolean'
|
||||
---@field optional? boolean
|
||||
---
|
||||
---@class avante.ChatHistoryEntry
|
||||
---@field timestamp string
|
||||
---@field provider string
|
||||
---@field model string
|
||||
---@field request string
|
||||
---@field response string
|
||||
---@field original_response string
|
||||
---@field selected_file {filepath: string}?
|
||||
---@field selected_code {filetype: string, content: string}?
|
||||
---@field reset_memory boolean?
|
||||
---@field selected_filepaths string[] | nil
|
||||
---
|
||||
---@class avante.ChatHistory
|
||||
---@field title string
|
||||
---@field timestamp string
|
||||
---@field entries avante.ChatHistoryEntry[]
|
||||
---@field memory avante.ChatMemory | nil
|
||||
---
|
||||
---@class avante.ChatMemory
|
||||
---@field content string
|
||||
---@field last_summarized_timestamp string
|
||||
|
||||
---@class avante.Path
|
||||
---@field history_path Path
|
||||
---@field cache_path Path
|
||||
---
|
||||
---@class avante.CurlOpts
|
||||
---@field provider AvanteProviderFunctor
|
||||
---@field prompt_opts AvantePromptOptions
|
||||
---@field handler_opts AvanteHandlerOptions
|
||||
---
|
||||
|
||||
52
lua/avante/utils/history.lua
Normal file
52
lua/avante/utils/history.lua
Normal file
@@ -0,0 +1,52 @@
|
||||
local Utils = require("avante.utils")
|
||||
|
||||
---@class avante.utils.history
|
||||
local M = {}
|
||||
|
||||
---@param entries avante.ChatHistoryEntry[]
|
||||
---@return avante.ChatHistoryEntry[]
|
||||
function M.filter_active_entries(entries)
|
||||
local entries_ = {}
|
||||
|
||||
for i = #entries, 1, -1 do
|
||||
local entry = entries[i]
|
||||
if entry.reset_memory then break end
|
||||
if
|
||||
entry.request == nil
|
||||
or entry.original_response == nil
|
||||
or entry.request == ""
|
||||
or entry.original_response == ""
|
||||
then
|
||||
break
|
||||
end
|
||||
table.insert(entries_, 1, entry)
|
||||
end
|
||||
|
||||
return entries_
|
||||
end
|
||||
|
||||
---@param entries avante.ChatHistoryEntry[]
|
||||
---@return AvanteLLMMessage[]
|
||||
function M.entries_to_llm_messages(entries)
|
||||
local messages = {}
|
||||
for _, entry in ipairs(entries) do
|
||||
local user_content = ""
|
||||
if entry.selected_file ~= nil then
|
||||
user_content = user_content .. "SELECTED FILE: " .. entry.selected_file.filepath .. "\n\n"
|
||||
end
|
||||
if entry.selected_code ~= nil then
|
||||
user_content = user_content
|
||||
.. "SELECTED CODE:\n\n```"
|
||||
.. entry.selected_code.filetype
|
||||
.. "\n"
|
||||
.. entry.selected_code.content
|
||||
.. "\n```\n\n"
|
||||
end
|
||||
user_content = user_content .. "USER PROMPT:\n\n" .. entry.request
|
||||
table.insert(messages, { role = "user", content = user_content })
|
||||
table.insert(messages, { role = "assistant", content = Utils.trim_think_content(entry.original_response) })
|
||||
end
|
||||
return messages
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -8,6 +8,7 @@ local lsp = vim.lsp
|
||||
---@field tokens avante.utils.tokens
|
||||
---@field root avante.utils.root
|
||||
---@field file avante.utils.file
|
||||
---@field history avante.utils.history
|
||||
local M = {}
|
||||
|
||||
setmetatable(M, {
|
||||
|
||||
Reference in New Issue
Block a user