diff --git a/lua/avante/history/helpers.lua b/lua/avante/history/helpers.lua index ad89d54..4284f71 100644 --- a/lua/avante/history/helpers.lua +++ b/lua/avante/history/helpers.lua @@ -67,20 +67,6 @@ function M.get_tool_use_message(message, messages) end end ----@param tool_use_message avante.HistoryMessage | nil -function M.is_edit_func_call_message(tool_use_message) - local is_replace_func_call = false - local is_str_replace_editor_func_call = false - local is_str_replace_based_edit_tool_func_call = false - local path = nil - if tool_use_message and M.is_tool_use_message(tool_use_message) then - local tool_use = tool_use_message.message.content[1] - ---@cast tool_use AvanteLLMToolUse - return Utils.is_edit_func_call_tool_use(tool_use) - end - return is_replace_func_call, is_str_replace_editor_func_call, is_str_replace_based_edit_tool_func_call, path -end - ---Given a tool use message locate corresponding tool result message ---@param message avante.HistoryMessage ---@param messages avante.HistoryMessage[] diff --git a/lua/avante/history/init.lua b/lua/avante/history/init.lua index 3e43ca1..b88ba53 100644 --- a/lua/avante/history/init.lua +++ b/lua/avante/history/init.lua @@ -41,219 +41,286 @@ function M.get_history_messages(history) return messages end +---Represents information about tool use: invocation, result, affected file (for "view" or "edit" tools). +---@class HistoryToolInfo +---@field kind "edit" | "view" | "other" +---@field use AvanteLLMToolUse +---@field result? AvanteLLMToolResult +---@field result_message? avante.HistoryMessage Complete result message +---@field path? string Uniform (normalized) path of the affected file + +---@class HistoryFileInfo +---@field last_tool_id? string ID of the tool with most up-to-date state of the file +---@field edit_tool_id? string ID of the last tool done edit on the file + +---Collects information about all uses of tools in the history: their invocations, results, and affected files. ---@param messages avante.HistoryMessage[] ----@param using_ReAct_prompt boolean ----@param add_diagnostic boolean Mix in LSP diagnostic info for affected files ----@return avante.HistoryMessage[] -M.update_history_messages = function(messages, using_ReAct_prompt, add_diagnostic) - local tool_id_to_tool_name = {} - local tool_id_to_path = {} - local tool_id_to_start_line = {} - local tool_id_to_end_line = {} - local viewed_files = {} - local last_modified_files = {} - local history_messages = {} +---@return table +---@return table +local function collect_tool_info(messages) + ---@type table Maps tool ID to tool information + local tools = {} + ---@type table Maps file path to file information + local files = {} - for idx, message in ipairs(messages) do - if Helpers.is_tool_result_message(message) then - local tool_use_message = Helpers.get_tool_use_message(message, messages) - - local is_edit_func_call, _, _, path = Helpers.is_edit_func_call_message(tool_use_message) - - -- Only track as successful modification if not an error AND not user-declined - if - is_edit_func_call - and path - and not message.message.content[1].is_error - and not message.message.content[1].is_user_declined - then - local uniformed_path = Utils.uniform_path(path) - last_modified_files[uniformed_path] = idx - end - end - end - - for idx, message in ipairs(messages) do - table.insert(history_messages, message) - if Helpers.is_tool_result_message(message) then - local tool_use_message = Helpers.get_tool_use_message(message, messages) - local is_edit_func_call, is_str_replace_editor_func_call, is_str_replace_based_edit_tool_func_call, path = - Helpers.is_edit_func_call_message(tool_use_message) - --- For models like gpt-4o, the input parameter of replace_in_file is treated as the latest file content, so here we need to insert a fake view tool call to ensure it uses the latest file content - if is_edit_func_call and path and not message.message.content[1].is_error then - local uniformed_path = Utils.uniform_path(path) - local view_result, view_error = require("avante.llm_tools.view").func({ path = path }, {}) - if view_error then view_result = "Error: " .. view_error end - local get_diagnostics_tool_use_id = Utils.uuid() - local view_tool_use_id = Utils.uuid() - local view_tool_name = "view" - local view_tool_input = { path = path } - if is_str_replace_editor_func_call then - view_tool_name = "str_replace_editor" - view_tool_input = { command = "view", path = path } - end - if is_str_replace_based_edit_tool_func_call then - view_tool_name = "str_replace_based_edit_tool" - view_tool_input = { command = "view", path = path } - end - history_messages = vim.list_extend(history_messages, { - Message:new_assistant_synthetic(string.format("Viewing file %s to get the latest content", path)), - Message:new_assistant_synthetic({ - type = "tool_use", - id = view_tool_use_id, - name = view_tool_name, - input = view_tool_input, - }), - Message:new_user_synthetic({ - type = "tool_result", - tool_use_id = view_tool_use_id, - content = view_result, - is_error = view_error ~= nil, - is_user_declined = false, - }), - }) - if last_modified_files[uniformed_path] == idx and add_diagnostic then - local diagnostics = Utils.lsp.get_diagnostics_from_filepath(path) - history_messages = vim.list_extend(history_messages, { - Message:new_assistant_synthetic( - string.format( - "The file %s has been modified, let me check if there are any errors in the changes.", - path - ) - ), - Message:new_assistant_synthetic({ - type = "tool_use", - id = get_diagnostics_tool_use_id, - name = "get_diagnostics", - input = { path = path }, - }), - Message:new_user_synthetic({ - type = "tool_result", - tool_use_id = get_diagnostics_tool_use_id, - content = vim.json.encode(diagnostics), - is_error = false, - is_user_declined = false, - }), - }) - end - end - end - end - for _, message in ipairs(history_messages) do - local content = message.message.content - if type(content) ~= "table" then goto continue end - for _, item in ipairs(content) do - if type(item) ~= "table" then goto continue1 end - if item.type ~= "tool_use" then goto continue1 end - local tool_name = item.name - if tool_name ~= "view" then goto continue1 end - local path = item.input.path - tool_id_to_tool_name[item.id] = tool_name - if path then - local uniform_path = Utils.uniform_path(path) - tool_id_to_path[item.id] = uniform_path - tool_id_to_start_line[item.id] = item.input.start_line - tool_id_to_end_line[item.id] = item.input.end_line - viewed_files[uniform_path] = item.id - end - ::continue1:: - end - ::continue:: - end - for _, message in ipairs(history_messages) do - local content = message.message.content - if type(content) == "table" then - for _, item in ipairs(content) do - if type(item) ~= "table" then goto continue end - if item.type ~= "tool_result" then goto continue end - local tool_name = tool_id_to_tool_name[item.tool_use_id] - if tool_name ~= "view" then goto continue end - if item.is_error then goto continue end - local path = tool_id_to_path[item.tool_use_id] - local latest_tool_id = viewed_files[path] - if not latest_tool_id then goto continue end - if latest_tool_id ~= item.tool_use_id then - item.content = string.format("The file %s has been updated. Please use the latest `view` tool result!", path) - else - local start_line = tool_id_to_start_line[item.tool_use_id] - local end_line = tool_id_to_end_line[item.tool_use_id] - local view_result, view_error = require("avante.llm_tools.view").func( - { path = path, start_line = start_line, end_line = end_line }, - {} - ) - if view_error then view_result = "Error: " .. view_error end - item.content = view_result - item.is_error = view_error ~= nil - end - ::continue:: - end - end - end - - if not using_ReAct_prompt then - local picked_messages = {} - local max_tool_use_count = 25 - local tool_use_count = 0 - for idx = #history_messages, 1, -1 do - local msg = history_messages[idx] - if tool_use_count > max_tool_use_count then - if Helpers.is_tool_result_message(msg) then - local tool_use_message = Helpers.get_tool_use_message(msg, history_messages) - if tool_use_message then - table.insert( - picked_messages, - 1, - Message:new_user_synthetic({ - type = "text", - text = string.format( - "Tool use [%s] is successful: %s", - tool_use_message.message.content[1].name, - tostring(not msg.message.content[1].is_error) - ), - }) - ) - table.insert( - picked_messages, - 1, - Message:new_assistant_synthetic( - string.format( - "Tool use %s(%s)", - tool_use_message.message.content[1].name, - vim.json.encode(tool_use_message.message.content[1].input) - ) - ) - ) - end - elseif Helpers.is_tool_use_message(msg) then - tool_use_count = tool_use_count + 1 - goto continue - else - table.insert(picked_messages, 1, msg) + -- Collect invocations of all tools, and also build a list of viewed or edited files. + for _, message in ipairs(messages) do + local use = Helpers.get_tool_use_data(message) + if use then + if use.name == "view" or Utils.is_edit_tool_use(use) then + if use.input.path then + local path = Utils.uniform_path(use.input.path) + tools[use.id] = { kind = use.name == "view" and "view" or "edit", use = use, path = path } end else - if Helpers.is_tool_use_message(msg) then tool_use_count = tool_use_count + 1 end - table.insert(picked_messages, 1, msg) + tools[use.id] = { kind = "other", use = use } end - ::continue:: + goto continue end - history_messages = picked_messages - end - - local final_history_messages = {} - for _, msg in ipairs(history_messages) do - local tool_result_message - if Helpers.is_tool_use_message(msg) then - tool_result_message = Helpers.get_tool_result_message(msg, history_messages) - if not tool_result_message then goto continue end + local result = Helpers.get_tool_result_data(message) + if result then + -- We assume that "result" entries always come after corresponding "use" entries. + local info = tools[result.tool_use_id] + if info then + info.result = result + info.result_message = message + if info.path then + local f = files[info.path] + if not f then + f = {} + files[info.path] = f + end + f.last_tool_id = result.tool_use_id + if info.kind == "edit" and not (result.is_error or result.is_user_declined) then + f.edit_tool_id = result.tool_use_id + end + end + end end - if Helpers.is_tool_result_message(msg) then goto continue end - table.insert(final_history_messages, msg) - if tool_result_message then table.insert(final_history_messages, tool_result_message) end + ::continue:: end - return final_history_messages + return tools, files +end + +---Converts a tool invocation (use + result) into a simple request/response pair of text messages +---@param tool_info HistoryToolInfo +---@return avante.HistoryMessage[] +local function convert_tool_to_text(tool_info) + return { + Message:new_assistant_synthetic( + string.format("Tool use %s(%s)", tool_info.use.name, vim.json.encode(tool_info.use.input)) + ), + Message:new_user_synthetic({ + type = "text", + text = string.format( + "Tool use [%s] is successful: %s", + tool_info.use.name, + tostring(not tool_info.result.is_error) + ), + }), + } +end + +---Generates a fake file "content" telling LLM to look further for up-to-date data +---@param path string +---@return string +local function stale_view_content(path) + return string.format("The file %s has been updated. Please use the latest `view` tool result!", path) +end + +---Updates the result of "view" tool invocation with latest contents of a buffer or file, +---or a stub message if this result will be superseded by another one. +---@param tool_info HistoryToolInfo +---@param stale_view boolean +local function update_view_result(tool_info, stale_view) + local use = tool_info.use + local result = tool_info.result + + if stale_view then + result.content = stale_view_content(tool_info.path) + else + local view_result, view_error = require("avante.llm_tools.view").func( + { path = tool_info.path, start_line = use.input.start_line, end_line = use.input.end_line }, + {} + ) + result.content = view_error and ("Error: " .. view_error) or view_result + result.is_error = view_error ~= nil + end +end + +---Generates synthetic "view" tool invocation to tell LLM to refresh its view of a file after editing +---@param tool_use AvanteLLMToolUse +---@param path any +---@param stale_view any +---@return avante.HistoryMessage[] +local function generate_view_messages(tool_use, path, stale_view) + local view_result, view_error + if stale_view then + view_result = stale_view_content(path) + else + view_result, view_error = require("avante.llm_tools.view").func({ path = path }, {}) + end + + if view_error then view_result = "Error: " .. view_error end + + local view_tool_use_id = Utils.uuid() + local view_tool_name = "view" + local view_tool_input = { path = path } + + if tool_use.name == "str_replace_editor" and tool_use.input.command == "str_replace" then + view_tool_name = "str_replace_editor" + view_tool_input.command = "view" + elseif tool_use.name == "str_replace_based_edit_tool" and tool_use.input.command == "str_replace" then + view_tool_name = "str_replace_based_edit_tool" + view_tool_input.command = "view" + end + + return { + Message:new_assistant_synthetic(string.format("Viewing file %s to get the latest content", path)), + Message:new_assistant_synthetic({ + type = "tool_use", + id = view_tool_use_id, + name = view_tool_name, + input = view_tool_input, + }), + Message:new_user_synthetic({ + type = "tool_result", + tool_use_id = view_tool_use_id, + content = view_result, + is_error = view_error ~= nil, + is_user_declined = false, + }), + } +end + +---Generates "diagnostic" for a file after it has been edited to help catching errors +---@param path string +---@return avante.HistoryMessage[] +local function generate_diagnostic_messages(path) + local get_diagnostics_tool_use_id = Utils.uuid() + local diagnostics = Utils.lsp.get_diagnostics_from_filepath(path) + return { + Message:new_assistant_synthetic( + string.format("The file %s has been modified, let me check if there are any errors in the changes.", path) + ), + Message:new_assistant_synthetic({ + type = "tool_use", + id = get_diagnostics_tool_use_id, + name = "get_diagnostics", + input = { path = path }, + }), + Message:new_user_synthetic({ + type = "tool_result", + tool_use_id = get_diagnostics_tool_use_id, + content = vim.json.encode(diagnostics), + is_error = false, + is_user_declined = false, + }), + } +end + +---Iterate through history messages and generate a new list containing updated history +---that has up-to-date file contents and potentially updated diagnostic for modified +---files. +---@param messages avante.HistoryMessage[] +---@param tools HistoryToolInfo[] +---@param files HistoryFileInfo[] +---@param add_diagnostic boolean Whether to generate and add diagnostic info to "edit" invocations +---@param tools_to_text integer Number of tool invocations to be converted to simple text +---@return avante.HistoryMessage[] +local function refresh_history(messages, tools, files, add_diagnostic, tools_to_text) + ---@type avante.HistoryMessage[] + local updated_messages = {} + local tool_count = 0 + + for _, message in ipairs(messages) do + local use = Helpers.get_tool_use_data(message) + if use then + -- This is a tool invocation message. We will be handling both use and result together. + local tool_info = tools[use.id] + if not tool_info.result then goto continue end + + if tool_count < tools_to_text then + local text_msgs = convert_tool_to_text(tool_info) + Utils.debug("Converted", use.name, "invocation to", #text_msgs, "messages") + updated_messages = vim.list_extend(updated_messages, text_msgs) + else + table.insert(updated_messages, message) + table.insert(updated_messages, tool_info.result_message) + tool_count = tool_count + 1 + + if tool_info.kind == "view" then + local path = tool_info.path + assert(path, "encountered 'view' tool invocation without path") + update_view_result(tool_info, use.id ~= files[tool_info.path].last_tool_id) + end + end + + if tool_info.kind == "edit" then + local path = tool_info.path + assert(path, "encountered 'edit' tool invocation without path") + local file_info = files[path] + + -- If this is the last operation for this file, generate synthetic "view" + -- invocation to provide the up-to-date file contents. + if not tool_info.result.is_error then + local view_msgs = generate_view_messages(use, path, use.id == file_info.last_tool_id) + Utils.debug("Added", #view_msgs, "'view' tool messages for", path) + updated_messages = vim.list_extend(updated_messages, view_msgs) + tool_count = tool_count + 1 + end + + if add_diagnostic and use.id == file_info.edit_tool_id then + local diag_msgs = generate_diagnostic_messages(path) + Utils.debug("Added", #diag_msgs, "'diagnostics' tool messages for", path) + updated_messages = vim.list_extend(updated_messages, diag_msgs) + tool_count = tool_count + 1 + end + end + elseif not Helpers.get_tool_result_data(message) then + -- Skip the tool result messages (since we process them together with their "use"s. + -- All other (non-tool-related) messages we simply keep. + table.insert(updated_messages, message) + end + + ::continue:: + end + + return updated_messages +end + +---Analyzes the history looking for tool invocations, drops incomplete invocations, +---and updates complete ones with the latest data available. +---@param messages avante.HistoryMessage[] +---@param max_tool_use integer | nil Maximum number of tool invocations to keep +---@param add_diagnostic boolean Mix in LSP diagnostic info for affected files +---@return avante.HistoryMessage[] +M.update_tool_invocation_history = function(messages, max_tool_use, add_diagnostic) + local tools, files = collect_tool_info(messages) + + -- Figure number of tool invocations that should be converted to simple "text" + -- messages to reduce prompt costs. + local tools_to_text = 0 + if max_tool_use then + local n_edits = vim.iter(files):fold( + 0, + ---@param count integer + ---@param file_info HistoryFileInfo + function(count, file_info) + if file_info.edit_tool_id then count = count + 1 end + return count + end + ) + -- Each valid "edit" invocation will result in synthetic "view" and also + -- in "diagnostic" if it is requested by the caller. + local expected = #tools + n_edits + (add_diagnostic and n_edits or 0) + tools_to_text = expected - max_tool_use + end + + return refresh_history(messages, tools, files, add_diagnostic, tools_to_text) end ---Scans message history backwards, looking for tool invocations that have not been executed yet diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index ef6a637..22695b7 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -824,7 +824,7 @@ function M._stream(opts) table.insert(tool_results, tool_result) return handle_next_tool_use(tool_uses, tool_use_messages, tool_use_index + 1, tool_results) end - local is_edit_tool_use = Utils.is_edit_func_call_tool_use(partial_tool_use) + local is_edit_tool_use = Utils.is_edit_tool_use(partial_tool_use) local support_streaming = false local llm_tool = vim.iter(prompt_opts.tools):find(function(tool) return tool.name == partial_tool_use.name end) if llm_tool then support_streaming = llm_tool.support_streaming == true end diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 7ca1617..016274c 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -2255,35 +2255,50 @@ end ---@return avante.HistoryMessage[] function Sidebar:get_history_messages_for_api(opts) opts = opts or {} - local history_messages0 = History.get_history_messages(self.chat_history) + local messages = History.get_history_messages(self.chat_history) - history_messages0 = vim - .iter(history_messages0) - :filter(function(message) return not message.just_for_display and not message.is_compacted end) + -- Scan the initial set of messages, filtering out "uninteresting" ones, but also + -- check if the last message mentioned in the chat memory is actually present. + local last_message = self.chat_history.memory and self.chat_history.memory.last_message_uuid + local last_message_present = false + messages = vim + .iter(messages) + :filter(function(message) + if message.just_for_display or message.is_compacted then return false end + if not opts.all then + if message.state == "generating" then return false end + if last_message and message.uuid == last_message then last_message_present = true end + end + return true + end) :totable() - if opts.all then return history_messages0 end - - history_messages0 = vim - .iter(history_messages0) - :filter(function(message) return message.state ~= "generating" end) - :totable() - - if self.chat_history and self.chat_history.memory then - local picked_messages = {} - for idx = #history_messages0, 1, -1 do - local message = history_messages0[idx] - if message.uuid == self.chat_history.memory.last_message_uuid then break end - table.insert(picked_messages, 1, message) + if not opts.all then + if last_message and last_message_present then + -- Drop all old messages preceding the "last" one from the memory + local last_message_seen = false + messages = vim + .iter(messages) + :filter(function(message) + if not last_message_seen then + if message.uuid == last_message then last_message_seen = true end + return false + end + return true + end) + :totable() end - history_messages0 = picked_messages + + local tool_limit + if Providers[Config.provider].use_ReAct_prompt then + tool_limit = nil + else + tool_limit = 25 + end + messages = History.update_tool_invocation_history(messages, tool_limit, Config.behaviour.auto_check_diagnostics) end - return History.update_history_messages( - history_messages0, - Providers[Config.provider].use_ReAct_prompt ~= nil, - Config.behaviour.auto_check_diagnostics - ) + return messages end ---@param request string diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index db359c9..7720987 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -1496,30 +1496,10 @@ function M.tool_use_to_xml(tool_use) end ---@param tool_use AvanteLLMToolUse -function M.is_edit_func_call_tool_use(tool_use) - local is_replace_func_call = false - local is_str_replace_editor_func_call = false - local is_str_replace_based_edit_tool_func_call = false - local path = nil - if tool_use.name == "replace_in_file" then - is_replace_func_call = true - path = tool_use.input.path - end - if tool_use.name == "str_replace_editor" then - if tool_use.input.command == "str_replace" then - is_replace_func_call = true - is_str_replace_editor_func_call = true - path = tool_use.input.path - end - end - if tool_use.name == "str_replace_based_edit_tool" then - if tool_use.input.command == "str_replace" then - is_replace_func_call = true - is_str_replace_based_edit_tool_func_call = true - path = tool_use.input.path - end - end - return is_replace_func_call, is_str_replace_editor_func_call, is_str_replace_based_edit_tool_func_call, path +function M.is_edit_tool_use(tool_use) + return tool_use.name == "replace_in_file" + or (tool_use.name == "str_replace_editor" and tool_use.input.command == "str_replace") + or (tool_use.name == "str_replace_based_edit_tool" and tool_use.input.command == "str_replace") end ---Counts number of strings in text, accounting for possibility of a trailing newline