There are many places in the code that wants to work with content of "tool use" and "tool result" messages. Currently such code uses is_tool_use_message() and is_tool_result() message to check if message is of right kind, and then pokes into message internals. This is not very efficient. Introduce get_tool_use_data() and get_tool_result_data() that would return contents of the message if it is of right kind, or nil otherwise. Also introduce get_tool_result() that attempts to locate result of a tool execution by its invocation ID.
100 lines
3.3 KiB
Lua
100 lines
3.3 KiB
Lua
local Utils = require("avante.utils")
|
|
|
|
local M = {}
|
|
|
|
---If message is a "tool use" message returns information about the tool invocation.
|
|
---@param message avante.HistoryMessage
|
|
---@return AvanteLLMToolUse | nil
|
|
function M.get_tool_use_data(message)
|
|
local content = message.message.content
|
|
if type(content) == "table" then
|
|
assert(#content == 1, "more than one entry in message content")
|
|
local item = content[1]
|
|
if item.type == "tool_use" then
|
|
---@cast item AvanteLLMToolUse
|
|
return item
|
|
end
|
|
end
|
|
end
|
|
|
|
---@param message avante.HistoryMessage
|
|
---@return boolean
|
|
function M.is_tool_use_message(message) return M.get_tool_use_data(message) ~= nil end
|
|
|
|
---If message is a "tool result" message returns results of the tool invocation.
|
|
---@param message avante.HistoryMessage
|
|
---@return AvanteLLMToolResult | nil
|
|
function M.get_tool_result_data(message)
|
|
local content = message.message.content
|
|
if type(content) == "table" then
|
|
assert(#content == 1, "more than one entry in message content")
|
|
local item = content[1]
|
|
if item.type == "tool_result" then
|
|
---@cast item AvanteLLMToolResult
|
|
return item
|
|
end
|
|
end
|
|
end
|
|
|
|
---Attempts to locate result of a tool execution given tool invocation ID
|
|
---@param id string
|
|
---@param messages avante.HistoryMessage[]
|
|
---@return AvanteLLMToolResult | nil
|
|
function M.get_tool_result(id, messages)
|
|
for idx = #messages, 1, -1 do
|
|
local msg = messages[idx]
|
|
local result = M.get_tool_result_data(msg)
|
|
if result and result.tool_use_id == id then return result end
|
|
end
|
|
end
|
|
|
|
---@param message avante.HistoryMessage
|
|
---@return boolean
|
|
function M.is_tool_result_message(message) return M.get_tool_result_data(message) ~= nil end
|
|
|
|
---Given a tool result message locate corresponding tool use message
|
|
---@param message avante.HistoryMessage
|
|
---@param messages avante.HistoryMessage[]
|
|
---@return avante.HistoryMessage | nil
|
|
function M.get_tool_use_message(message, messages)
|
|
local result = M.get_tool_result_data(message)
|
|
if result then
|
|
for idx = #messages, 1, -1 do
|
|
local msg = messages[idx]
|
|
local use = M.get_tool_use_data(msg)
|
|
if use and use.id == result.tool_use_id then return msg end
|
|
end
|
|
end
|
|
end
|
|
|
|
---@param tool_use_message avante.HistoryMessage | nil
|
|
function M.is_edit_func_call_message(tool_use_message)
|
|
local is_replace_func_call = false
|
|
local is_str_replace_editor_func_call = false
|
|
local is_str_replace_based_edit_tool_func_call = false
|
|
local path = nil
|
|
if tool_use_message and M.is_tool_use_message(tool_use_message) then
|
|
local tool_use = tool_use_message.message.content[1]
|
|
---@cast tool_use AvanteLLMToolUse
|
|
return Utils.is_edit_func_call_tool_use(tool_use)
|
|
end
|
|
return is_replace_func_call, is_str_replace_editor_func_call, is_str_replace_based_edit_tool_func_call, path
|
|
end
|
|
|
|
---Given a tool use message locate corresponding tool result message
|
|
---@param message avante.HistoryMessage
|
|
---@param messages avante.HistoryMessage[]
|
|
---@return avante.HistoryMessage | nil
|
|
function M.get_tool_result_message(message, messages)
|
|
local use = M.get_tool_use_data(message)
|
|
if use then
|
|
for idx = #messages, 1, -1 do
|
|
local msg = messages[idx]
|
|
local result = M.get_tool_result_data(msg)
|
|
if result and result.tool_use_id == use.id then return msg end
|
|
end
|
|
end
|
|
end
|
|
|
|
return M
|