From 34907fc1cd6347bd10d908f76ce100db334fbae8 Mon Sep 17 00:00:00 2001 From: Dmitry Torokhov Date: Tue, 8 Jul 2025 15:57:05 -0700 Subject: [PATCH] refactor(history): start moving history-related code into avante/history The utils module has grown too big and contains unrelated functionality. Start moving code related to managing history messages comprising chat history into lua/avante/history module to keep the code more manageable. --- lua/avante/history/helpers.lua | 101 ++++++ lua/avante/history/init.lua | 303 ++++++++++++++++++ .../message.lua} | 0 lua/avante/history_selector.lua | 3 +- lua/avante/llm.lua | 16 +- .../llm_tools/delete_tool_use_messages.lua | 6 +- lua/avante/llm_tools/dispatch_agent.lua | 6 +- lua/avante/path.lua | 5 +- lua/avante/providers/claude.lua | 2 +- lua/avante/providers/ollama.lua | 2 +- lua/avante/providers/openai.lua | 2 +- lua/avante/sidebar.lua | 282 +--------------- lua/avante/suggestion.lua | 2 +- lua/avante/utils/init.lua | 139 +------- 14 files changed, 448 insertions(+), 421 deletions(-) create mode 100644 lua/avante/history/helpers.lua create mode 100644 lua/avante/history/init.lua rename lua/avante/{history_message.lua => history/message.lua} (100%) diff --git a/lua/avante/history/helpers.lua b/lua/avante/history/helpers.lua new file mode 100644 index 0000000..aafd5a8 --- /dev/null +++ b/lua/avante/history/helpers.lua @@ -0,0 +1,101 @@ +local Utils = require("avante.utils") + +local M = {} + +---@param message avante.HistoryMessage +---@return boolean +function M.is_tool_use_message(message) + local content = message.message.content + if type(content) == "string" then return false end + if vim.islist(content) then + for _, item in ipairs(content) do + if item.type == "tool_use" then return true end + end + end + return false +end + +---@param message avante.HistoryMessage +---@return boolean +function M.is_tool_result_message(message) + local content = message.message.content + if type(content) == "string" then return false end + if vim.islist(content) then + for _, item in ipairs(content) do + if item.type == "tool_result" then return true end + end + end + return false +end + +---@param message avante.HistoryMessage +---@param messages avante.HistoryMessage[] +---@return avante.HistoryMessage | nil +function M.get_tool_use_message(message, messages) + local content = message.message.content + if type(content) == "string" then return nil end + if vim.islist(content) then + local tool_id = nil + for _, item in ipairs(content) do + if item.type == "tool_result" then + tool_id = item.tool_use_id + break + end + end + if not tool_id then return nil end + for idx_ = #messages, 1, -1 do + local message_ = messages[idx_] + local content_ = message_.message.content + if type(content_) == "table" then + for _, item in ipairs(content_) do + if item.type == "tool_use" and item.id == tool_id then return message_ end + end + end + end + end + return nil +end + +---@param tool_use_message avante.HistoryMessage | nil +function M.is_edit_func_call_message(tool_use_message) + local is_replace_func_call = false + local is_str_replace_editor_func_call = false + local is_str_replace_based_edit_tool_func_call = false + local path = nil + if tool_use_message and M.is_tool_use_message(tool_use_message) then + local tool_use = tool_use_message.message.content[1] + ---@cast tool_use AvanteLLMToolUse + return Utils.is_edit_func_call_tool_use(tool_use) + end + return is_replace_func_call, is_str_replace_editor_func_call, is_str_replace_based_edit_tool_func_call, path +end + +---@param message avante.HistoryMessage +---@param messages avante.HistoryMessage[] +---@return avante.HistoryMessage | nil +function M.get_tool_result_message(message, messages) + local content = message.message.content + if type(content) == "string" then return nil end + if vim.islist(content) then + local tool_id = nil + for _, item in ipairs(content) do + if item.type == "tool_use" then + tool_id = item.id + break + end + end + if not tool_id then return nil end + for idx_ = #messages, 1, -1 do + local message_ = messages[idx_] + local content_ = message_.message.content + if type(content_) == "table" then + for _, item in ipairs(content_) do + if item.type == "tool_result" and item.tool_use_id == tool_id then return message_ end + end + end + end + end + return nil +end + +return M diff --git a/lua/avante/history/init.lua b/lua/avante/history/init.lua new file mode 100644 index 0000000..e81f8c8 --- /dev/null +++ b/lua/avante/history/init.lua @@ -0,0 +1,303 @@ +local Helpers = require("avante.history.helpers") +local Message = require("avante.history.message") +local Utils = require("avante.utils") + +local M = {} + +M.Helpers = Helpers +M.Message = Message + +---@param history avante.ChatHistory +---@return avante.HistoryMessage[] +function M.get_history_messages(history) + if history.messages then return history.messages end + local messages = {} + for _, entry in ipairs(history.entries or {}) do + if entry.request and entry.request ~= "" then + local message = Message:new({ + role = "user", + content = entry.request, + }, { + timestamp = entry.timestamp, + is_user_submission = true, + visible = entry.visible, + selected_filepaths = entry.selected_filepaths, + selected_code = entry.selected_code, + }) + table.insert(messages, message) + end + if entry.response and entry.response ~= "" then + local message = Message:new({ + role = "assistant", + content = entry.response, + }, { + timestamp = entry.timestamp, + visible = entry.visible, + }) + table.insert(messages, message) + end + end + history.messages = messages + return messages +end + +---@param messages avante.HistoryMessage[] +---@param using_ReAct_prompt boolean +---@param add_diagnostic boolean Mix in LSP diagnostic info for affected files +---@return avante.HistoryMessage[] +M.update_history_messages = function(messages, using_ReAct_prompt, add_diagnostic) + local tool_id_to_tool_name = {} + local tool_id_to_path = {} + local tool_id_to_start_line = {} + local tool_id_to_end_line = {} + local viewed_files = {} + local last_modified_files = {} + local history_messages = {} + + for idx, message in ipairs(messages) do + if Helpers.is_tool_result_message(message) then + local tool_use_message = Helpers.get_tool_use_message(message, messages) + + local is_edit_func_call, _, _, path = Helpers.is_edit_func_call_message(tool_use_message) + + -- Only track as successful modification if not an error AND not user-declined + if + is_edit_func_call + and path + and not message.message.content[1].is_error + and not message.message.content[1].is_user_declined + then + local uniformed_path = Utils.uniform_path(path) + last_modified_files[uniformed_path] = idx + end + end + end + + for idx, message in ipairs(messages) do + table.insert(history_messages, message) + if Helpers.is_tool_result_message(message) then + local tool_use_message = Helpers.get_tool_use_message(message, messages) + local is_edit_func_call, is_str_replace_editor_func_call, is_str_replace_based_edit_tool_func_call, path = + Helpers.is_edit_func_call_message(tool_use_message) + --- For models like gpt-4o, the input parameter of replace_in_file is treated as the latest file content, so here we need to insert a fake view tool call to ensure it uses the latest file content + if is_edit_func_call and path and not message.message.content[1].is_error then + local uniformed_path = Utils.uniform_path(path) + local view_result, view_error = require("avante.llm_tools.view").func({ path = path }, {}) + if view_error then view_result = "Error: " .. view_error end + local get_diagnostics_tool_use_id = Utils.uuid() + local view_tool_use_id = Utils.uuid() + local view_tool_name = "view" + local view_tool_input = { path = path } + if is_str_replace_editor_func_call then + view_tool_name = "str_replace_editor" + view_tool_input = { command = "view", path = path } + end + if is_str_replace_based_edit_tool_func_call then + view_tool_name = "str_replace_based_edit_tool" + view_tool_input = { command = "view", path = path } + end + history_messages = vim.list_extend(history_messages, { + Message:new({ + role = "assistant", + content = string.format("Viewing file %s to get the latest content", path), + }, { + is_dummy = true, + }), + Message:new({ + role = "assistant", + content = { + { + type = "tool_use", + id = view_tool_use_id, + name = view_tool_name, + input = view_tool_input, + }, + }, + }, { + is_dummy = true, + }), + Message:new({ + role = "user", + content = { + { + type = "tool_result", + tool_use_id = view_tool_use_id, + content = view_result, + is_error = view_error ~= nil, + is_user_declined = false, + }, + }, + }, { + is_dummy = true, + }), + }) + if last_modified_files[uniformed_path] == idx and add_diagnostic then + local diagnostics = Utils.lsp.get_diagnostics_from_filepath(path) + history_messages = vim.list_extend(history_messages, { + Message:new({ + role = "assistant", + content = string.format( + "The file %s has been modified, let me check if there are any errors in the changes.", + path + ), + }, { + is_dummy = true, + }), + Message:new({ + role = "assistant", + content = { + { + type = "tool_use", + id = get_diagnostics_tool_use_id, + name = "get_diagnostics", + input = { path = path }, + }, + }, + }, { + is_dummy = true, + }), + Message:new({ + role = "user", + content = { + { + type = "tool_result", + tool_use_id = get_diagnostics_tool_use_id, + content = vim.json.encode(diagnostics), + is_error = false, + is_user_declined = false, + }, + }, + }, { + is_dummy = true, + }), + }) + end + end + end + end + for _, message in ipairs(history_messages) do + local content = message.message.content + if type(content) ~= "table" then goto continue end + for _, item in ipairs(content) do + if type(item) ~= "table" then goto continue1 end + if item.type ~= "tool_use" then goto continue1 end + local tool_name = item.name + if tool_name ~= "view" then goto continue1 end + local path = item.input.path + tool_id_to_tool_name[item.id] = tool_name + if path then + local uniform_path = Utils.uniform_path(path) + tool_id_to_path[item.id] = uniform_path + tool_id_to_start_line[item.id] = item.input.start_line + tool_id_to_end_line[item.id] = item.input.end_line + viewed_files[uniform_path] = item.id + end + ::continue1:: + end + ::continue:: + end + for _, message in ipairs(history_messages) do + local content = message.message.content + if type(content) == "table" then + for _, item in ipairs(content) do + if type(item) ~= "table" then goto continue end + if item.type ~= "tool_result" then goto continue end + local tool_name = tool_id_to_tool_name[item.tool_use_id] + if tool_name ~= "view" then goto continue end + if item.is_error then goto continue end + local path = tool_id_to_path[item.tool_use_id] + local latest_tool_id = viewed_files[path] + if not latest_tool_id then goto continue end + if latest_tool_id ~= item.tool_use_id then + item.content = string.format("The file %s has been updated. Please use the latest `view` tool result!", path) + else + local start_line = tool_id_to_start_line[item.tool_use_id] + local end_line = tool_id_to_end_line[item.tool_use_id] + local view_result, view_error = require("avante.llm_tools.view").func( + { path = path, start_line = start_line, end_line = end_line }, + {} + ) + if view_error then view_result = "Error: " .. view_error end + item.content = view_result + item.is_error = view_error ~= nil + end + ::continue:: + end + end + end + + if not using_ReAct_prompt then + local picked_messages = {} + local max_tool_use_count = 25 + local tool_use_count = 0 + for idx = #history_messages, 1, -1 do + local msg = history_messages[idx] + if tool_use_count > max_tool_use_count then + if Helpers.is_tool_result_message(msg) then + local tool_use_message = Helpers.get_tool_use_message(msg, history_messages) + if tool_use_message then + table.insert( + picked_messages, + 1, + Message:new({ + role = "user", + content = { + { + type = "text", + text = string.format( + "Tool use [%s] is successful: %s", + tool_use_message.message.content[1].name, + tostring(not msg.message.content[1].is_error) + ), + }, + }, + }, { is_dummy = true }) + ) + local msg_content = {} + table.insert(msg_content, { + type = "text", + text = string.format( + "Tool use %s(%s)", + tool_use_message.message.content[1].name, + vim.json.encode(tool_use_message.message.content[1].input) + ), + }) + table.insert( + picked_messages, + 1, + Message:new({ role = "assistant", content = msg_content }, { is_dummy = true }) + ) + end + elseif Helpers.is_tool_use_message(msg) then + tool_use_count = tool_use_count + 1 + goto continue + else + table.insert(picked_messages, 1, msg) + end + else + if Helpers.is_tool_use_message(msg) then tool_use_count = tool_use_count + 1 end + table.insert(picked_messages, 1, msg) + end + ::continue:: + end + + history_messages = picked_messages + end + + local final_history_messages = {} + for _, msg in ipairs(history_messages) do + local tool_result_message + if Helpers.is_tool_use_message(msg) then + tool_result_message = Helpers.get_tool_result_message(msg, history_messages) + if not tool_result_message then goto continue end + end + if Helpers.is_tool_result_message(msg) then goto continue end + table.insert(final_history_messages, msg) + if tool_result_message then table.insert(final_history_messages, tool_result_message) end + ::continue:: + end + + return final_history_messages +end + +return M diff --git a/lua/avante/history_message.lua b/lua/avante/history/message.lua similarity index 100% rename from lua/avante/history_message.lua rename to lua/avante/history/message.lua diff --git a/lua/avante/history_selector.lua b/lua/avante/history_selector.lua index 18c6e0a..f86db9f 100644 --- a/lua/avante/history_selector.lua +++ b/lua/avante/history_selector.lua @@ -1,3 +1,4 @@ +local History = require("avante.history") local Utils = require("avante.utils") local Path = require("avante.path") local Config = require("avante.config") @@ -9,7 +10,7 @@ local M = {} ---@param history avante.ChatHistory ---@return table? local function to_selector_item(history) - local messages = Utils.get_history_messages(history) + local messages = History.get_history_messages(history) local timestamp = #messages > 0 and messages[#messages].timestamp or history.timestamp local name = history.title .. " - " .. timestamp .. " (" .. #messages .. ")" name = name:gsub("\n", "\\n") diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index d817fdf..cabfbd2 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -11,7 +11,7 @@ local Path = require("avante.path") local Providers = require("avante.providers") local LLMToolHelpers = require("avante.llm_tools.helpers") local LLMTools = require("avante.llm_tools") -local HistoryMessage = require("avante.history_message") +local History = require("avante.history") ---@class avante.LLM local M = {} @@ -771,7 +771,7 @@ function M._stream(opts) ---@type avante.HistoryMessage[] local messages = {} for _, tool_result in ipairs(tool_results) do - messages[#messages + 1] = HistoryMessage:new({ + messages[#messages + 1] = History.Message:new({ role = "user", content = { { @@ -816,7 +816,7 @@ function M._stream(opts) Utils.debug("Tool execution was cancelled by user") if opts.on_chunk then opts.on_chunk("\n*[Request cancelled by user during tool execution.]*\n") end if opts.on_messages_add then - local message = HistoryMessage:new({ + local message = History.Message:new({ role = "assistant", content = "\n*[Request cancelled by user during tool execution.]*\n", }, { @@ -868,7 +868,7 @@ function M._stream(opts) if stop_opts.reason == "cancelled" then if opts.on_chunk then opts.on_chunk("\n*[Request cancelled by user.]*\n") end if opts.on_messages_add then - local message = HistoryMessage:new({ + local message = History.Message:new({ role = "assistant", content = "\n*[Request cancelled by user.]*\n", }, { @@ -885,7 +885,7 @@ function M._stream(opts) for idx = #history_messages, 1, -1 do local message = history_messages[idx] if message.is_user_submission then break end - if not Utils.is_tool_use_message(message) then goto continue end + if not History.Helpers.is_tool_use_message(message) then goto continue end if message.message.content[1].name ~= "attempt_completion" then break end completed_attempt_completion_tool_use = message if message then break end @@ -909,14 +909,14 @@ function M._stream(opts) Utils.debug("user reminder count", user_reminder_count) local message if #unfinished_todos > 0 then - message = HistoryMessage:new({ + message = History.Message:new({ role = "user", content = "You should use tool calls to answer the question, for example, use update_todo_status if the task step is done or cancelled.", }, { visible = false, }) else - message = HistoryMessage:new({ + message = History.Message:new({ role = "user", content = "You should use tool calls to answer the question, for example, use attempt_completion if the job is done.", }, { @@ -954,7 +954,7 @@ function M._stream(opts) if opts.on_chunk then opts.on_chunk("\n" .. msg_content .. "\n") end local message if opts.on_messages_add then - message = HistoryMessage:new({ + message = History.Message:new({ role = "assistant", content = "\n\n" .. msg_content, }, { diff --git a/lua/avante/llm_tools/delete_tool_use_messages.lua b/lua/avante/llm_tools/delete_tool_use_messages.lua index afcbbde..5348343 100644 --- a/lua/avante/llm_tools/delete_tool_use_messages.lua +++ b/lua/avante/llm_tools/delete_tool_use_messages.lua @@ -1,5 +1,5 @@ local Base = require("avante.llm_tools.base") -local Utils = require("avante.utils") +local History = require("avante.history") ---@class AvanteLLMTool local M = setmetatable({}, Base) @@ -43,10 +43,10 @@ M.returns = { function M.func(input, opts) local sidebar = require("avante").get() if not sidebar then return false, "Avante sidebar not found" end - local history_messages = Utils.get_history_messages(sidebar.chat_history) + local history_messages = History.get_history_messages(sidebar.chat_history) local the_deleted_message_uuids = {} for _, msg in ipairs(history_messages) do - if Utils.is_tool_use_message(msg) then + if History.Helpers.is_tool_use_message(msg) then local content = msg.message.content if type(content) == "table" then for _, item in ipairs(content) do diff --git a/lua/avante/llm_tools/dispatch_agent.lua b/lua/avante/llm_tools/dispatch_agent.lua index b1e9c27..97a8c07 100644 --- a/lua/avante/llm_tools/dispatch_agent.lua +++ b/lua/avante/llm_tools/dispatch_agent.lua @@ -2,7 +2,7 @@ local Providers = require("avante.providers") local Config = require("avante.config") local Utils = require("avante.utils") local Base = require("avante.llm_tools.base") -local HistoryMessage = require("avante.history_message") +local History = require("avante.history") local Line = require("avante.ui.line") local Highlights = require("avante.highlights") @@ -94,7 +94,7 @@ function M.on_render(input, opts) local content = msg.message.content local summary if type(content) == "table" and #content > 0 and content[1].type == "tool_use" then - local tool_result_message = Utils.get_tool_result_message(msg, messages) + local tool_result_message = History.Helpers.get_tool_result_message(msg, messages) if tool_result_message then local tool_name = msg.message.content[1].name if tool_name == "ls" then @@ -267,7 +267,7 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub .. elapsed_time .. "s)" if session_ctx.on_messages_add then - local message = HistoryMessage:new({ + local message = History.Message:new({ role = "assistant", content = "\n\n" .. summary, }, { diff --git a/lua/avante/path.lua b/lua/avante/path.lua index 6cff139..10db0a7 100644 --- a/lua/avante/path.lua +++ b/lua/avante/path.lua @@ -53,10 +53,11 @@ function History.list(bufnr) --- sort by timestamp --- sort by latest_filename table.sort(res, function(a, b) + local H = require("avante.history") if a.filename == latest_filename then return true end if b.filename == latest_filename then return false end - local a_messages = Utils.get_history_messages(a) - local b_messages = Utils.get_history_messages(b) + local a_messages = H.get_history_messages(a) + local b_messages = H.get_history_messages(b) local timestamp_a = #a_messages > 0 and a_messages[#a_messages].timestamp or a.timestamp local timestamp_b = #b_messages > 0 and b_messages[#b_messages].timestamp or b.timestamp return timestamp_a > timestamp_b diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index abf7ac3..af686f8 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -1,7 +1,7 @@ local Utils = require("avante.utils") local Clipboard = require("avante.clipboard") local P = require("avante.providers") -local HistoryMessage = require("avante.history_message") +local HistoryMessage = require("avante.history.message") local JsonParser = require("avante.libs.jsonparser") ---@class AvanteProviderFunctor diff --git a/lua/avante/providers/ollama.lua b/lua/avante/providers/ollama.lua index dc261f9..fe49589 100644 --- a/lua/avante/providers/ollama.lua +++ b/lua/avante/providers/ollama.lua @@ -2,7 +2,7 @@ local Utils = require("avante.utils") local Providers = require("avante.providers") local Config = require("avante.config") local Clipboard = require("avante.clipboard") -local HistoryMessage = require("avante.history_message") +local HistoryMessage = require("avante.history.message") local Prompts = require("avante.utils.prompts") ---@class AvanteProviderFunctor diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index 05eb649..c65f208 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -2,7 +2,7 @@ local Utils = require("avante.utils") local Config = require("avante.config") local Clipboard = require("avante.clipboard") local Providers = require("avante.providers") -local HistoryMessage = require("avante.history_message") +local HistoryMessage = require("avante.history.message") local ReActParser = require("avante.libs.ReAct_parser") local JsonParser = require("avante.libs.jsonparser") local Prompts = require("avante.utils.prompts") diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index abec1e9..e709469 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -16,7 +16,7 @@ local Highlights = require("avante.highlights") local RepoMap = require("avante.repo_map") local FileSelector = require("avante.file_selector") local LLMTools = require("avante.llm_tools") -local HistoryMessage = require("avante.history_message") +local History = require("avante.history") local Line = require("avante.ui.line") local LRUCache = require("avante.utils.lru_cache") @@ -1755,7 +1755,7 @@ end ---@param history avante.ChatHistory ---@return avante.ui.Line[] function Sidebar.get_history_lines(history) - local history_messages = Utils.get_history_messages(history) + local history_messages = History.get_history_messages(history) local ctx = {} ---@type avante.ui.Line[][] local group = {} @@ -1825,7 +1825,7 @@ end ---@param history avante.ChatHistory ---@return string function Sidebar.render_history_content(history) - local history_messages = Utils.get_history_messages(history) + local history_messages = History.get_history_messages(history) local ctx = {} local group = {} for _, message in ipairs(history_messages) do @@ -1985,7 +1985,7 @@ end function Sidebar:compact_history_messages(args, cb) local history_memory = self.chat_history.memory - local messages = Utils.get_history_messages(self.chat_history) + local messages = History.get_history_messages(self.chat_history) self.current_state = "compacting" self:render_state() self:update_content( @@ -2023,7 +2023,7 @@ function Sidebar:save_history() debounced_save_history(self) end ---@param uuids string[] function Sidebar:delete_history_messages(uuids) - local history_messages = Utils.get_history_messages(self.chat_history) + local history_messages = History.get_history_messages(self.chat_history) for _, msg in ipairs(history_messages) do if vim.list_contains(uuids, msg.uuid) then msg.is_deleted = true end end @@ -2041,7 +2041,7 @@ end ---@param messages avante.HistoryMessage | avante.HistoryMessage[] function Sidebar:add_history_messages(messages) - local history_messages = Utils.get_history_messages(self.chat_history) + local history_messages = History.get_history_messages(self.chat_history) messages = vim.islist(messages) and messages or { messages } for _, message in ipairs(messages) do if message.is_user_submission then @@ -2110,7 +2110,7 @@ function Sidebar:add_chat_history(messages, options) self.chat_history.system_prompt = content goto continue end - local history_message = HistoryMessage:new(message) + local history_message = History.Message:new(message) if message.role == "user" and is_first_user then is_first_user = false history_message.is_user_submission = true @@ -2254,8 +2254,7 @@ end ---@return avante.HistoryMessage[] function Sidebar:get_history_messages_for_api(opts) opts = opts or {} - local history_messages0 = Utils.get_history_messages(self.chat_history) - self.chat_history.messages = history_messages0 + local history_messages0 = History.get_history_messages(self.chat_history) history_messages0 = vim .iter(history_messages0) @@ -2264,8 +2263,6 @@ function Sidebar:get_history_messages_for_api(opts) if opts.all then return history_messages0 end - local use_ReAct_prompt = Providers[Config.provider].use_ReAct_prompt ~= nil - history_messages0 = vim .iter(history_messages0) :filter(function(message) return message.state ~= "generating" end) @@ -2281,258 +2278,11 @@ function Sidebar:get_history_messages_for_api(opts) history_messages0 = picked_messages end - local tool_id_to_tool_name = {} - local tool_id_to_path = {} - local tool_id_to_start_line = {} - local tool_id_to_end_line = {} - local viewed_files = {} - local last_modified_files = {} - local history_messages = {} - - for idx, message in ipairs(history_messages0) do - if Utils.is_tool_result_message(message) then - local tool_use_message = Utils.get_tool_use_message(message, history_messages0) - - local is_edit_func_call, _, _, path = Utils.is_edit_func_call_message(tool_use_message) - - -- Only track as successful modification if not an error AND not user-declined - if - is_edit_func_call - and path - and not message.message.content[1].is_error - and not message.message.content[1].is_user_declined - then - local uniformed_path = Utils.uniform_path(path) - last_modified_files[uniformed_path] = idx - end - end - end - - for idx, message in ipairs(history_messages0) do - table.insert(history_messages, message) - if Utils.is_tool_result_message(message) then - local tool_use_message = Utils.get_tool_use_message(message, history_messages0) - local is_edit_func_call, is_str_replace_editor_func_call, is_str_replace_based_edit_tool_func_call, path = - Utils.is_edit_func_call_message(tool_use_message) - --- For models like gpt-4o, the input parameter of replace_in_file is treated as the latest file content, so here we need to insert a fake view tool call to ensure it uses the latest file content - if is_edit_func_call and path and not message.message.content[1].is_error then - local uniformed_path = Utils.uniform_path(path) - local view_result, view_error = require("avante.llm_tools.view").func({ path = path }, {}) - if view_error then view_result = "Error: " .. view_error end - local get_diagnostics_tool_use_id = Utils.uuid() - local view_tool_use_id = Utils.uuid() - local view_tool_name = "view" - local view_tool_input = { path = path } - if is_str_replace_editor_func_call then - view_tool_name = "str_replace_editor" - view_tool_input = { command = "view", path = path } - end - if is_str_replace_based_edit_tool_func_call then - view_tool_name = "str_replace_based_edit_tool" - view_tool_input = { command = "view", path = path } - end - history_messages = vim.list_extend(history_messages, { - HistoryMessage:new({ - role = "assistant", - content = string.format("Viewing file %s to get the latest content", path), - }, { - is_dummy = true, - }), - HistoryMessage:new({ - role = "assistant", - content = { - { - type = "tool_use", - id = view_tool_use_id, - name = view_tool_name, - input = view_tool_input, - }, - }, - }, { - is_dummy = true, - }), - HistoryMessage:new({ - role = "user", - content = { - { - type = "tool_result", - tool_use_id = view_tool_use_id, - content = view_result, - is_error = view_error ~= nil, - is_user_declined = false, - }, - }, - }, { - is_dummy = true, - }), - }) - if last_modified_files[uniformed_path] == idx and Config.behaviour.auto_check_diagnostics then - local diagnostics = Utils.lsp.get_diagnostics_from_filepath(path) - history_messages = vim.list_extend(history_messages, { - HistoryMessage:new({ - role = "assistant", - content = string.format( - "The file %s has been modified, let me check if there are any errors in the changes.", - path - ), - }, { - is_dummy = true, - }), - HistoryMessage:new({ - role = "assistant", - content = { - { - type = "tool_use", - id = get_diagnostics_tool_use_id, - name = "get_diagnostics", - input = { path = path }, - }, - }, - }, { - is_dummy = true, - }), - HistoryMessage:new({ - role = "user", - content = { - { - type = "tool_result", - tool_use_id = get_diagnostics_tool_use_id, - content = vim.json.encode(diagnostics), - is_error = false, - is_user_declined = false, - }, - }, - }, { - is_dummy = true, - }), - }) - end - end - end - end - for _, message in ipairs(history_messages) do - local content = message.message.content - if type(content) ~= "table" then goto continue end - for _, item in ipairs(content) do - if type(item) ~= "table" then goto continue1 end - if item.type ~= "tool_use" then goto continue1 end - local tool_name = item.name - if tool_name ~= "view" then goto continue1 end - local path = item.input.path - tool_id_to_tool_name[item.id] = tool_name - if path then - local uniform_path = Utils.uniform_path(path) - tool_id_to_path[item.id] = uniform_path - tool_id_to_start_line[item.id] = item.input.start_line - tool_id_to_end_line[item.id] = item.input.end_line - viewed_files[uniform_path] = item.id - end - ::continue1:: - end - ::continue:: - end - for _, message in ipairs(history_messages) do - local content = message.message.content - if type(content) == "table" then - for _, item in ipairs(content) do - if type(item) ~= "table" then goto continue end - if item.type ~= "tool_result" then goto continue end - local tool_name = tool_id_to_tool_name[item.tool_use_id] - if tool_name ~= "view" then goto continue end - if item.is_error then goto continue end - local path = tool_id_to_path[item.tool_use_id] - local latest_tool_id = viewed_files[path] - if not latest_tool_id then goto continue end - if latest_tool_id ~= item.tool_use_id then - item.content = string.format("The file %s has been updated. Please use the latest `view` tool result!", path) - else - local start_line = tool_id_to_start_line[item.tool_use_id] - local end_line = tool_id_to_end_line[item.tool_use_id] - local view_result, view_error = require("avante.llm_tools.view").func( - { path = path, start_line = start_line, end_line = end_line }, - {} - ) - if view_error then view_result = "Error: " .. view_error end - item.content = view_result - item.is_error = view_error ~= nil - end - ::continue:: - end - end - end - - if not use_ReAct_prompt then - local picked_messages = {} - local max_tool_use_count = 25 - local tool_use_count = 0 - for idx = #history_messages, 1, -1 do - local msg = history_messages[idx] - if tool_use_count > max_tool_use_count then - if Utils.is_tool_result_message(msg) then - local tool_use_message = Utils.get_tool_use_message(msg, history_messages) - if tool_use_message then - table.insert( - picked_messages, - 1, - HistoryMessage:new({ - role = "user", - content = { - { - type = "text", - text = string.format( - "Tool use [%s] is successful: %s", - tool_use_message.message.content[1].name, - tostring(not msg.message.content[1].is_error) - ), - }, - }, - }, { is_dummy = true }) - ) - local msg_content = {} - table.insert(msg_content, { - type = "text", - text = string.format( - "Tool use %s(%s)", - tool_use_message.message.content[1].name, - vim.json.encode(tool_use_message.message.content[1].input) - ), - }) - table.insert( - picked_messages, - 1, - HistoryMessage:new({ role = "assistant", content = msg_content }, { is_dummy = true }) - ) - end - elseif Utils.is_tool_use_message(msg) then - tool_use_count = tool_use_count + 1 - goto continue - else - table.insert(picked_messages, 1, msg) - end - else - if Utils.is_tool_use_message(msg) then tool_use_count = tool_use_count + 1 end - table.insert(picked_messages, 1, msg) - end - ::continue:: - end - - history_messages = picked_messages - end - - local final_history_messages = {} - for _, msg in ipairs(history_messages) do - local tool_result_message - if Utils.is_tool_use_message(msg) then - tool_result_message = Utils.get_tool_result_message(msg, history_messages) - if not tool_result_message then goto continue end - end - if Utils.is_tool_result_message(msg) then goto continue end - table.insert(final_history_messages, msg) - if tool_result_message then table.insert(final_history_messages, tool_result_message) end - ::continue:: - end - - return final_history_messages + return History.update_history_messages( + history_messages0, + Providers[Config.provider].use_ReAct_prompt ~= nil, + Config.behaviour.auto_check_diagnostics + ) end ---@param request string @@ -2651,7 +2401,7 @@ function Sidebar:create_input_container() if self.is_generating then self:add_history_messages({ - HistoryMessage:new({ role = "user", content = request }), + History.Message:new({ role = "user", content = request }), }) return end @@ -2802,7 +2552,7 @@ function Sidebar:create_input_container() local msg_content = stop_opts.error if type(msg_content) ~= "string" then msg_content = vim.inspect(msg_content) end self:add_history_messages({ - HistoryMessage:new({ + History.Message:new({ role = "assistant", content = "\n\nError: " .. msg_content, }, { @@ -2831,7 +2581,7 @@ function Sidebar:create_input_container() if request and request ~= "" then self:add_history_messages({ - HistoryMessage:new({ + History.Message:new({ role = "user", content = request, }, { diff --git a/lua/avante/suggestion.lua b/lua/avante/suggestion.lua index 3635600..47b9241 100644 --- a/lua/avante/suggestion.lua +++ b/lua/avante/suggestion.lua @@ -3,7 +3,7 @@ local Llm = require("avante.llm") local Highlights = require("avante.highlights") local Config = require("avante.config") local Providers = require("avante.providers") -local HistoryMessage = require("avante.history_message") +local HistoryMessage = require("avante.history.message") local api = vim.api local fn = vim.fn diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index 2ece994..794f25b 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -1473,41 +1473,6 @@ function M.get_commands() return vim.list_extend(builtin_commands, Config.slash_commands) end ----@param history avante.ChatHistory ----@return avante.HistoryMessage[] -function M.get_history_messages(history) - local HistoryMessage = require("avante.history_message") - if history.messages then return history.messages end - local messages = {} - for _, entry in ipairs(history.entries or {}) do - if entry.request and entry.request ~= "" then - local message = HistoryMessage:new({ - role = "user", - content = entry.request, - }, { - timestamp = entry.timestamp, - is_user_submission = true, - visible = entry.visible, - selected_filepaths = entry.selected_filepaths, - selected_code = entry.selected_code, - }) - table.insert(messages, message) - end - if entry.response and entry.response ~= "" then - local message = HistoryMessage:new({ - role = "assistant", - content = entry.response, - }, { - timestamp = entry.timestamp, - visible = entry.visible, - }) - table.insert(messages, message) - end - end - history.messages = messages - return messages -end - function M.get_timestamp() return tostring(os.date("%Y-%m-%d %H:%M:%S")) end ---@param history_messages avante.HistoryMessage[] @@ -1530,60 +1495,6 @@ function M.uuid() end) end ----@param message avante.HistoryMessage ----@return boolean -function M.is_tool_use_message(message) - local content = message.message.content - if type(content) == "string" then return false end - if vim.islist(content) then - for _, item in ipairs(content) do - if item.type == "tool_use" then return true end - end - end - return false -end - ----@param message avante.HistoryMessage ----@return boolean -function M.is_tool_result_message(message) - local content = message.message.content - if type(content) == "string" then return false end - if vim.islist(content) then - for _, item in ipairs(content) do - if item.type == "tool_result" then return true end - end - end - return false -end - ----@param message avante.HistoryMessage ----@param messages avante.HistoryMessage[] ----@return avante.HistoryMessage | nil -function M.get_tool_use_message(message, messages) - local content = message.message.content - if type(content) == "string" then return nil end - if vim.islist(content) then - local tool_id = nil - for _, item in ipairs(content) do - if item.type == "tool_result" then - tool_id = item.tool_use_id - break - end - end - if not tool_id then return nil end - for idx_ = #messages, 1, -1 do - local message_ = messages[idx_] - local content_ = message_.message.content - if type(content_) == "table" then - for _, item in ipairs(content_) do - if item.type == "tool_use" and item.id == tool_id then return message_ end - end - end - end - end - return nil -end - ---@param tool_use AvanteLLMToolUse function M.tool_use_to_xml(tool_use) local xml = string.format("\n<%s>\n", tool_use.name) @@ -1621,48 +1532,6 @@ function M.is_edit_func_call_tool_use(tool_use) return is_replace_func_call, is_str_replace_editor_func_call, is_str_replace_based_edit_tool_func_call, path end ----@param tool_use_message avante.HistoryMessage | nil -function M.is_edit_func_call_message(tool_use_message) - local is_replace_func_call = false - local is_str_replace_editor_func_call = false - local is_str_replace_based_edit_tool_func_call = false - local path = nil - if tool_use_message and M.is_tool_use_message(tool_use_message) then - local tool_use = tool_use_message.message.content[1] - ---@cast tool_use AvanteLLMToolUse - return M.is_edit_func_call_tool_use(tool_use) - end - return is_replace_func_call, is_str_replace_editor_func_call, is_str_replace_based_edit_tool_func_call, path -end - ----@param message avante.HistoryMessage ----@param messages avante.HistoryMessage[] ----@return avante.HistoryMessage | nil -function M.get_tool_result_message(message, messages) - local content = message.message.content - if type(content) == "string" then return nil end - if vim.islist(content) then - local tool_id = nil - for _, item in ipairs(content) do - if item.type == "tool_use" then - tool_id = item.id - break - end - end - if not tool_id then return nil end - for idx_ = #messages, 1, -1 do - local message_ = messages[idx_] - local content_ = message_.message.content - if type(content_) == "table" then - for _, item in ipairs(content_) do - if item.type == "tool_result" and item.tool_use_id == tool_id then return message_ end - end - end - end - end - return nil -end - ---@param text string ---@param hl string | nil ---@return avante.ui.Line[] @@ -1701,6 +1570,7 @@ end ---@return avante.ui.Line[] function M.message_content_item_to_lines(item, message, messages) local Line = require("avante.ui.line") + local HistoryHelpers = require("avante.history.helpers") if type(item) == "string" then return M.text_to_lines(item) end if type(item) == "table" then if item.type == "thinking" or item.type == "redacted_thinking" then @@ -1711,7 +1581,7 @@ function M.message_content_item_to_lines(item, message, messages) return { Line:new({ { "![image](" .. item.source.media_type .. ": " .. item.source.data .. ")" } }) } end if item.type == "tool_use" then - local tool_result_message = M.get_tool_result_message(message, messages) + local tool_result_message = HistoryHelpers.get_tool_result_message(message, messages) local lines = {} local state = "generating" local hl = "AvanteStateSpinnerToolCalling" @@ -1832,6 +1702,7 @@ end ---@return AvantePartialLLMToolUse[] ---@return avante.HistoryMessage[] function M.get_uncalled_tool_uses(history_messages) + local HistoryHelpers = require("avante.history.helpers") local last_turn_id = nil if #history_messages > 0 then last_turn_id = history_messages[#history_messages].turn_id end local uncalled_tool_uses = {} ---@type AvantePartialLLMToolUse[] @@ -1841,8 +1712,8 @@ function M.get_uncalled_tool_uses(history_messages) local message = history_messages[idx] if last_turn_id then if message.turn_id ~= last_turn_id then break end - else - if not M.is_tool_use_message(message) and not M.is_tool_result_message(message) then break end + elseif not HistoryHelpers.is_tool_use_message(message) and not HistoryHelpers.is_tool_result_message(message) then + break end local content = message.message.content if type(content) ~= "table" or #content == 0 then goto continue end