refactor: history messages (#1934)

This commit is contained in:
yetone
2025-04-30 03:07:18 +08:00
committed by GitHub
parent f9aa75459d
commit f10b8383e3
36 changed files with 1699 additions and 1462 deletions

View File

@@ -10,6 +10,7 @@ local Path = require("avante.path")
local Providers = require("avante.providers")
local LLMToolHelpers = require("avante.llm_tools.helpers")
local LLMTools = require("avante.llm_tools")
local HistoryMessage = require("avante.history_message")
---@class avante.LLM
local M = {}
@@ -26,7 +27,7 @@ function M.summarize_chat_thread_title(content, cb)
local system_prompt =
[[Summarize the content as a title for the chat thread. The title should be a concise and informative summary of the conversation, capturing the main points and key takeaways. It should be no longer than 100 words and should be written in a clear and engaging style. The title should be suitable for use as the title of a chat thread on a messaging platform or other communication medium.]]
local response_content = ""
local provider = Providers[Config.memory_summary_provider or Config.provider]
local provider = Providers.get_memory_summary_provider()
M.curl({
provider = provider,
prompt_opts = {
@@ -58,73 +59,49 @@ function M.summarize_chat_thread_title(content, cb)
})
end
---@param messages AvanteLLMMessage[]
---@return AvanteLLMMessage[]
local function filter_out_tool_use_messages(messages)
local filtered_messages = {}
for _, message in ipairs(messages) do
local content = message.content
if type(content) == "table" then
local new_content = {}
for _, item in ipairs(content) do
if item.type == "tool_use" or item.type == "tool_result" then goto continue end
table.insert(new_content, item)
::continue::
end
content = new_content
end
if type(content) == "table" then
if #content > 0 then table.insert(filtered_messages, { role = message.role, content = content }) end
else
table.insert(filtered_messages, { role = message.role, content = content })
end
end
return filtered_messages
end
---@param bufnr integer
---@param history avante.ChatHistory
---@param entries? avante.ChatHistoryEntry[]
---@param history_messages avante.HistoryMessage[]
---@param cb fun(memory: avante.ChatMemory | nil): nil
function M.summarize_memory(bufnr, history, entries, cb)
local system_prompt = [[You are a helpful AI assistant tasked with summarizing conversations.]]
if not entries then entries = Utils.history.filter_active_entries(history.entries) end
if #entries == 0 then
cb(nil)
return
end
if history.memory then
entries = vim
.iter(entries)
:filter(function(entry) return entry.timestamp > history.memory.last_summarized_timestamp end)
:totable()
end
if #entries == 0 then
cb(history.memory)
return
end
local history_messages = Utils.history.entries_to_llm_messages(entries)
history_messages = filter_out_tool_use_messages(history_messages)
history_messages = vim.list_slice(history_messages, 1, 4)
function M.summarize_memory(bufnr, history, history_messages, cb)
local system_prompt =
[[You are an expert coding assistant. Your goal is to generate a concise, structured summary of the conversation below that captures all essential information needed to continue development after context replacement. Include tasks performed, code areas modified or reviewed, key decisions or assumptions, test results or errors, and outstanding tasks or next steps.]]
if #history_messages == 0 then
cb(history.memory)
return
end
Utils.debug("summarize memory", #history_messages, history_messages[#history_messages].content)
local user_prompt =
[[Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.]]
local latest_timestamp = history_messages[#history_messages].timestamp
local latest_message_uuid = history_messages[#history_messages].uuid
local conversation_items = vim
.iter(history_messages)
:filter(function(msg)
if msg.just_for_display then return false end
if msg.message.role ~= "assistant" and msg.message.role ~= "user" then return false end
local content = msg.message.content
if type(content) == "table" and content[1].type == "tool_result" then return false end
if type(content) == "table" and content[1].type == "tool_use" then return false end
return true
end)
:map(function(msg) return msg.message.role .. ": " .. Utils.message_to_text(msg) end)
:totable()
local conversation_text = table.concat(conversation_items, "\n")
local user_prompt = "Here is the conversation so far:\n"
.. conversation_text
.. "\n\nPlease summarize this conversation, covering:\n1. Tasks performed and outcomes\n2. Code files, modules, or functions modified or examined\n3. Important decisions or assumptions made\n4. Errors encountered and test or build results\n5. Remaining tasks, open questions, or next steps\nProvide the summary in a clear, concise format."
if history.memory then user_prompt = user_prompt .. "\n\nThe previous summary is:\n\n" .. history.memory.content end
table.insert(history_messages, {
role = "user",
content = user_prompt,
})
local messages = {
{
role = "user",
content = user_prompt,
},
}
local response_content = ""
local provider = Providers[Config.memory_summary_provider or Config.provider]
local provider = Providers.get_memory_summary_provider()
M.curl({
provider = provider,
prompt_opts = {
system_prompt = system_prompt,
messages = history_messages,
messages = messages,
},
handler_opts = {
on_start = function(_) end,
@@ -141,11 +118,14 @@ function M.summarize_memory(bufnr, history, entries, cb)
response_content = Utils.trim_think_content(response_content)
local memory = {
content = response_content,
last_summarized_timestamp = entries[#entries].timestamp,
last_summarized_timestamp = latest_timestamp,
last_message_uuid = latest_message_uuid,
}
history.memory = memory
Path.history.save(bufnr, history)
cb(memory)
else
cb(history.memory)
end
end,
},
@@ -156,7 +136,7 @@ end
---@return AvantePromptOptions
function M.generate_prompts(opts)
local provider = opts.provider or Providers[Config.provider]
local mode = opts.mode or "planning"
local mode = opts.mode or Config.mode
---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor
local _, request_body = Providers.parse_config(provider)
local max_tokens = request_body.max_tokens or 4096
@@ -166,22 +146,58 @@ function M.generate_prompts(opts)
if opts.prompt_opts and opts.prompt_opts.image_paths then
image_paths = vim.list_extend(image_paths, opts.prompt_opts.image_paths)
end
local instructions = opts.instructions
if instructions and instructions:match("image: ") then
local lines = vim.split(opts.instructions, "\n")
for i, line in ipairs(lines) do
if line:match("^image: ") then
local image_path = line:gsub("^image: ", "")
table.insert(image_paths, image_path)
table.remove(lines, i)
end
end
instructions = table.concat(lines, "\n")
end
local project_root = Utils.root.get()
Path.prompts.initialize(Path.prompts.get_templates_dir(project_root))
local tool_id_to_tool_name = {}
local tool_id_to_path = {}
local viewed_files = {}
if opts.history_messages then
for _, message in ipairs(opts.history_messages) do
local content = message.message.content
if type(content) ~= "table" then goto continue end
for _, item in ipairs(content) do
if type(item) ~= "table" then goto continue1 end
if item.type ~= "tool_use" then goto continue1 end
local tool_name = item.name
if tool_name ~= "view" then goto continue1 end
local path = item.input.path
tool_id_to_tool_name[item.id] = tool_name
if path then
local uniform_path = Utils.uniform_path(path)
tool_id_to_path[item.id] = uniform_path
viewed_files[uniform_path] = item.id
end
::continue1::
end
::continue::
end
for _, message in ipairs(opts.history_messages) do
local content = message.message.content
if type(content) == "table" then
for _, item in ipairs(content) do
if type(item) ~= "table" then goto continue end
if item.type ~= "tool_result" then goto continue end
local tool_name = tool_id_to_tool_name[item.tool_use_id]
if tool_name ~= "view" then goto continue end
local path = tool_id_to_path[item.tool_use_id]
local latest_tool_id = viewed_files[path]
if not latest_tool_id then goto continue end
if latest_tool_id ~= item.tool_use_id then
item.content = string.format("The file %s has been updated. Please use the latest view tool result!", path)
else
local lines, error = Utils.read_file_from_buf_or_disk(path)
if error ~= nil then Utils.error("error reading file: " .. error) end
lines = lines or {}
item.content = table.concat(lines, "\n")
end
::continue::
end
end
end
end
local system_info = Utils.get_system_info()
local selected_files = opts.selected_files or {}
@@ -200,6 +216,8 @@ function M.generate_prompts(opts)
end
end
selected_files = vim.iter(selected_files):filter(function(file) return viewed_files[file.path] == nil end):totable()
local template_opts = {
ask = opts.ask, -- TODO: add mode without ask instruction
code_lang = opts.code_lang,
@@ -229,36 +247,42 @@ function M.generate_prompts(opts)
end
---@type AvanteLLMMessage[]
local messages = {}
local context_messages = {}
if opts.prompt_opts and opts.prompt_opts.messages then
messages = vim.list_extend(messages, opts.prompt_opts.messages)
context_messages = vim.list_extend(context_messages, opts.prompt_opts.messages)
end
if opts.project_context ~= nil and opts.project_context ~= "" and opts.project_context ~= "null" then
local project_context = Path.prompts.render_file("_project.avanterules", template_opts)
if project_context ~= "" then table.insert(messages, { role = "user", content = project_context }) end
if project_context ~= "" then
table.insert(context_messages, { role = "user", content = project_context, visible = false, is_context = true })
end
end
if opts.diagnostics ~= nil and opts.diagnostics ~= "" and opts.diagnostics ~= "null" then
local diagnostics = Path.prompts.render_file("_diagnostics.avanterules", template_opts)
if diagnostics ~= "" then table.insert(messages, { role = "user", content = diagnostics }) end
if diagnostics ~= "" then
table.insert(context_messages, { role = "user", content = diagnostics, visible = false, is_context = true })
end
end
if #selected_files > 0 or opts.selected_code ~= nil then
local code_context = Path.prompts.render_file("_context.avanterules", template_opts)
if code_context ~= "" then table.insert(messages, { role = "user", content = code_context }) end
if code_context ~= "" then
table.insert(context_messages, { role = "user", content = code_context, visible = false, is_context = true })
end
end
if opts.memory ~= nil and opts.memory ~= "" and opts.memory ~= "null" then
local memory = Path.prompts.render_file("_memory.avanterules", template_opts)
if memory ~= "" then table.insert(messages, { role = "user", content = memory }) end
if memory ~= "" then
table.insert(context_messages, { role = "user", content = memory, visible = false, is_context = true })
end
end
if instructions then table.insert(messages, { role = "user", content = instructions }) end
local remaining_tokens = max_tokens - Utils.tokens.calculate_tokens(system_prompt)
for _, message in ipairs(messages) do
for _, message in ipairs(context_messages) do
remaining_tokens = remaining_tokens - Utils.tokens.calculate_tokens(message.content)
end
@@ -267,47 +291,49 @@ function M.generate_prompts(opts)
dropped_history_messages = vim.list_extend(dropped_history_messages, opts.prompt_opts.dropped_history_messages)
end
local final_history_messages = {}
if opts.history_messages then
if Config.history.max_tokens > 0 then remaining_tokens = math.min(Config.history.max_tokens, remaining_tokens) end
-- Traverse the history in reverse, keeping only the latest history until the remaining tokens are exhausted and the first message role is "user"
local history_messages = {}
for i = #opts.history_messages, 1, -1 do
local message = opts.history_messages[i]
if Config.history.carried_entry_count ~= nil then
if #history_messages > Config.history.carried_entry_count then break end
table.insert(history_messages, message)
local tokens = Utils.tokens.calculate_tokens(message.message.content)
remaining_tokens = remaining_tokens - tokens
if remaining_tokens > 0 then
table.insert(history_messages, 1, message)
else
local tokens = Utils.tokens.calculate_tokens(message.content)
remaining_tokens = remaining_tokens - tokens
if remaining_tokens > 0 then
table.insert(history_messages, message)
else
break
end
break
end
end
if #history_messages == 0 then
history_messages = vim.list_slice(opts.history_messages, #opts.history_messages - 1, #opts.history_messages)
end
dropped_history_messages = vim.list_slice(opts.history_messages, 1, #opts.history_messages - #history_messages)
-- prepend the history messages to the messages table
vim.iter(history_messages):each(function(msg) table.insert(messages, 1, msg) end)
if #messages > 0 and messages[1].role == "assistant" then table.remove(messages, 1) end
vim.iter(history_messages):each(function(msg) table.insert(final_history_messages, msg) end)
end
if opts.mode == "cursor-applying" then
local user_prompt = [[
Merge all changes from the <update> snippet into the <code> below.
- Preserve the code's structure, order, comments, and indentation exactly.
- Output only the updated code, enclosed within <updated-code> and </updated-code> tags.
- Do not include any additional text, explanations, placeholders, ellipses, or code fences.
-- Utils.debug("opts.history_messages", opts.history_messages)
-- Utils.debug("final_history_messages", final_history_messages)
]]
user_prompt = user_prompt .. string.format("<code>\n%s\n</code>\n", opts.original_code)
for _, snippet in ipairs(opts.update_snippets) do
user_prompt = user_prompt .. string.format("<update>\n%s\n</update>\n", snippet)
end
user_prompt = user_prompt .. "Provide the complete updated code."
table.insert(messages, { role = "user", content = user_prompt })
---@type AvanteLLMMessage[]
local messages = vim.deepcopy(context_messages)
for _, msg in ipairs(final_history_messages) do
local message = msg.message
table.insert(messages, message)
end
messages = vim
.iter(messages)
:filter(function(msg) return type(msg.content) ~= "string" or msg.content ~= "" end)
:totable()
if opts.instructions ~= nil and opts.instructions ~= "" then
messages = vim.list_extend(messages, { { role = "user", content = opts.instructions } })
end
opts.session_ctx = opts.session_ctx or {}
@@ -318,19 +344,12 @@ Merge all changes from the <update> snippet into the <code> below.
if opts.tools then tools = vim.list_extend(tools, opts.tools) end
if opts.prompt_opts and opts.prompt_opts.tools then tools = vim.list_extend(tools, opts.prompt_opts.tools) end
local tool_histories = {}
if opts.tool_histories then tool_histories = vim.list_extend(tool_histories, opts.tool_histories) end
if opts.prompt_opts and opts.prompt_opts.tool_histories then
tool_histories = vim.list_extend(tool_histories, opts.prompt_opts.tool_histories)
end
---@type AvantePromptOptions
return {
system_prompt = system_prompt,
messages = messages,
image_paths = image_paths,
tools = tools,
tool_histories = tool_histories,
dropped_history_messages = dropped_history_messages,
}
end
@@ -372,7 +391,9 @@ function M.curl(opts)
---@type string
local current_event_state = nil
local resp_ctx = {}
resp_ctx.session_id = Utils.uuid()
local response_body = ""
---@param line string
local function parse_stream_data(line)
local event = line:match("^event:%s*(.+)$")
@@ -381,7 +402,16 @@ function M.curl(opts)
return
end
local data_match = line:match("^data:%s*(.+)$")
if data_match then provider:parse_response(resp_ctx, data_match, current_event_state, handler_opts) end
if data_match then
provider:parse_response(resp_ctx, data_match, current_event_state, handler_opts)
else
response_body = response_body .. line
local ok, jsn = pcall(vim.json.decode, response_body)
if ok then
response_body = ""
if jsn.error then handler_opts.on_stop({ reason = "error", error = jsn.error }) end
end
end
end
local function parse_response_without_stream(data)
@@ -394,10 +424,13 @@ function M.curl(opts)
local temp_file = fn.tempname()
local curl_body_file = temp_file .. "-request-body.json"
local resp_body_file = temp_file .. "-response-body.json"
local json_content = vim.json.encode(spec.body)
fn.writefile(vim.split(json_content, "\n"), curl_body_file)
Utils.debug("curl body file:", curl_body_file)
Utils.debug("curl request body file:", curl_body_file)
Utils.debug("curl response body file:", resp_body_file)
local headers_file = temp_file .. "-headers.txt"
@@ -407,6 +440,7 @@ function M.curl(opts)
if Config.debug then return end
vim.schedule(function()
fn.delete(curl_body_file)
pcall(fn.delete, resp_body_file)
fn.delete(headers_file)
end)
end
@@ -431,6 +465,15 @@ function M.curl(opts)
return
end
if not data then return end
if Config.debug then
if type(data) == "string" then
local file = io.open(resp_body_file, "a")
if file then
file:write(data .. "\n")
file:close()
end
end
end
vim.schedule(function()
if Config[Config.provider] == nil and provider.parse_stream_data ~= nil then
if provider.parse_response ~= nil then
@@ -495,6 +538,17 @@ function M.curl(opts)
Utils.error("API request failed with status " .. result.status, { once = true, title = "Avante" })
end
if result.status == 429 then
Utils.debug("result", result)
if result.body then
local ok, jsn = pcall(vim.json.decode, result.body)
if ok then
if jsn.error and jsn.error.message then
handler_opts.on_stop({ reason = "error", error = jsn.error.message })
return
end
end
end
Utils.debug("result", result)
local retry_after = 10
if headers_map["retry-after"] then retry_after = tonumber(headers_map["retry-after"]) or 10 end
handler_opts.on_stop({ reason = "rate_limit", retry_after = retry_after })
@@ -585,17 +639,34 @@ function M._stream(opts)
---@type AvanteHandlerOptions
local handler_opts = {
on_partial_tool_use = opts.on_partial_tool_use,
on_messages_add = opts.on_messages_add,
on_state_change = opts.on_state_change,
on_start = opts.on_start,
on_chunk = opts.on_chunk,
on_stop = function(stop_opts)
---@param tool_use_list AvanteLLMToolUse[]
---@param tool_use_index integer
---@param tool_histories AvanteLLMToolHistory[]
local function handle_next_tool_use(tool_use_list, tool_use_index, tool_histories)
---@param tool_results AvanteLLMToolResult[]
local function handle_next_tool_use(tool_use_list, tool_use_index, tool_results)
if tool_use_index > #tool_use_list then
---@type avante.HistoryMessage[]
local messages = {}
for _, tool_result in ipairs(tool_results) do
messages[#messages + 1] = HistoryMessage:new({
role = "user",
content = {
{
type = "tool_result",
tool_use_id = tool_result.tool_use_id,
content = tool_result.content,
is_error = tool_result.is_error,
},
},
})
end
opts.on_messages_add(messages)
local new_opts = vim.tbl_deep_extend("force", opts, {
tool_histories = tool_histories,
history_messages = opts.get_history_messages(),
})
if provider.get_rate_limit_sleep_time then
local sleep_time = provider:get_rate_limit_sleep_time(resp_headers)
@@ -616,7 +687,7 @@ function M._stream(opts)
if error == LLMToolHelpers.CANCEL_TOKEN then
Utils.debug("Tool execution was cancelled by user")
opts.on_chunk("\n*[Request cancelled by user during tool execution.]*\n")
return opts.on_stop({ reason = "cancelled", tool_histories = tool_histories })
return opts.on_stop({ reason = "cancelled" })
end
local tool_result = {
@@ -624,8 +695,8 @@ function M._stream(opts)
content = error ~= nil and error or result,
is_error = error ~= nil,
}
table.insert(tool_histories, { tool_result = tool_result, tool_use = tool_use })
return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_histories)
table.insert(tool_results, tool_result)
return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_results)
end
-- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil
local result, error = LLMTools.process_tool_use(
@@ -638,20 +709,53 @@ function M._stream(opts)
if result ~= nil or error ~= nil then return handle_tool_result(result, error) end
end
if stop_opts.reason == "cancelled" then
opts.on_chunk("\n*[Request cancelled by user.]*\n")
return opts.on_stop({ reason = "cancelled", tool_histories = opts.tool_histories })
if opts.on_chunk then opts.on_chunk("\n*[Request cancelled by user.]*\n") end
if opts.on_messages_add then
local message = HistoryMessage:new({
role = "user",
content = "[Request cancelled by user.]",
})
opts.on_messages_add({ message })
end
return opts.on_stop({ reason = "cancelled" })
end
if stop_opts.reason == "tool_use" and stop_opts.tool_use_list then
local old_tool_histories = vim.deepcopy(opts.tool_histories) or {}
if stop_opts.reason == "tool_use" then
local tool_use_list = {} ---@type AvanteLLMToolUse[]
local tool_result_seen = {}
local history_messages = opts.get_history_messages and opts.get_history_messages() or {}
for idx = #history_messages, 1, -1 do
local message = history_messages[idx]
local content = message.message.content
if type(content) ~= "table" or #content == 0 then goto continue end
if content[1].type == "tool_use" then
if not tool_result_seen[content[1].id] then
table.insert(tool_use_list, 1, content[1])
else
break
end
end
if content[1].type == "tool_result" then tool_result_seen[content[1].tool_use_id] = true end
::continue::
end
local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[]
for _, tool_use in vim.spairs(stop_opts.tool_use_list) do
for _, tool_use in vim.spairs(tool_use_list) do
table.insert(sorted_tool_use_list, tool_use)
end
return handle_next_tool_use(sorted_tool_use_list, 1, old_tool_histories)
return handle_next_tool_use(sorted_tool_use_list, 1, {})
end
if stop_opts.reason == "rate_limit" then
local msg = "Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ..."
opts.on_chunk("\n*[" .. msg .. "]*\n")
local msg_content = "*[Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ...]*"
if opts.on_chunk then opts.on_chunk("\n" .. msg_content .. "\n") end
local message
if opts.on_messages_add then
message = HistoryMessage:new({
role = "assistant",
content = "\n\n" .. msg_content,
}, {
just_for_display = true,
})
opts.on_messages_add({ message })
end
local timer = vim.loop.new_timer()
if timer then
local retry_after = stop_opts.retry_after
@@ -661,8 +765,12 @@ function M._stream(opts)
0,
vim.schedule_wrap(function()
if retry_after > 0 then retry_after = retry_after - 1 end
local msg_ = "Rate limit reached. Retrying in " .. retry_after .. " seconds ..."
opts.on_chunk([[\033[1A\033[K]] .. "\n*[" .. msg_ .. "]*\n")
local msg_content_ = "*[Rate limit reached. Retrying in " .. retry_after .. " seconds ...]*"
if opts.on_chunk then opts.on_chunk([[\033[1A\033[K]] .. "\n" .. msg_content_ .. "\n") end
if opts.on_messages_add and message then
message.message.content = "\n\n" .. msg_content_
opts.on_messages_add({ message })
end
countdown()
end)
)
@@ -676,7 +784,6 @@ function M._stream(opts)
end, stop_opts.retry_after * 1000)
return
end
stop_opts.tool_histories = opts.tool_histories
return opts.on_stop(stop_opts)
end,
}
@@ -697,6 +804,8 @@ local function _merge_response(first_response, second_response, opts)
prompt = prompt .. "\n"
if opts.instructions == nil then opts.instructions = "" end
-- append this reference prompt to the prompt_opts messages at last
opts.instructions = opts.instructions .. prompt
@@ -802,20 +911,12 @@ function M.stream(opts)
return original_on_stop(stop_opts)
end)
end
if opts.on_partial_tool_use ~= nil then
local original_on_partial_tool_use = opts.on_partial_tool_use
opts.on_partial_tool_use = vim.schedule_wrap(function(tool_use)
if is_completed then return end
return original_on_partial_tool_use(tool_use)
end)
end
local valid_dual_boost_modes = {
planning = true,
["cursor-planning"] = true,
legacy = true,
}
opts.mode = opts.mode or "planning"
opts.mode = opts.mode or Config.mode
if Config.dual_boost.enabled and valid_dual_boost_modes[opts.mode] then
M._dual_boost_stream(