From 2a16e7d4d90c8b3a82c0f3c3bd7f3008c882c8df Mon Sep 17 00:00:00 2001 From: yetone Date: Wed, 16 Jul 2025 20:15:14 +0800 Subject: [PATCH] fix: do not modify the tool_use input --- lua/avante/llm.lua | 12 +++--- lua/avante/llm_tools/attempt_completion.lua | 2 +- lua/avante/llm_tools/bash.lua | 2 +- lua/avante/llm_tools/edit_file.lua | 2 - lua/avante/llm_tools/replace_in_file.lua | 42 ++++++++++----------- lua/avante/llm_tools/str_replace.lua | 4 +- lua/avante/llm_tools/write_to_file.lua | 2 - lua/avante/providers/openai.lua | 4 ++ lua/avante/types.lua | 2 + 9 files changed, 37 insertions(+), 35 deletions(-) diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 4378dbf..038672a 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -827,14 +827,16 @@ function M._stream(opts) local support_streaming = false local llm_tool = vim.iter(prompt_opts.tools):find(function(tool) return tool.name == partial_tool_use.name end) if llm_tool then support_streaming = llm_tool.support_streaming == true end + ---@type AvanteLLMToolFuncOpts + local tool_use_opts = { + session_ctx = opts.session_ctx, + } if partial_tool_use.state == "generating" and not is_edit_tool_use and not support_streaming then return end - if type(partial_tool_use.input) == "table" then partial_tool_use.input.tool_use_id = partial_tool_use.id end + if type(partial_tool_use.input) == "table" then tool_use_opts.tool_use_id = partial_tool_use.id end if partial_tool_use.state == "generating" then if type(partial_tool_use.input) == "table" then - partial_tool_use.input.streaming = true - LLMTools.process_tool_use(prompt_opts.tools, partial_tool_use, { - session_ctx = opts.session_ctx, - }) + tool_use_opts.streaming = true + LLMTools.process_tool_use(prompt_opts.tools, partial_tool_use, tool_use_opts) end return else diff --git a/lua/avante/llm_tools/attempt_completion.lua b/lua/avante/llm_tools/attempt_completion.lua index c72ad2e..5ddcbff 100644 --- a/lua/avante/llm_tools/attempt_completion.lua +++ b/lua/avante/llm_tools/attempt_completion.lua @@ -75,7 +75,7 @@ function M.func(input, opts) local sidebar = require("avante").get() if not sidebar then return false, "Avante sidebar not found" end - local is_streaming = input.streaming or false + local is_streaming = opts.streaming or false if is_streaming then -- wait for stream completion as command may not be complete yet return diff --git a/lua/avante/llm_tools/bash.lua b/lua/avante/llm_tools/bash.lua index 3e40169..52e8c1c 100644 --- a/lua/avante/llm_tools/bash.lua +++ b/lua/avante/llm_tools/bash.lua @@ -216,7 +216,7 @@ M.returns = { ---@type AvanteLLMToolFunc<{ path: string, command: string, streaming?: boolean }> function M.func(input, opts) - local is_streaming = input.streaming or false + local is_streaming = opts.streaming or false if is_streaming then -- wait for stream completion as command may not be complete yet return diff --git a/lua/avante/llm_tools/edit_file.lua b/lua/avante/llm_tools/edit_file.lua index 164711d..eb0ca64 100644 --- a/lua/avante/llm_tools/edit_file.lua +++ b/lua/avante/llm_tools/edit_file.lua @@ -163,8 +163,6 @@ M.func = vim.schedule_wrap(function(input, opts) path = input.target_file, old_str = original_code, new_str = jsn.choices[1].message.content, - streaming = input.streaming, - tool_use_id = input.tool_use_id, } str_replace.func(new_input, opts) end) diff --git a/lua/avante/llm_tools/replace_in_file.lua b/lua/avante/llm_tools/replace_in_file.lua index 21313c3..50b172d 100644 --- a/lua/avante/llm_tools/replace_in_file.lua +++ b/lua/avante/llm_tools/replace_in_file.lua @@ -127,7 +127,7 @@ local function fix_diff(diff) end --- IMPORTANT: Using "the_diff" instead of "diff" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "diff", making it impossible to achieve a streaming diff view. ----@type AvanteLLMToolFunc<{ path: string, diff: string, the_diff?: string, streaming?: boolean, tool_use_id?: string }> +---@type AvanteLLMToolFunc<{ path: string, diff: string, the_diff?: string }> function M.func(input, opts) local on_log = opts.on_log local on_complete = opts.on_complete @@ -143,12 +143,12 @@ function M.func(input, opts) local abs_path = Helpers.get_abs_path(input.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end - local is_streaming = input.streaming or false + local is_streaming = opts.streaming or false session_ctx.prev_streaming_diff_timestamp_map = session_ctx.prev_streaming_diff_timestamp_map or {} local current_timestamp = os.time() if is_streaming then - local prev_streaming_diff_timestamp = session_ctx.prev_streaming_diff_timestamp_map[input.tool_use_id] + local prev_streaming_diff_timestamp = session_ctx.prev_streaming_diff_timestamp_map[opts.tool_use_id] if prev_streaming_diff_timestamp ~= nil then if current_timestamp - prev_streaming_diff_timestamp < 2 then return false, "Diff hasn't changed in the last 2 seconds" @@ -156,11 +156,11 @@ function M.func(input, opts) end local streaming_diff_lines_count = Utils.count_lines(input.diff) session_ctx.streaming_diff_lines_count_history = session_ctx.streaming_diff_lines_count_history or {} - local prev_streaming_diff_lines_count = session_ctx.streaming_diff_lines_count_history[input.tool_use_id] + local prev_streaming_diff_lines_count = session_ctx.streaming_diff_lines_count_history[opts.tool_use_id] if streaming_diff_lines_count == prev_streaming_diff_lines_count then return false, "Diff lines count hasn't changed" end - session_ctx.streaming_diff_lines_count_history[input.tool_use_id] = streaming_diff_lines_count + session_ctx.streaming_diff_lines_count_history[opts.tool_use_id] = streaming_diff_lines_count end local diff = fix_diff(input.diff) @@ -211,16 +211,16 @@ function M.func(input, opts) return false, "No diff blocks found" end - session_ctx.prev_streaming_diff_timestamp_map[input.tool_use_id] = current_timestamp + session_ctx.prev_streaming_diff_timestamp_map[opts.tool_use_id] = current_timestamp local bufnr, err = Helpers.get_bufnr(abs_path) if err then return false, err end session_ctx.undo_joined = session_ctx.undo_joined or {} - local undo_joined = session_ctx.undo_joined[input.tool_use_id] + local undo_joined = session_ctx.undo_joined[opts.tool_use_id] if not undo_joined then pcall(vim.cmd.undojoin) - session_ctx.undo_joined[input.tool_use_id] = true + session_ctx.undo_joined[opts.tool_use_id] = true end local original_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) @@ -250,10 +250,10 @@ function M.func(input, opts) session_ctx.rough_diff_blocks_to_diff_blocks_cache_map = session_ctx.rough_diff_blocks_to_diff_blocks_cache_map or {} local rough_diff_blocks_to_diff_blocks_cache = - session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[input.tool_use_id] + session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[opts.tool_use_id] if not rough_diff_blocks_to_diff_blocks_cache then rough_diff_blocks_to_diff_blocks_cache = {} - session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[input.tool_use_id] = rough_diff_blocks_to_diff_blocks_cache + session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[opts.tool_use_id] = rough_diff_blocks_to_diff_blocks_cache end local function rough_diff_blocks_to_diff_blocks(rough_diff_blocks_) @@ -623,35 +623,35 @@ function M.func(input, opts) end session_ctx.extmark_id_map = session_ctx.extmark_id_map or {} - local extmark_id_map = session_ctx.extmark_id_map[input.tool_use_id] + local extmark_id_map = session_ctx.extmark_id_map[opts.tool_use_id] if not extmark_id_map then extmark_id_map = {} - session_ctx.extmark_id_map[input.tool_use_id] = extmark_id_map + session_ctx.extmark_id_map[opts.tool_use_id] = extmark_id_map end session_ctx.virt_lines_map = session_ctx.virt_lines_map or {} - local virt_lines_map = session_ctx.virt_lines_map[input.tool_use_id] + local virt_lines_map = session_ctx.virt_lines_map[opts.tool_use_id] if not virt_lines_map then virt_lines_map = {} - session_ctx.virt_lines_map[input.tool_use_id] = virt_lines_map + session_ctx.virt_lines_map[opts.tool_use_id] = virt_lines_map end session_ctx.last_orig_diff_end_line_map = session_ctx.last_orig_diff_end_line_map or {} - local last_orig_diff_end_line = session_ctx.last_orig_diff_end_line_map[input.tool_use_id] + local last_orig_diff_end_line = session_ctx.last_orig_diff_end_line_map[opts.tool_use_id] if not last_orig_diff_end_line then last_orig_diff_end_line = 1 - session_ctx.last_orig_diff_end_line_map[input.tool_use_id] = last_orig_diff_end_line + session_ctx.last_orig_diff_end_line_map[opts.tool_use_id] = last_orig_diff_end_line end session_ctx.last_resp_diff_end_line_map = session_ctx.last_resp_diff_end_line_map or {} - local last_resp_diff_end_line = session_ctx.last_resp_diff_end_line_map[input.tool_use_id] + local last_resp_diff_end_line = session_ctx.last_resp_diff_end_line_map[opts.tool_use_id] if not last_resp_diff_end_line then last_resp_diff_end_line = 1 - session_ctx.last_resp_diff_end_line_map[input.tool_use_id] = last_resp_diff_end_line + session_ctx.last_resp_diff_end_line_map[opts.tool_use_id] = last_resp_diff_end_line end session_ctx.prev_diff_blocks_map = session_ctx.prev_diff_blocks_map or {} - local prev_diff_blocks = session_ctx.prev_diff_blocks_map[input.tool_use_id] + local prev_diff_blocks = session_ctx.prev_diff_blocks_map[opts.tool_use_id] if not prev_diff_blocks then prev_diff_blocks = {} - session_ctx.prev_diff_blocks_map[input.tool_use_id] = prev_diff_blocks + session_ctx.prev_diff_blocks_map[opts.tool_use_id] = prev_diff_blocks end local function get_unstable_diff_blocks(diff_blocks_) @@ -671,7 +671,7 @@ function M.func(input, opts) local function highlight_streaming_diff_blocks() local unstable_diff_blocks = get_unstable_diff_blocks(diff_blocks) - session_ctx.prev_diff_blocks_map[input.tool_use_id] = diff_blocks + session_ctx.prev_diff_blocks_map[opts.tool_use_id] = diff_blocks local max_col = vim.o.columns for _, diff_block in ipairs(unstable_diff_blocks) do local new_lines = diff_block.new_lines diff --git a/lua/avante/llm_tools/str_replace.lua b/lua/avante/llm_tools/str_replace.lua index 1d0daf6..f5f78e9 100644 --- a/lua/avante/llm_tools/str_replace.lua +++ b/lua/avante/llm_tools/str_replace.lua @@ -57,12 +57,10 @@ M.returns = { function M.func(input, opts) local replace_in_file = require("avante.llm_tools.replace_in_file") local diff = "------- SEARCH\n" .. input.old_str .. "\n=======\n" .. input.new_str - if not input.streaming then diff = diff .. "\n+++++++ REPLACE" end + if not opts.streaming then diff = diff .. "\n+++++++ REPLACE" end local new_input = { path = input.path, diff = diff, - streaming = input.streaming, - tool_use_id = input.tool_use_id, } return replace_in_file.func(new_input, opts) end diff --git a/lua/avante/llm_tools/write_to_file.lua b/lua/avante/llm_tools/write_to_file.lua index fa22781..076ef09 100644 --- a/lua/avante/llm_tools/write_to_file.lua +++ b/lua/avante/llm_tools/write_to_file.lua @@ -79,8 +79,6 @@ function M.func(input, opts) path = input.path, old_str = old_content, new_str = input.content, - streaming = input.streaming, - tool_use_id = input.tool_use_id, } return str_replace.func(new_input, opts) end diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index c65f208..6f6223f 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -413,6 +413,10 @@ function M:parse_response(ctx, data_stream, _, opts) if usage then opts.update_tokens_usage(usage) end end end + if jsn.error and jsn.error ~= vim.NIL then + opts.on_stop({ reason = "error", error = vim.inspect(jsn.error) }) + return + end ---@cast jsn AvanteOpenAIChatResponse if not jsn.choices then return end local choice = jsn.choices[1] diff --git a/lua/avante/types.lua b/lua/avante/types.lua index 4b77e99..0744c20 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -419,6 +419,8 @@ vim.g.avante_login = vim.g.avante_login ---@field on_complete? fun(result: boolean | string | nil, error: string | nil): nil ---@field on_log? fun(log: string): nil ---@field set_store? fun(key: string, value: any): nil +---@field tool_use_id? string +---@field streaming? boolean --- ---@alias AvanteLLMToolFunc fun( --- input: T,