diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index ccb9217..522d54b 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -145,18 +145,17 @@ M._stream = function(opts) on_start = opts.on_start, on_chunk = opts.on_chunk, on_stop = function(stop_opts) - if stop_opts.reason == "tool_use" and stop_opts.tool_use then - local result, error = LLMTools.process_tool_use(opts.tools, stop_opts.tool_use, opts.on_tool_log) - local tool_result = { - tool_use_id = stop_opts.tool_use.id, - content = error ~= nil and error or result, - is_error = error ~= nil, - } + if stop_opts.reason == "tool_use" and stop_opts.tool_use_list then local old_tool_histories = vim.deepcopy(opts.tool_histories) or {} - table.insert( - old_tool_histories, - { tool_result = tool_result, tool_use = stop_opts.tool_use, response_content = stop_opts.response_content } - ) + for _, tool_use in ipairs(stop_opts.tool_use_list) do + local result, error = LLMTools.process_tool_use(opts.tools, tool_use, opts.on_tool_log) + local tool_result = { + tool_use_id = tool_use.id, + content = error ~= nil and error or result, + is_error = error ~= nil, + } + table.insert(old_tool_histories, { tool_result = tool_result, tool_use = tool_use }) + end local new_opts = vim.tbl_deep_extend("force", opts, { tool_histories = old_tool_histories, }) @@ -418,7 +417,6 @@ end ---@class AvanteLLMToolHistory ---@field tool_result? AvanteLLMToolResult ---@field tool_use? AvanteLLMToolUse ----@field response_content? string --- ---@class StreamOptions: GeneratePromptsOptions ---@field on_start AvanteLLMStartCallback diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index 3e48d92..ecc97dd 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -119,10 +119,10 @@ M.parse_messages = function(opts) role = "assistant", content = {}, } - if tool_history.response_content then + if tool_history.tool_use.response_content then msg.content[#msg.content + 1] = { type = "text", - text = tool_history.response_content, + text = tool_history.tool_use.response_content, } end msg.content[#msg.content + 1] = { @@ -177,24 +177,33 @@ M.parse_response = function(ctx, data_stream, event_state, opts) local ok, jsn = pcall(vim.json.decode, data_stream) if not ok then return end if jsn.content_block.type == "tool_use" then - ctx.tool_use = { + if not ctx.tool_use_list then ctx.tool_use_list = {} end + local tool_use = { name = jsn.content_block.name, id = jsn.content_block.id, input_json = "", + response_content = nil, } + table.insert(ctx.tool_use_list, tool_use) elseif jsn.content_block.type == "text" then ctx.response_content = "" end elseif event_state == "content_block_delta" then local ok, jsn = pcall(vim.json.decode, data_stream) if not ok then return end - if ctx.tool_use and jsn.delta.type == "input_json_delta" then - ctx.tool_use.input_json = ctx.tool_use.input_json .. jsn.delta.partial_json + if ctx.tool_use_list and jsn.delta.type == "input_json_delta" then + local tool_use = ctx.tool_use_list[#ctx.tool_use_list] + tool_use.input_json = tool_use.input_json .. jsn.delta.partial_json return elseif ctx.response_content and jsn.delta.type == "text_delta" then ctx.response_content = ctx.response_content .. jsn.delta.text end opts.on_chunk(jsn.delta.text) + elseif event_state == "content_block_stop" then + if ctx.tool_use_list then + local tool_use = ctx.tool_use_list[#ctx.tool_use_list] + if tool_use.response_content == nil then tool_use.response_content = ctx.response_content end + end elseif event_state == "message_delta" then local ok, jsn = pcall(vim.json.decode, data_stream) if not ok then return end @@ -204,8 +213,7 @@ M.parse_response = function(ctx, data_stream, event_state, opts) opts.on_stop({ reason = "tool_use", usage = jsn.usage, - tool_use = ctx.tool_use, - response_content = ctx.response_content, + tool_use_list = ctx.tool_use_list, }) end return diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index f9509ab..fdb1cd3 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -75,6 +75,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil } ---@field name string ---@field id string ---@field input_json string +---@field response_content? string --- ---@class AvanteLLMStartCallbackOptions ---@field usage? AvanteLLMUsage @@ -83,8 +84,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil } ---@field reason "complete" | "tool_use" | "error" ---@field error? string | table ---@field usage? AvanteLLMUsage ----@field tool_use? AvanteLLMToolUse ----@field response_content? string +---@field tool_use_list? AvanteLLMToolUse[] --- ---@alias AvanteStreamParser fun(line: string, handler_opts: AvanteHandlerOptions): nil ---@alias AvanteLLMStartCallback fun(opts: AvanteLLMStartCallbackOptions): nil diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index 80647ff..0543046 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -29,6 +29,7 @@ local P = require("avante.providers") ---@field arguments string --- ---@class OpenAIMessageToolCall +---@field index integer ---@field id string ---@field type "function" ---@field function OpenAIMessageToolCallFunction @@ -210,8 +211,7 @@ M.parse_response = function(ctx, data_stream, _, opts) opts.on_stop({ reason = "tool_use", usage = jsn.usage, - tool_use = ctx.tool_use, - response_content = ctx.response_content, + tool_use_list = ctx.tool_use_list, }) elseif choice.delta.reasoning_content and choice.delta.reasoning_content ~= vim.NIL then if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then @@ -229,14 +229,17 @@ M.parse_response = function(ctx, data_stream, _, opts) opts.on_chunk(choice.delta.reasoning) elseif choice.delta.tool_calls then local tool_call = choice.delta.tool_calls[1] - if not ctx.tool_use then - ctx.tool_use = { + if not ctx.tool_use_list then ctx.tool_use_list = {} end + if not ctx.tool_use_list[tool_call.index + 1] then + local tool_use = { name = tool_call["function"].name, id = tool_call.id, input_json = "", } + ctx.tool_use_list[tool_call.index + 1] = tool_use else - ctx.tool_use.input_json = ctx.tool_use.input_json .. tool_call["function"].arguments + local tool_use = ctx.tool_use_list[tool_call.index + 1] + tool_use.input_json = tool_use.input_json .. tool_call["function"].arguments end elseif choice.delta.content then if diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index fd0a79d..17abcba 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -1847,11 +1847,11 @@ function Sidebar:create_input_container(opts) local on_tool_log = function(tool_name, log) if transformed_response:sub(-1) ~= "\n" then transformed_response = transformed_response .. "\n" end - transformed_response = transformed_response .. "[" .. tool_name .. "]: " .. log + transformed_response = transformed_response .. "[" .. tool_name .. "]: " .. log .. "\n" local breakline = "" if displayed_response:sub(-1) ~= "\n" then breakline = "\n" end - displayed_response = displayed_response .. breakline .. "[" .. tool_name .. "]: " .. log - self:update_content(content_prefix .. displayed_response .. "\n", { + displayed_response = displayed_response .. breakline .. "[" .. tool_name .. "]: " .. log .. "\n" + self:update_content(content_prefix .. displayed_response, { scroll = scroll, }) end