diff --git a/lua/avante/history/helpers.lua b/lua/avante/history/helpers.lua index 66f7f4c..48aaa00 100644 --- a/lua/avante/history/helpers.lua +++ b/lua/avante/history/helpers.lua @@ -35,10 +35,6 @@ function M.get_tool_use_data(message) end end ----@param message avante.HistoryMessage ----@return boolean -function M.is_tool_use_message(message) return M.get_tool_use_data(message) ~= nil end - ---If message is a "tool result" message returns results of the tool invocation. ---@param message avante.HistoryMessage ---@return AvanteLLMToolResult | nil @@ -66,38 +62,43 @@ function M.get_tool_result(id, messages) end end +---Given a tool invocation ID locate corresponding tool use message +---@param id string +---@param messages avante.HistoryMessage[] +---@return avante.HistoryMessage | nil +function M.get_tool_use_message(id, messages) + for idx = #messages, 1, -1 do + local msg = messages[idx] + local use = M.get_tool_use_data(msg) + if use and use.id == id then return msg end + end +end + +---Given a tool invocation ID locate corresponding tool result message +---@param id string +---@param messages avante.HistoryMessage[] +---@return avante.HistoryMessage | nil +function M.get_tool_result_message(id, messages) + for idx = #messages, 1, -1 do + local msg = messages[idx] + local result = M.get_tool_result_data(msg) + if result and result.tool_use_id == id then return msg end + end +end + +---@param message avante.HistoryMessage +---@return boolean +function M.is_thinking_message(message) + local content = message.message.content + return type(content) == "table" and (content[1].type == "thinking" or content[1].type == "redacted_thinking") +end + ---@param message avante.HistoryMessage ---@return boolean function M.is_tool_result_message(message) return M.get_tool_result_data(message) ~= nil end ----Given a tool result message locate corresponding tool use message ---@param message avante.HistoryMessage ----@param messages avante.HistoryMessage[] ----@return avante.HistoryMessage | nil -function M.get_tool_use_message(message, messages) - local result = M.get_tool_result_data(message) - if result then - for idx = #messages, 1, -1 do - local msg = messages[idx] - local use = M.get_tool_use_data(msg) - if use and use.id == result.tool_use_id then return msg end - end - end -end - ----Given a tool use message locate corresponding tool result message ----@param message avante.HistoryMessage ----@param messages avante.HistoryMessage[] ----@return avante.HistoryMessage | nil -function M.get_tool_result_message(message, messages) - local use = M.get_tool_use_data(message) - if use then - for idx = #messages, 1, -1 do - local msg = messages[idx] - local result = M.get_tool_result_data(msg) - if result and result.tool_use_id == use.id then return msg end - end - end -end +---@return boolean +function M.is_tool_use_message(message) return M.get_tool_use_data(message) ~= nil end return M diff --git a/lua/avante/history/render.lua b/lua/avante/history/render.lua index 76b5e02..e5edcb3 100644 --- a/lua/avante/history/render.lua +++ b/lua/avante/history/render.lua @@ -104,7 +104,7 @@ local function message_content_item_to_lines(item, message, messages) elseif item.type == "tool_use" then local ok, llm_tool = pcall(require, "avante.llm_tools." .. item.name) if ok then - local tool_result_message = Helpers.get_tool_result_message(message, messages) + local tool_result_message = Helpers.get_tool_result_message(item.id, messages) ---@cast llm_tool AvanteLLMTool if llm_tool.on_render then return llm_tool.on_render(item.input, { diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 3031095..51851af 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -1724,8 +1724,7 @@ local function _get_message_lines(message, messages, ctx) return res end if message.message.role == "assistant" then - local content = message.message.content - if type(content) == "table" and content[1].type == "tool_use" then return lines end + if History.Helpers.is_tool_use_message(message) then return lines end local text = table.concat(vim.tbl_map(function(line) return tostring(line) end, lines), "\n") local transformed = transform_result_content(text, ctx.prev_filepath) ctx.prev_filepath = transformed.current_filepath @@ -2081,12 +2080,9 @@ function Sidebar:add_history_messages(messages) end local last_message = messages[#messages] if last_message then - local content = last_message.message.content - if type(content) == "table" and content[1].type == "tool_use" then + if History.Helpers.is_tool_use_message(last_message) then self.current_state = "tool calling" - elseif type(content) == "table" and content[1].type == "thinking" then - self.current_state = "thinking" - elseif type(content) == "table" and content[1].type == "redacted_thinking" then + elseif History.Helpers.is_thinking_message(last_message) then self.current_state = "thinking" else self.current_state = "generating" @@ -2514,24 +2510,18 @@ function Sidebar:create_input_container() ---@param state AvanteLLMToolUseState local function on_tool_log(tool_id, tool_name, log, state) if state == "generating" then on_state_change("tool calling") end - local tool_use_message = nil - for idx = #self.chat_history.messages, 1, -1 do - local message = self.chat_history.messages[idx] - local content = message.message.content - if type(content) == "table" and content[1].type == "tool_use" and content[1].id == tool_id then - tool_use_message = message - break - end - end + local tool_use_message = History.Helpers.get_tool_use_message(tool_id, self.chat_history.messages) if not tool_use_message then -- Utils.debug("tool_use message not found", tool_id, tool_name) return end + local tool_use_logs = tool_use_message.tool_use_logs or {} local content = string.format("[%s]: %s", tool_name, log) table.insert(tool_use_logs, content) - local orig_is_calling = tool_use_message.is_calling tool_use_message.tool_use_logs = tool_use_logs + + local orig_is_calling = tool_use_message.is_calling tool_use_message.is_calling = true self:update_content("") tool_use_message.is_calling = orig_is_calling @@ -2539,20 +2529,13 @@ function Sidebar:create_input_container() end local function set_tool_use_store(tool_id, key, value) - local tool_use_message = nil - for idx = #self.chat_history.messages, 1, -1 do - local message = self.chat_history.messages[idx] - local content = message.message.content - if type(content) == "table" and content[1].type == "tool_use" and content[1].id == tool_id then - tool_use_message = message - break - end + local tool_use_message = History.Helpers.get_tool_use_message(tool_id, self.chat_history.messages) + if tool_use_message then + local tool_use_store = tool_use_message.tool_use_store or {} + tool_use_store[key] = value + tool_use_message.tool_use_store = tool_use_store + self:save_history() end - if not tool_use_message then return end - local tool_use_store = tool_use_message.tool_use_store or {} - tool_use_store[key] = value - tool_use_message.tool_use_store = tool_use_store - self:save_history() end ---@type AvanteLLMStopCallback