refactor(history): start moving history-related code into avante/history

The utils module has grown too big and contains unrelated functionality.
Start moving code related to managing history messages comprising chat
history into lua/avante/history module to keep the code more manageable.
This commit is contained in:
Dmitry Torokhov
2025-07-08 15:57:05 -07:00
committed by yetone
parent 85a9fcef95
commit 34907fc1cd
14 changed files with 448 additions and 421 deletions

View File

@@ -0,0 +1,101 @@
local Utils = require("avante.utils")
local M = {}
---@param message avante.HistoryMessage
---@return boolean
function M.is_tool_use_message(message)
local content = message.message.content
if type(content) == "string" then return false end
if vim.islist(content) then
for _, item in ipairs(content) do
if item.type == "tool_use" then return true end
end
end
return false
end
---@param message avante.HistoryMessage
---@return boolean
function M.is_tool_result_message(message)
local content = message.message.content
if type(content) == "string" then return false end
if vim.islist(content) then
for _, item in ipairs(content) do
if item.type == "tool_result" then return true end
end
end
return false
end
---@param message avante.HistoryMessage
---@param messages avante.HistoryMessage[]
---@return avante.HistoryMessage | nil
function M.get_tool_use_message(message, messages)
local content = message.message.content
if type(content) == "string" then return nil end
if vim.islist(content) then
local tool_id = nil
for _, item in ipairs(content) do
if item.type == "tool_result" then
tool_id = item.tool_use_id
break
end
end
if not tool_id then return nil end
for idx_ = #messages, 1, -1 do
local message_ = messages[idx_]
local content_ = message_.message.content
if type(content_) == "table" then
for _, item in ipairs(content_) do
if item.type == "tool_use" and item.id == tool_id then return message_ end
end
end
end
end
return nil
end
---@param tool_use_message avante.HistoryMessage | nil
function M.is_edit_func_call_message(tool_use_message)
local is_replace_func_call = false
local is_str_replace_editor_func_call = false
local is_str_replace_based_edit_tool_func_call = false
local path = nil
if tool_use_message and M.is_tool_use_message(tool_use_message) then
local tool_use = tool_use_message.message.content[1]
---@cast tool_use AvanteLLMToolUse
return Utils.is_edit_func_call_tool_use(tool_use)
end
return is_replace_func_call, is_str_replace_editor_func_call, is_str_replace_based_edit_tool_func_call, path
end
---@param message avante.HistoryMessage
---@param messages avante.HistoryMessage[]
---@return avante.HistoryMessage | nil
function M.get_tool_result_message(message, messages)
local content = message.message.content
if type(content) == "string" then return nil end
if vim.islist(content) then
local tool_id = nil
for _, item in ipairs(content) do
if item.type == "tool_use" then
tool_id = item.id
break
end
end
if not tool_id then return nil end
for idx_ = #messages, 1, -1 do
local message_ = messages[idx_]
local content_ = message_.message.content
if type(content_) == "table" then
for _, item in ipairs(content_) do
if item.type == "tool_result" and item.tool_use_id == tool_id then return message_ end
end
end
end
end
return nil
end
return M

303
lua/avante/history/init.lua Normal file
View File

@@ -0,0 +1,303 @@
local Helpers = require("avante.history.helpers")
local Message = require("avante.history.message")
local Utils = require("avante.utils")
local M = {}
M.Helpers = Helpers
M.Message = Message
---@param history avante.ChatHistory
---@return avante.HistoryMessage[]
function M.get_history_messages(history)
if history.messages then return history.messages end
local messages = {}
for _, entry in ipairs(history.entries or {}) do
if entry.request and entry.request ~= "" then
local message = Message:new({
role = "user",
content = entry.request,
}, {
timestamp = entry.timestamp,
is_user_submission = true,
visible = entry.visible,
selected_filepaths = entry.selected_filepaths,
selected_code = entry.selected_code,
})
table.insert(messages, message)
end
if entry.response and entry.response ~= "" then
local message = Message:new({
role = "assistant",
content = entry.response,
}, {
timestamp = entry.timestamp,
visible = entry.visible,
})
table.insert(messages, message)
end
end
history.messages = messages
return messages
end
---@param messages avante.HistoryMessage[]
---@param using_ReAct_prompt boolean
---@param add_diagnostic boolean Mix in LSP diagnostic info for affected files
---@return avante.HistoryMessage[]
M.update_history_messages = function(messages, using_ReAct_prompt, add_diagnostic)
local tool_id_to_tool_name = {}
local tool_id_to_path = {}
local tool_id_to_start_line = {}
local tool_id_to_end_line = {}
local viewed_files = {}
local last_modified_files = {}
local history_messages = {}
for idx, message in ipairs(messages) do
if Helpers.is_tool_result_message(message) then
local tool_use_message = Helpers.get_tool_use_message(message, messages)
local is_edit_func_call, _, _, path = Helpers.is_edit_func_call_message(tool_use_message)
-- Only track as successful modification if not an error AND not user-declined
if
is_edit_func_call
and path
and not message.message.content[1].is_error
and not message.message.content[1].is_user_declined
then
local uniformed_path = Utils.uniform_path(path)
last_modified_files[uniformed_path] = idx
end
end
end
for idx, message in ipairs(messages) do
table.insert(history_messages, message)
if Helpers.is_tool_result_message(message) then
local tool_use_message = Helpers.get_tool_use_message(message, messages)
local is_edit_func_call, is_str_replace_editor_func_call, is_str_replace_based_edit_tool_func_call, path =
Helpers.is_edit_func_call_message(tool_use_message)
--- For models like gpt-4o, the input parameter of replace_in_file is treated as the latest file content, so here we need to insert a fake view tool call to ensure it uses the latest file content
if is_edit_func_call and path and not message.message.content[1].is_error then
local uniformed_path = Utils.uniform_path(path)
local view_result, view_error = require("avante.llm_tools.view").func({ path = path }, {})
if view_error then view_result = "Error: " .. view_error end
local get_diagnostics_tool_use_id = Utils.uuid()
local view_tool_use_id = Utils.uuid()
local view_tool_name = "view"
local view_tool_input = { path = path }
if is_str_replace_editor_func_call then
view_tool_name = "str_replace_editor"
view_tool_input = { command = "view", path = path }
end
if is_str_replace_based_edit_tool_func_call then
view_tool_name = "str_replace_based_edit_tool"
view_tool_input = { command = "view", path = path }
end
history_messages = vim.list_extend(history_messages, {
Message:new({
role = "assistant",
content = string.format("Viewing file %s to get the latest content", path),
}, {
is_dummy = true,
}),
Message:new({
role = "assistant",
content = {
{
type = "tool_use",
id = view_tool_use_id,
name = view_tool_name,
input = view_tool_input,
},
},
}, {
is_dummy = true,
}),
Message:new({
role = "user",
content = {
{
type = "tool_result",
tool_use_id = view_tool_use_id,
content = view_result,
is_error = view_error ~= nil,
is_user_declined = false,
},
},
}, {
is_dummy = true,
}),
})
if last_modified_files[uniformed_path] == idx and add_diagnostic then
local diagnostics = Utils.lsp.get_diagnostics_from_filepath(path)
history_messages = vim.list_extend(history_messages, {
Message:new({
role = "assistant",
content = string.format(
"The file %s has been modified, let me check if there are any errors in the changes.",
path
),
}, {
is_dummy = true,
}),
Message:new({
role = "assistant",
content = {
{
type = "tool_use",
id = get_diagnostics_tool_use_id,
name = "get_diagnostics",
input = { path = path },
},
},
}, {
is_dummy = true,
}),
Message:new({
role = "user",
content = {
{
type = "tool_result",
tool_use_id = get_diagnostics_tool_use_id,
content = vim.json.encode(diagnostics),
is_error = false,
is_user_declined = false,
},
},
}, {
is_dummy = true,
}),
})
end
end
end
end
for _, message in ipairs(history_messages) do
local content = message.message.content
if type(content) ~= "table" then goto continue end
for _, item in ipairs(content) do
if type(item) ~= "table" then goto continue1 end
if item.type ~= "tool_use" then goto continue1 end
local tool_name = item.name
if tool_name ~= "view" then goto continue1 end
local path = item.input.path
tool_id_to_tool_name[item.id] = tool_name
if path then
local uniform_path = Utils.uniform_path(path)
tool_id_to_path[item.id] = uniform_path
tool_id_to_start_line[item.id] = item.input.start_line
tool_id_to_end_line[item.id] = item.input.end_line
viewed_files[uniform_path] = item.id
end
::continue1::
end
::continue::
end
for _, message in ipairs(history_messages) do
local content = message.message.content
if type(content) == "table" then
for _, item in ipairs(content) do
if type(item) ~= "table" then goto continue end
if item.type ~= "tool_result" then goto continue end
local tool_name = tool_id_to_tool_name[item.tool_use_id]
if tool_name ~= "view" then goto continue end
if item.is_error then goto continue end
local path = tool_id_to_path[item.tool_use_id]
local latest_tool_id = viewed_files[path]
if not latest_tool_id then goto continue end
if latest_tool_id ~= item.tool_use_id then
item.content = string.format("The file %s has been updated. Please use the latest `view` tool result!", path)
else
local start_line = tool_id_to_start_line[item.tool_use_id]
local end_line = tool_id_to_end_line[item.tool_use_id]
local view_result, view_error = require("avante.llm_tools.view").func(
{ path = path, start_line = start_line, end_line = end_line },
{}
)
if view_error then view_result = "Error: " .. view_error end
item.content = view_result
item.is_error = view_error ~= nil
end
::continue::
end
end
end
if not using_ReAct_prompt then
local picked_messages = {}
local max_tool_use_count = 25
local tool_use_count = 0
for idx = #history_messages, 1, -1 do
local msg = history_messages[idx]
if tool_use_count > max_tool_use_count then
if Helpers.is_tool_result_message(msg) then
local tool_use_message = Helpers.get_tool_use_message(msg, history_messages)
if tool_use_message then
table.insert(
picked_messages,
1,
Message:new({
role = "user",
content = {
{
type = "text",
text = string.format(
"Tool use [%s] is successful: %s",
tool_use_message.message.content[1].name,
tostring(not msg.message.content[1].is_error)
),
},
},
}, { is_dummy = true })
)
local msg_content = {}
table.insert(msg_content, {
type = "text",
text = string.format(
"Tool use %s(%s)",
tool_use_message.message.content[1].name,
vim.json.encode(tool_use_message.message.content[1].input)
),
})
table.insert(
picked_messages,
1,
Message:new({ role = "assistant", content = msg_content }, { is_dummy = true })
)
end
elseif Helpers.is_tool_use_message(msg) then
tool_use_count = tool_use_count + 1
goto continue
else
table.insert(picked_messages, 1, msg)
end
else
if Helpers.is_tool_use_message(msg) then tool_use_count = tool_use_count + 1 end
table.insert(picked_messages, 1, msg)
end
::continue::
end
history_messages = picked_messages
end
local final_history_messages = {}
for _, msg in ipairs(history_messages) do
local tool_result_message
if Helpers.is_tool_use_message(msg) then
tool_result_message = Helpers.get_tool_result_message(msg, history_messages)
if not tool_result_message then goto continue end
end
if Helpers.is_tool_result_message(msg) then goto continue end
table.insert(final_history_messages, msg)
if tool_result_message then table.insert(final_history_messages, tool_result_message) end
::continue::
end
return final_history_messages
end
return M

View File

@@ -0,0 +1,31 @@
local Utils = require("avante.utils")
---@class avante.HistoryMessage
local M = {}
M.__index = M
---@param message AvanteLLMMessage
---@param opts? {is_user_submission?: boolean, visible?: boolean, displayed_content?: string, state?: avante.HistoryMessageState, uuid?: string, selected_filepaths?: string[], selected_code?: AvanteSelectedCode, just_for_display?: boolean, is_dummy?: boolean, turn_id?: string, is_calling?: boolean}
---@return avante.HistoryMessage
function M:new(message, opts)
opts = opts or {}
local obj = setmetatable({}, M)
obj.message = message
obj.uuid = opts.uuid or Utils.uuid()
obj.state = opts.state or "generated"
obj.timestamp = Utils.get_timestamp()
obj.is_user_submission = false
obj.visible = true
if opts.is_user_submission ~= nil then obj.is_user_submission = opts.is_user_submission end
if opts.visible ~= nil then obj.visible = opts.visible end
if opts.displayed_content ~= nil then obj.displayed_content = opts.displayed_content end
if opts.selected_filepaths ~= nil then obj.selected_filepaths = opts.selected_filepaths end
if opts.selected_code ~= nil then obj.selected_code = opts.selected_code end
if opts.just_for_display ~= nil then obj.just_for_display = opts.just_for_display end
if opts.is_dummy ~= nil then obj.is_dummy = opts.is_dummy end
if opts.turn_id ~= nil then obj.turn_id = opts.turn_id end
if opts.is_calling ~= nil then obj.is_calling = opts.is_calling end
return obj
end
return M