refactor(history): move code collecting pending tools to history
Code dealing with scanning history messages and extracting some data or state belongs to avante/history/ so move it there. Along with the move update the implementation to make use of get_tool_use_data() and get_tool_result_data() helpers to simplify it. Also rename the function to History.get_pending_tools() to better reflect its purpose: "uncalled" means something that was done without request, unsolicited, not something that has not completed yet.
This commit is contained in:
@@ -257,4 +257,45 @@ M.update_history_messages = function(messages, using_ReAct_prompt, add_diagnosti
|
|||||||
return final_history_messages
|
return final_history_messages
|
||||||
end
|
end
|
||||||
|
|
||||||
|
---Scans message history backwards, looking for tool invocations that have not been executed yet
|
||||||
|
---@param messages avante.HistoryMessage[]
|
||||||
|
---@return AvantePartialLLMToolUse[]
|
||||||
|
---@return avante.HistoryMessage[]
|
||||||
|
function M.get_pending_tools(messages)
|
||||||
|
local last_turn_id = nil
|
||||||
|
if #messages > 0 then last_turn_id = messages[#messages].turn_id end
|
||||||
|
|
||||||
|
local pending_tool_uses = {} ---@type AvantePartialLLMToolUse[]
|
||||||
|
local pending_tool_uses_messages = {} ---@type avante.HistoryMessage[]
|
||||||
|
local tool_result_seen = {}
|
||||||
|
|
||||||
|
for idx = #messages, 1, -1 do
|
||||||
|
local message = messages[idx]
|
||||||
|
|
||||||
|
if last_turn_id and message.turn_id ~= last_turn_id then break end
|
||||||
|
|
||||||
|
local use = Helpers.get_tool_use_data(message)
|
||||||
|
if use then
|
||||||
|
if not tool_result_seen[use.id] then
|
||||||
|
local partial_tool_use = {
|
||||||
|
name = use.name,
|
||||||
|
id = use.id,
|
||||||
|
input = use.input,
|
||||||
|
state = message.state,
|
||||||
|
}
|
||||||
|
table.insert(pending_tool_uses, 1, partial_tool_use)
|
||||||
|
table.insert(pending_tool_uses_messages, 1, message)
|
||||||
|
end
|
||||||
|
goto continue
|
||||||
|
end
|
||||||
|
|
||||||
|
local result = Helpers.get_tool_result_data(message)
|
||||||
|
if result then tool_result_seen[result.tool_use_id] = true end
|
||||||
|
|
||||||
|
::continue::
|
||||||
|
end
|
||||||
|
|
||||||
|
return pending_tool_uses, pending_tool_uses_messages
|
||||||
|
end
|
||||||
|
|
||||||
return M
|
return M
|
||||||
|
|||||||
@@ -147,10 +147,10 @@ function M.generate_todos(user_input, cb)
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
if stop_opts.reason == "tool_use" then
|
if stop_opts.reason == "tool_use" then
|
||||||
local uncalled_tool_uses = Utils.get_uncalled_tool_uses(history_messages)
|
local pending_tools = History.get_pending_tools(history_messages)
|
||||||
for _, partial_tool_use in ipairs(uncalled_tool_uses) do
|
for _, pending_tool in ipairs(pending_tools) do
|
||||||
if partial_tool_use.state == "generated" and partial_tool_use.name == "add_todos" then
|
if pending_tool.state == "generated" and pending_tool.name == "add_todos" then
|
||||||
local result = LLMTools.process_tool_use(tools, partial_tool_use, {
|
local result = LLMTools.process_tool_use(tools, pending_tool, {
|
||||||
session_ctx = {},
|
session_ctx = {},
|
||||||
on_complete = function() cb() end,
|
on_complete = function() cb() end,
|
||||||
})
|
})
|
||||||
@@ -864,7 +864,7 @@ function M._stream(opts)
|
|||||||
return opts.on_stop({ reason = "cancelled" })
|
return opts.on_stop({ reason = "cancelled" })
|
||||||
end
|
end
|
||||||
local history_messages = opts.get_history_messages and opts.get_history_messages({ all = true }) or {}
|
local history_messages = opts.get_history_messages and opts.get_history_messages({ all = true }) or {}
|
||||||
local uncalled_tool_uses, uncalled_tool_uses_messages = Utils.get_uncalled_tool_uses(history_messages)
|
local pending_tools, pending_tool_use_messages = History.get_pending_tools(history_messages)
|
||||||
if stop_opts.reason == "complete" and Config.mode == "agentic" then
|
if stop_opts.reason == "complete" and Config.mode == "agentic" then
|
||||||
local completed_attempt_completion_tool_use = nil
|
local completed_attempt_completion_tool_use = nil
|
||||||
for idx = #history_messages, 1, -1 do
|
for idx = #history_messages, 1, -1 do
|
||||||
@@ -927,13 +927,7 @@ function M._stream(opts)
|
|||||||
end
|
end
|
||||||
if stop_opts.reason == "tool_use" then
|
if stop_opts.reason == "tool_use" then
|
||||||
opts.session_ctx.user_reminder_count = 0
|
opts.session_ctx.user_reminder_count = 0
|
||||||
return handle_next_tool_use(
|
return handle_next_tool_use(pending_tools, pending_tool_use_messages, 1, {}, stop_opts.streaming_tool_use)
|
||||||
uncalled_tool_uses,
|
|
||||||
uncalled_tool_uses_messages,
|
|
||||||
1,
|
|
||||||
{},
|
|
||||||
stop_opts.streaming_tool_use
|
|
||||||
)
|
|
||||||
end
|
end
|
||||||
if stop_opts.reason == "rate_limit" then
|
if stop_opts.reason == "rate_limit" then
|
||||||
local msg_content = "*[Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ...]*"
|
local msg_content = "*[Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ...]*"
|
||||||
|
|||||||
@@ -1698,50 +1698,6 @@ function M.tbl_override(value, override)
|
|||||||
return vim.tbl_extend("force", value, override)
|
return vim.tbl_extend("force", value, override)
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param history_messages avante.HistoryMessage[]
|
|
||||||
---@return AvantePartialLLMToolUse[]
|
|
||||||
---@return avante.HistoryMessage[]
|
|
||||||
function M.get_uncalled_tool_uses(history_messages)
|
|
||||||
local HistoryHelpers = require("avante.history.helpers")
|
|
||||||
local last_turn_id = nil
|
|
||||||
if #history_messages > 0 then last_turn_id = history_messages[#history_messages].turn_id end
|
|
||||||
local uncalled_tool_uses = {} ---@type AvantePartialLLMToolUse[]
|
|
||||||
local uncalled_tool_uses_messages = {} ---@type avante.HistoryMessage[]
|
|
||||||
local tool_result_seen = {}
|
|
||||||
for idx = #history_messages, 1, -1 do
|
|
||||||
local message = history_messages[idx]
|
|
||||||
if last_turn_id then
|
|
||||||
if message.turn_id ~= last_turn_id then break end
|
|
||||||
elseif not HistoryHelpers.is_tool_use_message(message) and not HistoryHelpers.is_tool_result_message(message) then
|
|
||||||
break
|
|
||||||
end
|
|
||||||
local content = message.message.content
|
|
||||||
if type(content) ~= "table" or #content == 0 then goto continue end
|
|
||||||
local is_break = false
|
|
||||||
for _, item in ipairs(content) do
|
|
||||||
if item.type == "tool_use" then
|
|
||||||
if not tool_result_seen[item.id] then
|
|
||||||
local partial_tool_use = {
|
|
||||||
name = item.name,
|
|
||||||
id = item.id,
|
|
||||||
input = item.input,
|
|
||||||
state = message.state,
|
|
||||||
}
|
|
||||||
table.insert(uncalled_tool_uses, 1, partial_tool_use)
|
|
||||||
table.insert(uncalled_tool_uses_messages, 1, message)
|
|
||||||
else
|
|
||||||
is_break = true
|
|
||||||
break
|
|
||||||
end
|
|
||||||
end
|
|
||||||
if item.type == "tool_result" then tool_result_seen[item.tool_use_id] = true end
|
|
||||||
end
|
|
||||||
if is_break then break end
|
|
||||||
::continue::
|
|
||||||
end
|
|
||||||
return uncalled_tool_uses, uncalled_tool_uses_messages
|
|
||||||
end
|
|
||||||
|
|
||||||
function M.call_once(func)
|
function M.call_once(func)
|
||||||
local called = false
|
local called = false
|
||||||
return function(...)
|
return function(...)
|
||||||
|
|||||||
Reference in New Issue
Block a user