diff --git a/lua/avante/history/helpers.lua b/lua/avante/history/helpers.lua index aafd5a8..ad89d54 100644 --- a/lua/avante/history/helpers.lua +++ b/lua/avante/history/helpers.lua @@ -2,58 +2,69 @@ local Utils = require("avante.utils") local M = {} +---If message is a "tool use" message returns information about the tool invocation. ---@param message avante.HistoryMessage ----@return boolean -function M.is_tool_use_message(message) +---@return AvanteLLMToolUse | nil +function M.get_tool_use_data(message) local content = message.message.content - if type(content) == "string" then return false end - if vim.islist(content) then - for _, item in ipairs(content) do - if item.type == "tool_use" then return true end + if type(content) == "table" then + assert(#content == 1, "more than one entry in message content") + local item = content[1] + if item.type == "tool_use" then + ---@cast item AvanteLLMToolUse + return item end end - return false end ---@param message avante.HistoryMessage ---@return boolean -function M.is_tool_result_message(message) +function M.is_tool_use_message(message) return M.get_tool_use_data(message) ~= nil end + +---If message is a "tool result" message returns results of the tool invocation. +---@param message avante.HistoryMessage +---@return AvanteLLMToolResult | nil +function M.get_tool_result_data(message) local content = message.message.content - if type(content) == "string" then return false end - if vim.islist(content) then - for _, item in ipairs(content) do - if item.type == "tool_result" then return true end + if type(content) == "table" then + assert(#content == 1, "more than one entry in message content") + local item = content[1] + if item.type == "tool_result" then + ---@cast item AvanteLLMToolResult + return item end end - return false end +---Attempts to locate result of a tool execution given tool invocation ID +---@param id string +---@param messages avante.HistoryMessage[] +---@return AvanteLLMToolResult | nil +function M.get_tool_result(id, messages) + for idx = #messages, 1, -1 do + local msg = messages[idx] + local result = M.get_tool_result_data(msg) + if result and result.tool_use_id == id then return result end + end +end + +---@param message avante.HistoryMessage +---@return boolean +function M.is_tool_result_message(message) return M.get_tool_result_data(message) ~= nil end + +---Given a tool result message locate corresponding tool use message ---@param message avante.HistoryMessage ---@param messages avante.HistoryMessage[] ---@return avante.HistoryMessage | nil function M.get_tool_use_message(message, messages) - local content = message.message.content - if type(content) == "string" then return nil end - if vim.islist(content) then - local tool_id = nil - for _, item in ipairs(content) do - if item.type == "tool_result" then - tool_id = item.tool_use_id - break - end - end - if not tool_id then return nil end - for idx_ = #messages, 1, -1 do - local message_ = messages[idx_] - local content_ = message_.message.content - if type(content_) == "table" then - for _, item in ipairs(content_) do - if item.type == "tool_use" and item.id == tool_id then return message_ end - end - end + local result = M.get_tool_result_data(message) + if result then + for idx = #messages, 1, -1 do + local msg = messages[idx] + local use = M.get_tool_use_data(msg) + if use and use.id == result.tool_use_id then return msg end end end - return nil end ---@param tool_use_message avante.HistoryMessage | nil @@ -70,32 +81,19 @@ function M.is_edit_func_call_message(tool_use_message) return is_replace_func_call, is_str_replace_editor_func_call, is_str_replace_based_edit_tool_func_call, path end +---Given a tool use message locate corresponding tool result message ---@param message avante.HistoryMessage ---@param messages avante.HistoryMessage[] ---@return avante.HistoryMessage | nil function M.get_tool_result_message(message, messages) - local content = message.message.content - if type(content) == "string" then return nil end - if vim.islist(content) then - local tool_id = nil - for _, item in ipairs(content) do - if item.type == "tool_use" then - tool_id = item.id - break - end - end - if not tool_id then return nil end - for idx_ = #messages, 1, -1 do - local message_ = messages[idx_] - local content_ = message_.message.content - if type(content_) == "table" then - for _, item in ipairs(content_) do - if item.type == "tool_result" and item.tool_use_id == tool_id then return message_ end - end - end + local use = M.get_tool_use_data(message) + if use then + for idx = #messages, 1, -1 do + local msg = messages[idx] + local result = M.get_tool_result_data(msg) + if result and result.tool_use_id == use.id then return msg end end end - return nil end return M