refactor(history): add helpers to get content of tool "use" and "result"
There are many places in the code that wants to work with content of "tool use" and "tool result" messages. Currently such code uses is_tool_use_message() and is_tool_result() message to check if message is of right kind, and then pokes into message internals. This is not very efficient. Introduce get_tool_use_data() and get_tool_result_data() that would return contents of the message if it is of right kind, or nil otherwise. Also introduce get_tool_result() that attempts to locate result of a tool execution by its invocation ID.
This commit is contained in:
@@ -2,58 +2,69 @@ local Utils = require("avante.utils")
|
|||||||
|
|
||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
|
---If message is a "tool use" message returns information about the tool invocation.
|
||||||
---@param message avante.HistoryMessage
|
---@param message avante.HistoryMessage
|
||||||
---@return boolean
|
---@return AvanteLLMToolUse | nil
|
||||||
function M.is_tool_use_message(message)
|
function M.get_tool_use_data(message)
|
||||||
local content = message.message.content
|
local content = message.message.content
|
||||||
if type(content) == "string" then return false end
|
if type(content) == "table" then
|
||||||
if vim.islist(content) then
|
assert(#content == 1, "more than one entry in message content")
|
||||||
for _, item in ipairs(content) do
|
local item = content[1]
|
||||||
if item.type == "tool_use" then return true end
|
if item.type == "tool_use" then
|
||||||
|
---@cast item AvanteLLMToolUse
|
||||||
|
return item
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
return false
|
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param message avante.HistoryMessage
|
---@param message avante.HistoryMessage
|
||||||
---@return boolean
|
---@return boolean
|
||||||
function M.is_tool_result_message(message)
|
function M.is_tool_use_message(message) return M.get_tool_use_data(message) ~= nil end
|
||||||
|
|
||||||
|
---If message is a "tool result" message returns results of the tool invocation.
|
||||||
|
---@param message avante.HistoryMessage
|
||||||
|
---@return AvanteLLMToolResult | nil
|
||||||
|
function M.get_tool_result_data(message)
|
||||||
local content = message.message.content
|
local content = message.message.content
|
||||||
if type(content) == "string" then return false end
|
if type(content) == "table" then
|
||||||
if vim.islist(content) then
|
assert(#content == 1, "more than one entry in message content")
|
||||||
for _, item in ipairs(content) do
|
local item = content[1]
|
||||||
if item.type == "tool_result" then return true end
|
if item.type == "tool_result" then
|
||||||
|
---@cast item AvanteLLMToolResult
|
||||||
|
return item
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
return false
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
---Attempts to locate result of a tool execution given tool invocation ID
|
||||||
|
---@param id string
|
||||||
|
---@param messages avante.HistoryMessage[]
|
||||||
|
---@return AvanteLLMToolResult | nil
|
||||||
|
function M.get_tool_result(id, messages)
|
||||||
|
for idx = #messages, 1, -1 do
|
||||||
|
local msg = messages[idx]
|
||||||
|
local result = M.get_tool_result_data(msg)
|
||||||
|
if result and result.tool_use_id == id then return result end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
---@param message avante.HistoryMessage
|
||||||
|
---@return boolean
|
||||||
|
function M.is_tool_result_message(message) return M.get_tool_result_data(message) ~= nil end
|
||||||
|
|
||||||
|
---Given a tool result message locate corresponding tool use message
|
||||||
---@param message avante.HistoryMessage
|
---@param message avante.HistoryMessage
|
||||||
---@param messages avante.HistoryMessage[]
|
---@param messages avante.HistoryMessage[]
|
||||||
---@return avante.HistoryMessage | nil
|
---@return avante.HistoryMessage | nil
|
||||||
function M.get_tool_use_message(message, messages)
|
function M.get_tool_use_message(message, messages)
|
||||||
local content = message.message.content
|
local result = M.get_tool_result_data(message)
|
||||||
if type(content) == "string" then return nil end
|
if result then
|
||||||
if vim.islist(content) then
|
for idx = #messages, 1, -1 do
|
||||||
local tool_id = nil
|
local msg = messages[idx]
|
||||||
for _, item in ipairs(content) do
|
local use = M.get_tool_use_data(msg)
|
||||||
if item.type == "tool_result" then
|
if use and use.id == result.tool_use_id then return msg end
|
||||||
tool_id = item.tool_use_id
|
|
||||||
break
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
if not tool_id then return nil end
|
|
||||||
for idx_ = #messages, 1, -1 do
|
|
||||||
local message_ = messages[idx_]
|
|
||||||
local content_ = message_.message.content
|
|
||||||
if type(content_) == "table" then
|
|
||||||
for _, item in ipairs(content_) do
|
|
||||||
if item.type == "tool_use" and item.id == tool_id then return message_ end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
return nil
|
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param tool_use_message avante.HistoryMessage | nil
|
---@param tool_use_message avante.HistoryMessage | nil
|
||||||
@@ -70,32 +81,19 @@ function M.is_edit_func_call_message(tool_use_message)
|
|||||||
return is_replace_func_call, is_str_replace_editor_func_call, is_str_replace_based_edit_tool_func_call, path
|
return is_replace_func_call, is_str_replace_editor_func_call, is_str_replace_based_edit_tool_func_call, path
|
||||||
end
|
end
|
||||||
|
|
||||||
|
---Given a tool use message locate corresponding tool result message
|
||||||
---@param message avante.HistoryMessage
|
---@param message avante.HistoryMessage
|
||||||
---@param messages avante.HistoryMessage[]
|
---@param messages avante.HistoryMessage[]
|
||||||
---@return avante.HistoryMessage | nil
|
---@return avante.HistoryMessage | nil
|
||||||
function M.get_tool_result_message(message, messages)
|
function M.get_tool_result_message(message, messages)
|
||||||
local content = message.message.content
|
local use = M.get_tool_use_data(message)
|
||||||
if type(content) == "string" then return nil end
|
if use then
|
||||||
if vim.islist(content) then
|
for idx = #messages, 1, -1 do
|
||||||
local tool_id = nil
|
local msg = messages[idx]
|
||||||
for _, item in ipairs(content) do
|
local result = M.get_tool_result_data(msg)
|
||||||
if item.type == "tool_use" then
|
if result and result.tool_use_id == use.id then return msg end
|
||||||
tool_id = item.id
|
|
||||||
break
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
if not tool_id then return nil end
|
|
||||||
for idx_ = #messages, 1, -1 do
|
|
||||||
local message_ = messages[idx_]
|
|
||||||
local content_ = message_.message.content
|
|
||||||
if type(content_) == "table" then
|
|
||||||
for _, item in ipairs(content_) do
|
|
||||||
if item.type == "tool_result" and item.tool_use_id == tool_id then return message_ end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
return nil
|
|
||||||
end
|
end
|
||||||
|
|
||||||
return M
|
return M
|
||||||
|
|||||||
Reference in New Issue
Block a user