diff --git a/lua/avante/history/init.lua b/lua/avante/history/init.lua index 0608867..174160f 100644 --- a/lua/avante/history/init.lua +++ b/lua/avante/history/init.lua @@ -257,4 +257,45 @@ M.update_history_messages = function(messages, using_ReAct_prompt, add_diagnosti return final_history_messages end +---Scans message history backwards, looking for tool invocations that have not been executed yet +---@param messages avante.HistoryMessage[] +---@return AvantePartialLLMToolUse[] +---@return avante.HistoryMessage[] +function M.get_pending_tools(messages) + local last_turn_id = nil + if #messages > 0 then last_turn_id = messages[#messages].turn_id end + + local pending_tool_uses = {} ---@type AvantePartialLLMToolUse[] + local pending_tool_uses_messages = {} ---@type avante.HistoryMessage[] + local tool_result_seen = {} + + for idx = #messages, 1, -1 do + local message = messages[idx] + + if last_turn_id and message.turn_id ~= last_turn_id then break end + + local use = Helpers.get_tool_use_data(message) + if use then + if not tool_result_seen[use.id] then + local partial_tool_use = { + name = use.name, + id = use.id, + input = use.input, + state = message.state, + } + table.insert(pending_tool_uses, 1, partial_tool_use) + table.insert(pending_tool_uses_messages, 1, message) + end + goto continue + end + + local result = Helpers.get_tool_result_data(message) + if result then tool_result_seen[result.tool_use_id] = true end + + ::continue:: + end + + return pending_tool_uses, pending_tool_uses_messages +end + return M diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 879fcbe..e291aa5 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -147,10 +147,10 @@ function M.generate_todos(user_input, cb) return end if stop_opts.reason == "tool_use" then - local uncalled_tool_uses = Utils.get_uncalled_tool_uses(history_messages) - for _, partial_tool_use in ipairs(uncalled_tool_uses) do - if partial_tool_use.state == "generated" and partial_tool_use.name == "add_todos" then - local result = LLMTools.process_tool_use(tools, partial_tool_use, { + local pending_tools = History.get_pending_tools(history_messages) + for _, pending_tool in ipairs(pending_tools) do + if pending_tool.state == "generated" and pending_tool.name == "add_todos" then + local result = LLMTools.process_tool_use(tools, pending_tool, { session_ctx = {}, on_complete = function() cb() end, }) @@ -864,7 +864,7 @@ function M._stream(opts) return opts.on_stop({ reason = "cancelled" }) end local history_messages = opts.get_history_messages and opts.get_history_messages({ all = true }) or {} - local uncalled_tool_uses, uncalled_tool_uses_messages = Utils.get_uncalled_tool_uses(history_messages) + local pending_tools, pending_tool_use_messages = History.get_pending_tools(history_messages) if stop_opts.reason == "complete" and Config.mode == "agentic" then local completed_attempt_completion_tool_use = nil for idx = #history_messages, 1, -1 do @@ -927,13 +927,7 @@ function M._stream(opts) end if stop_opts.reason == "tool_use" then opts.session_ctx.user_reminder_count = 0 - return handle_next_tool_use( - uncalled_tool_uses, - uncalled_tool_uses_messages, - 1, - {}, - stop_opts.streaming_tool_use - ) + return handle_next_tool_use(pending_tools, pending_tool_use_messages, 1, {}, stop_opts.streaming_tool_use) end if stop_opts.reason == "rate_limit" then local msg_content = "*[Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ...]*" diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index 794f25b..c33a948 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -1698,50 +1698,6 @@ function M.tbl_override(value, override) return vim.tbl_extend("force", value, override) end ----@param history_messages avante.HistoryMessage[] ----@return AvantePartialLLMToolUse[] ----@return avante.HistoryMessage[] -function M.get_uncalled_tool_uses(history_messages) - local HistoryHelpers = require("avante.history.helpers") - local last_turn_id = nil - if #history_messages > 0 then last_turn_id = history_messages[#history_messages].turn_id end - local uncalled_tool_uses = {} ---@type AvantePartialLLMToolUse[] - local uncalled_tool_uses_messages = {} ---@type avante.HistoryMessage[] - local tool_result_seen = {} - for idx = #history_messages, 1, -1 do - local message = history_messages[idx] - if last_turn_id then - if message.turn_id ~= last_turn_id then break end - elseif not HistoryHelpers.is_tool_use_message(message) and not HistoryHelpers.is_tool_result_message(message) then - break - end - local content = message.message.content - if type(content) ~= "table" or #content == 0 then goto continue end - local is_break = false - for _, item in ipairs(content) do - if item.type == "tool_use" then - if not tool_result_seen[item.id] then - local partial_tool_use = { - name = item.name, - id = item.id, - input = item.input, - state = message.state, - } - table.insert(uncalled_tool_uses, 1, partial_tool_use) - table.insert(uncalled_tool_uses_messages, 1, message) - else - is_break = true - break - end - end - if item.type == "tool_result" then tool_result_seen[item.tool_use_id] = true end - end - if is_break then break end - ::continue:: - end - return uncalled_tool_uses, uncalled_tool_uses_messages -end - function M.call_once(func) local called = false return function(...)