refactor: summarize memory (#1508)
This commit is contained in:
@@ -19,6 +19,71 @@ M.CANCEL_PATTERN = "AvanteLLMEscape"
|
||||
|
||||
local group = api.nvim_create_augroup("avante_llm", { clear = true })
|
||||
|
||||
---@param bufnr integer
|
||||
---@param history avante.ChatHistory
|
||||
---@param cb fun(memory: avante.ChatMemory | nil): nil
|
||||
function M.summarize_memory(bufnr, history, cb)
|
||||
local system_prompt =
|
||||
[[Summarize the following conversation to extract the most critical information (such as languages used, conversation style, tech stack, considerations, user information, etc.) for memory in subsequent conversations. Since it is for memory purposes, be detailed and rigorous to ensure that no information from previous summaries is lost in the newly generated summary.]]
|
||||
local entries = Utils.history.filter_active_entries(history.entries)
|
||||
if #entries == 0 then
|
||||
cb(nil)
|
||||
return
|
||||
end
|
||||
if history.memory then
|
||||
system_prompt = system_prompt .. "\n\nThe previous summary is:\n\n" .. history.memory.content
|
||||
entries = vim
|
||||
.iter(entries)
|
||||
:filter(function(entry) return entry.timestamp > history.memory.last_summarized_timestamp end)
|
||||
:totable()
|
||||
end
|
||||
if #entries == 0 then
|
||||
cb(history.memory)
|
||||
return
|
||||
end
|
||||
local history_messages = Utils.history.entries_to_llm_messages(entries)
|
||||
history_messages = vim.list_slice(history_messages, 1, 4)
|
||||
if #history_messages == 0 then
|
||||
cb(history.memory)
|
||||
return
|
||||
end
|
||||
Utils.debug("summarize memory", #history_messages, history_messages[#history_messages].content)
|
||||
local response_content = ""
|
||||
local provider = Providers[Config.memory_summary_provider or Config.provider]
|
||||
M.curl({
|
||||
provider = provider,
|
||||
prompt_opts = {
|
||||
system_prompt = system_prompt,
|
||||
messages = {
|
||||
{ role = "user", content = vim.json.encode(history_messages) },
|
||||
},
|
||||
},
|
||||
handler_opts = {
|
||||
on_start = function(_) end,
|
||||
on_chunk = function(chunk)
|
||||
if not chunk then return end
|
||||
response_content = response_content .. chunk
|
||||
end,
|
||||
on_stop = function(stop_opts)
|
||||
if stop_opts.error ~= nil then
|
||||
Utils.error(string.format("summarize failed: %s", vim.inspect(stop_opts.error)))
|
||||
return
|
||||
end
|
||||
if stop_opts.reason == "complete" then
|
||||
response_content = Utils.trim_think_content(response_content)
|
||||
local memory = {
|
||||
content = response_content,
|
||||
last_summarized_timestamp = entries[#entries].timestamp,
|
||||
}
|
||||
history.memory = memory
|
||||
Path.history.save(bufnr, history)
|
||||
cb(memory)
|
||||
end
|
||||
end,
|
||||
},
|
||||
})
|
||||
end
|
||||
|
||||
---@param opts AvanteGeneratePromptsOptions
|
||||
---@return AvantePromptOptions
|
||||
function M.generate_prompts(opts)
|
||||
@@ -58,6 +123,7 @@ function M.generate_prompts(opts)
|
||||
diagnostics = opts.diagnostics,
|
||||
system_info = system_info,
|
||||
model_name = provider.model or "unknown",
|
||||
memory = opts.memory,
|
||||
}
|
||||
|
||||
local system_prompt = Path.prompts.render_mode(mode, template_opts)
|
||||
@@ -80,6 +146,11 @@ function M.generate_prompts(opts)
|
||||
if code_context ~= "" then table.insert(messages, { role = "user", content = code_context }) end
|
||||
end
|
||||
|
||||
if opts.memory ~= nil and opts.memory ~= "" and opts.memory ~= "null" then
|
||||
local memory = Path.prompts.render_file("_memory.avanterules", template_opts)
|
||||
if memory ~= "" then table.insert(messages, { role = "user", content = memory }) end
|
||||
end
|
||||
|
||||
if instructions then
|
||||
if opts.use_xml_format then
|
||||
table.insert(messages, { role = "user", content = string.format("<question>%s</question>", instructions) })
|
||||
@@ -150,90 +221,17 @@ function M.calculate_tokens(opts)
|
||||
return tokens
|
||||
end
|
||||
|
||||
---@param opts AvanteLLMStreamOptions
|
||||
function M._stream(opts)
|
||||
local provider = opts.provider or Providers[Config.provider]
|
||||
---@param opts avante.CurlOpts
|
||||
function M.curl(opts)
|
||||
local provider = opts.provider
|
||||
local prompt_opts = opts.prompt_opts
|
||||
local handler_opts = opts.handler_opts
|
||||
|
||||
---@cast provider AvanteProviderFunctor
|
||||
|
||||
local prompt_opts = M.generate_prompts(opts)
|
||||
---@type AvanteCurlOutput
|
||||
local spec = provider:parse_curl_args(prompt_opts)
|
||||
|
||||
---@type string
|
||||
local current_event_state = nil
|
||||
|
||||
---@type AvanteHandlerOptions
|
||||
local handler_opts = {
|
||||
on_start = opts.on_start,
|
||||
on_chunk = opts.on_chunk,
|
||||
on_stop = function(stop_opts)
|
||||
---@param tool_use_list AvanteLLMToolUse[]
|
||||
---@param tool_use_index integer
|
||||
---@param tool_histories AvanteLLMToolHistory[]
|
||||
local function handle_next_tool_use(tool_use_list, tool_use_index, tool_histories)
|
||||
if tool_use_index > #tool_use_list then
|
||||
local new_opts = vim.tbl_deep_extend("force", opts, {
|
||||
tool_histories = tool_histories,
|
||||
})
|
||||
return M._stream(new_opts)
|
||||
end
|
||||
local tool_use = tool_use_list[tool_use_index]
|
||||
---@param result string | nil
|
||||
---@param error string | nil
|
||||
local function handle_tool_result(result, error)
|
||||
local tool_result = {
|
||||
tool_use_id = tool_use.id,
|
||||
content = error ~= nil and error or result,
|
||||
is_error = error ~= nil,
|
||||
}
|
||||
table.insert(tool_histories, { tool_result = tool_result, tool_use = tool_use })
|
||||
return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_histories)
|
||||
end
|
||||
-- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil
|
||||
local result, error = LLMTools.process_tool_use(opts.tools, tool_use, opts.on_tool_log, handle_tool_result)
|
||||
if result ~= nil or error ~= nil then return handle_tool_result(result, error) end
|
||||
end
|
||||
if stop_opts.reason == "tool_use" and stop_opts.tool_use_list then
|
||||
local old_tool_histories = vim.deepcopy(opts.tool_histories) or {}
|
||||
local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[]
|
||||
for _, tool_use in vim.spairs(stop_opts.tool_use_list) do
|
||||
table.insert(sorted_tool_use_list, tool_use)
|
||||
end
|
||||
return handle_next_tool_use(sorted_tool_use_list, 1, old_tool_histories)
|
||||
end
|
||||
if stop_opts.reason == "rate_limit" then
|
||||
local msg = "Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ..."
|
||||
opts.on_chunk("\n*[" .. msg .. "]*\n")
|
||||
local timer = vim.loop.new_timer()
|
||||
if timer then
|
||||
local retry_after = stop_opts.retry_after
|
||||
local function countdown()
|
||||
timer:start(
|
||||
1000,
|
||||
0,
|
||||
vim.schedule_wrap(function()
|
||||
if retry_after > 0 then retry_after = retry_after - 1 end
|
||||
local msg_ = "Rate limit reached. Retrying in " .. retry_after .. " seconds ..."
|
||||
opts.on_chunk([[\033[1A\033[K]] .. "\n*[" .. msg_ .. "]*\n")
|
||||
countdown()
|
||||
end)
|
||||
)
|
||||
end
|
||||
countdown()
|
||||
end
|
||||
Utils.info("Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds", { title = "Avante" })
|
||||
vim.defer_fn(function()
|
||||
if timer then timer:stop() end
|
||||
M._stream(opts)
|
||||
end, stop_opts.retry_after * 1000)
|
||||
return
|
||||
end
|
||||
return opts.on_stop(stop_opts)
|
||||
end,
|
||||
}
|
||||
|
||||
---@type AvanteCurlOutput
|
||||
local spec = provider:parse_curl_args(provider, prompt_opts)
|
||||
|
||||
local resp_ctx = {}
|
||||
|
||||
---@param line string
|
||||
@@ -383,6 +381,91 @@ function M._stream(opts)
|
||||
return active_job
|
||||
end
|
||||
|
||||
---@param opts AvanteLLMStreamOptions
|
||||
function M._stream(opts)
|
||||
local provider = opts.provider or Providers[Config.provider]
|
||||
|
||||
---@cast provider AvanteProviderFunctor
|
||||
|
||||
local prompt_opts = M.generate_prompts(opts)
|
||||
|
||||
---@type AvanteHandlerOptions
|
||||
local handler_opts = {
|
||||
on_start = opts.on_start,
|
||||
on_chunk = opts.on_chunk,
|
||||
on_stop = function(stop_opts)
|
||||
---@param tool_use_list AvanteLLMToolUse[]
|
||||
---@param tool_use_index integer
|
||||
---@param tool_histories AvanteLLMToolHistory[]
|
||||
local function handle_next_tool_use(tool_use_list, tool_use_index, tool_histories)
|
||||
if tool_use_index > #tool_use_list then
|
||||
local new_opts = vim.tbl_deep_extend("force", opts, {
|
||||
tool_histories = tool_histories,
|
||||
})
|
||||
return M._stream(new_opts)
|
||||
end
|
||||
local tool_use = tool_use_list[tool_use_index]
|
||||
---@param result string | nil
|
||||
---@param error string | nil
|
||||
local function handle_tool_result(result, error)
|
||||
local tool_result = {
|
||||
tool_use_id = tool_use.id,
|
||||
content = error ~= nil and error or result,
|
||||
is_error = error ~= nil,
|
||||
}
|
||||
table.insert(tool_histories, { tool_result = tool_result, tool_use = tool_use })
|
||||
return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_histories)
|
||||
end
|
||||
-- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil
|
||||
local result, error = LLMTools.process_tool_use(opts.tools, tool_use, opts.on_tool_log, handle_tool_result)
|
||||
if result ~= nil or error ~= nil then return handle_tool_result(result, error) end
|
||||
end
|
||||
if stop_opts.reason == "tool_use" and stop_opts.tool_use_list then
|
||||
local old_tool_histories = vim.deepcopy(opts.tool_histories) or {}
|
||||
local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[]
|
||||
for _, tool_use in vim.spairs(stop_opts.tool_use_list) do
|
||||
table.insert(sorted_tool_use_list, tool_use)
|
||||
end
|
||||
return handle_next_tool_use(sorted_tool_use_list, 1, old_tool_histories)
|
||||
end
|
||||
if stop_opts.reason == "rate_limit" then
|
||||
local msg = "Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ..."
|
||||
opts.on_chunk("\n*[" .. msg .. "]*\n")
|
||||
local timer = vim.loop.new_timer()
|
||||
if timer then
|
||||
local retry_after = stop_opts.retry_after
|
||||
local function countdown()
|
||||
timer:start(
|
||||
1000,
|
||||
0,
|
||||
vim.schedule_wrap(function()
|
||||
if retry_after > 0 then retry_after = retry_after - 1 end
|
||||
local msg_ = "Rate limit reached. Retrying in " .. retry_after .. " seconds ..."
|
||||
opts.on_chunk([[\033[1A\033[K]] .. "\n*[" .. msg_ .. "]*\n")
|
||||
countdown()
|
||||
end)
|
||||
)
|
||||
end
|
||||
countdown()
|
||||
end
|
||||
Utils.info("Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds", { title = "Avante" })
|
||||
vim.defer_fn(function()
|
||||
if timer then timer:stop() end
|
||||
M._stream(opts)
|
||||
end, stop_opts.retry_after * 1000)
|
||||
return
|
||||
end
|
||||
return opts.on_stop(stop_opts)
|
||||
end,
|
||||
}
|
||||
|
||||
return M.curl({
|
||||
provider = provider,
|
||||
prompt_opts = prompt_opts,
|
||||
handler_opts = handler_opts,
|
||||
})
|
||||
end
|
||||
|
||||
local function _merge_response(first_response, second_response, opts)
|
||||
local prompt = "\n" .. Config.dual_boost.prompt
|
||||
prompt = prompt
|
||||
|
||||
Reference in New Issue
Block a user