diff --git a/lua/avante/history_message.lua b/lua/avante/history_message.lua index afb0cb4..0937dcc 100644 --- a/lua/avante/history_message.lua +++ b/lua/avante/history_message.lua @@ -5,7 +5,7 @@ local M = {} M.__index = M ---@param message AvanteLLMMessage ----@param opts? {is_user_submission?: boolean, visible?: boolean, displayed_content?: string, state?: avante.HistoryMessageState, uuid?: string, selected_filepaths?: string[], selected_code?: AvanteSelectedCode, just_for_display?: boolean, is_dummy?: boolean, turn_id?: string} +---@param opts? {is_user_submission?: boolean, visible?: boolean, displayed_content?: string, state?: avante.HistoryMessageState, uuid?: string, selected_filepaths?: string[], selected_code?: AvanteSelectedCode, just_for_display?: boolean, is_dummy?: boolean, turn_id?: string, is_calling?: boolean} ---@return avante.HistoryMessage function M:new(message, opts) opts = opts or {} @@ -23,7 +23,8 @@ function M:new(message, opts) if opts.selected_code ~= nil then obj.selected_code = opts.selected_code end if opts.just_for_display ~= nil then obj.just_for_display = opts.just_for_display end if opts.is_dummy ~= nil then obj.is_dummy = opts.is_dummy end - obj.turn_id = opts.turn_id + if opts.turn_id ~= nil then obj.turn_id = opts.turn_id end + if opts.is_calling ~= nil then obj.is_calling = opts.is_calling end return obj end diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index d53a315..ec3b9f0 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -690,11 +690,17 @@ function M._stream(opts) on_stop = function(stop_opts) if stop_opts.usage and opts.update_tokens_usage then opts.update_tokens_usage(stop_opts.usage) end - ---@param partial_tool_use_list AvantePartialLLMToolUse[] + ---@param tool_uses AvantePartialLLMToolUse[] ---@param tool_use_index integer ---@param tool_results AvanteLLMToolResult[] - local function handle_next_tool_use(partial_tool_use_list, tool_use_index, tool_results, streaming_tool_use) - if tool_use_index > #partial_tool_use_list then + local function handle_next_tool_use( + tool_uses, + tool_use_messages, + tool_use_index, + tool_results, + streaming_tool_use + ) + if tool_use_index > #tool_uses then ---@type avante.HistoryMessage[] local messages = {} for _, tool_result in ipairs(tool_results) do @@ -712,7 +718,7 @@ function M._stream(opts) }) end if opts.on_messages_add then opts.on_messages_add(messages) end - local the_last_tool_use = partial_tool_use_list[#partial_tool_use_list] + local the_last_tool_use = tool_uses[#tool_uses] if the_last_tool_use and the_last_tool_use.name == "attempt_completion" then opts.on_stop({ reason = "complete" }) return @@ -731,10 +737,13 @@ function M._stream(opts) M._stream(new_opts) return end - local partial_tool_use = partial_tool_use_list[tool_use_index] + local partial_tool_use = tool_uses[tool_use_index] + local partial_tool_use_message = tool_use_messages[tool_use_index] ---@param result string | nil ---@param error string | nil local function handle_tool_result(result, error) + partial_tool_use_message.is_calling = false + if opts.on_messages_add then opts.on_messages_add({ partial_tool_use_message }) end -- Special handling for cancellation signal from tools if error == LLMToolHelpers.CANCEL_TOKEN then Utils.debug("Tool execution was cancelled by user") @@ -759,7 +768,7 @@ function M._stream(opts) is_user_declined = is_user_declined ~= nil, } table.insert(tool_results, tool_result) - return handle_next_tool_use(partial_tool_use_list, tool_use_index + 1, tool_results) + return handle_next_tool_use(tool_uses, tool_use_messages, tool_use_index + 1, tool_results) end local is_edit_tool_use = Utils.is_edit_func_call_tool_use(partial_tool_use) local support_streaming = false @@ -782,6 +791,8 @@ function M._stream(opts) else if streaming_tool_use then return end end + partial_tool_use_message.is_calling = true + if opts.on_messages_add then opts.on_messages_add({ partial_tool_use_message }) end -- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil local result, error = LLMTools.process_tool_use( prompt_opts.tools, @@ -806,7 +817,7 @@ function M._stream(opts) return opts.on_stop({ reason = "cancelled" }) end local history_messages = opts.get_history_messages and opts.get_history_messages({ all = true }) or {} - local uncalled_tool_uses = Utils.get_uncalled_tool_uses(history_messages) + local uncalled_tool_uses, uncalled_tool_uses_messages = Utils.get_uncalled_tool_uses(history_messages) if stop_opts.reason == "complete" and Config.mode == "agentic" then local completed_attempt_completion_tool_use = nil for idx = #history_messages, 1, -1 do @@ -868,7 +879,13 @@ function M._stream(opts) end if stop_opts.reason == "tool_use" then opts.session_ctx.user_reminder_count = 0 - return handle_next_tool_use(uncalled_tool_uses, 1, {}, stop_opts.streaming_tool_use) + return handle_next_tool_use( + uncalled_tool_uses, + uncalled_tool_uses_messages, + 1, + {}, + stop_opts.streaming_tool_use + ) end if stop_opts.reason == "rate_limit" then local msg_content = "*[Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ...]*" diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 6023da2..3b826ed 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -1687,7 +1687,7 @@ local _message_to_lines_lru_cache = LRUCache:new(100) ---@param ctx table ---@return avante.ui.Line[] local function get_message_lines(message, messages, ctx) - if message.state == "generating" then return _get_message_lines(message, messages, ctx) end + if message.state == "generating" or message.is_calling then return _get_message_lines(message, messages, ctx) end local cached_lines = _message_to_lines_lru_cache:get(message.uuid) if cached_lines then return cached_lines end local lines = _get_message_lines(message, messages, ctx) @@ -2000,7 +2000,7 @@ function Sidebar:add_history_messages(messages) and messages[1].state == "generated" then local first_msg_text = Utils.message_to_text(messages[1], messages) - local lines_ = vim.split(first_msg_text, "\n") + local lines_ = vim.iter(vim.split(first_msg_text, "\n")):filter(function(line) return line ~= "" end):totable() if #lines_ > 0 then self.chat_history.title = lines_[1] self:save_history() @@ -2704,9 +2704,12 @@ function Sidebar:create_input_container() local tool_use_logs = tool_use_message.tool_use_logs or {} local content = string.format("[%s]: %s", tool_name, log) table.insert(tool_use_logs, content) + local orig_is_calling = tool_use_message.is_calling tool_use_message.tool_use_logs = tool_use_logs - self:save_history() + tool_use_message.is_calling = true self:update_content("") + tool_use_message.is_calling = orig_is_calling + self:save_history() end ---@type AvanteLLMStopCallback diff --git a/lua/avante/types.lua b/lua/avante/types.lua index ba4be87..64b3b79 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -112,6 +112,7 @@ vim.g.avante_login = vim.g.avante_login ---@field is_compacted boolean | nil ---@field is_deleted boolean | nil ---@field turn_id string | nil +---@field is_calling boolean | nil --- ---@class AvanteLLMToolResult ---@field tool_name string diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index 76af92d..b7964b7 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -1748,10 +1748,12 @@ end ---@param history_messages avante.HistoryMessage[] ---@return AvantePartialLLMToolUse[] +---@return avante.HistoryMessage[] function M.get_uncalled_tool_uses(history_messages) local last_turn_id = nil if #history_messages > 0 then last_turn_id = history_messages[#history_messages].turn_id end - local uncalled_tool_use_list = {} ---@type AvantePartialLLMToolUse[] + local uncalled_tool_uses = {} ---@type AvantePartialLLMToolUse[] + local uncalled_tool_uses_messages = {} ---@type avante.HistoryMessage[] local tool_result_seen = {} for idx = #history_messages, 1, -1 do local message = history_messages[idx] @@ -1772,7 +1774,8 @@ function M.get_uncalled_tool_uses(history_messages) input = item.input, state = message.state, } - table.insert(uncalled_tool_use_list, 1, partial_tool_use) + table.insert(uncalled_tool_uses, 1, partial_tool_use) + table.insert(uncalled_tool_uses_messages, 1, message) else is_break = true break @@ -1783,7 +1786,7 @@ function M.get_uncalled_tool_uses(history_messages) if is_break then break end ::continue:: end - return uncalled_tool_use_list + return uncalled_tool_uses, uncalled_tool_uses_messages end function M.call_once(func)