fix: do not modify the tool_use input

This commit is contained in:
yetone
2025-07-16 20:15:14 +08:00
parent ae06698c30
commit 2a16e7d4d9
9 changed files with 37 additions and 35 deletions

View File

@@ -827,14 +827,16 @@ function M._stream(opts)
local support_streaming = false
local llm_tool = vim.iter(prompt_opts.tools):find(function(tool) return tool.name == partial_tool_use.name end)
if llm_tool then support_streaming = llm_tool.support_streaming == true end
---@type AvanteLLMToolFuncOpts
local tool_use_opts = {
session_ctx = opts.session_ctx,
}
if partial_tool_use.state == "generating" and not is_edit_tool_use and not support_streaming then return end
if type(partial_tool_use.input) == "table" then partial_tool_use.input.tool_use_id = partial_tool_use.id end
if type(partial_tool_use.input) == "table" then tool_use_opts.tool_use_id = partial_tool_use.id end
if partial_tool_use.state == "generating" then
if type(partial_tool_use.input) == "table" then
partial_tool_use.input.streaming = true
LLMTools.process_tool_use(prompt_opts.tools, partial_tool_use, {
session_ctx = opts.session_ctx,
})
tool_use_opts.streaming = true
LLMTools.process_tool_use(prompt_opts.tools, partial_tool_use, tool_use_opts)
end
return
else

View File

@@ -75,7 +75,7 @@ function M.func(input, opts)
local sidebar = require("avante").get()
if not sidebar then return false, "Avante sidebar not found" end
local is_streaming = input.streaming or false
local is_streaming = opts.streaming or false
if is_streaming then
-- wait for stream completion as command may not be complete yet
return

View File

@@ -216,7 +216,7 @@ M.returns = {
---@type AvanteLLMToolFunc<{ path: string, command: string, streaming?: boolean }>
function M.func(input, opts)
local is_streaming = input.streaming or false
local is_streaming = opts.streaming or false
if is_streaming then
-- wait for stream completion as command may not be complete yet
return

View File

@@ -163,8 +163,6 @@ M.func = vim.schedule_wrap(function(input, opts)
path = input.target_file,
old_str = original_code,
new_str = jsn.choices[1].message.content,
streaming = input.streaming,
tool_use_id = input.tool_use_id,
}
str_replace.func(new_input, opts)
end)

View File

@@ -127,7 +127,7 @@ local function fix_diff(diff)
end
--- IMPORTANT: Using "the_diff" instead of "diff" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "diff", making it impossible to achieve a streaming diff view.
---@type AvanteLLMToolFunc<{ path: string, diff: string, the_diff?: string, streaming?: boolean, tool_use_id?: string }>
---@type AvanteLLMToolFunc<{ path: string, diff: string, the_diff?: string }>
function M.func(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
@@ -143,12 +143,12 @@ function M.func(input, opts)
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
local is_streaming = input.streaming or false
local is_streaming = opts.streaming or false
session_ctx.prev_streaming_diff_timestamp_map = session_ctx.prev_streaming_diff_timestamp_map or {}
local current_timestamp = os.time()
if is_streaming then
local prev_streaming_diff_timestamp = session_ctx.prev_streaming_diff_timestamp_map[input.tool_use_id]
local prev_streaming_diff_timestamp = session_ctx.prev_streaming_diff_timestamp_map[opts.tool_use_id]
if prev_streaming_diff_timestamp ~= nil then
if current_timestamp - prev_streaming_diff_timestamp < 2 then
return false, "Diff hasn't changed in the last 2 seconds"
@@ -156,11 +156,11 @@ function M.func(input, opts)
end
local streaming_diff_lines_count = Utils.count_lines(input.diff)
session_ctx.streaming_diff_lines_count_history = session_ctx.streaming_diff_lines_count_history or {}
local prev_streaming_diff_lines_count = session_ctx.streaming_diff_lines_count_history[input.tool_use_id]
local prev_streaming_diff_lines_count = session_ctx.streaming_diff_lines_count_history[opts.tool_use_id]
if streaming_diff_lines_count == prev_streaming_diff_lines_count then
return false, "Diff lines count hasn't changed"
end
session_ctx.streaming_diff_lines_count_history[input.tool_use_id] = streaming_diff_lines_count
session_ctx.streaming_diff_lines_count_history[opts.tool_use_id] = streaming_diff_lines_count
end
local diff = fix_diff(input.diff)
@@ -211,16 +211,16 @@ function M.func(input, opts)
return false, "No diff blocks found"
end
session_ctx.prev_streaming_diff_timestamp_map[input.tool_use_id] = current_timestamp
session_ctx.prev_streaming_diff_timestamp_map[opts.tool_use_id] = current_timestamp
local bufnr, err = Helpers.get_bufnr(abs_path)
if err then return false, err end
session_ctx.undo_joined = session_ctx.undo_joined or {}
local undo_joined = session_ctx.undo_joined[input.tool_use_id]
local undo_joined = session_ctx.undo_joined[opts.tool_use_id]
if not undo_joined then
pcall(vim.cmd.undojoin)
session_ctx.undo_joined[input.tool_use_id] = true
session_ctx.undo_joined[opts.tool_use_id] = true
end
local original_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
@@ -250,10 +250,10 @@ function M.func(input, opts)
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map = session_ctx.rough_diff_blocks_to_diff_blocks_cache_map or {}
local rough_diff_blocks_to_diff_blocks_cache =
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[input.tool_use_id]
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[opts.tool_use_id]
if not rough_diff_blocks_to_diff_blocks_cache then
rough_diff_blocks_to_diff_blocks_cache = {}
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[input.tool_use_id] = rough_diff_blocks_to_diff_blocks_cache
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[opts.tool_use_id] = rough_diff_blocks_to_diff_blocks_cache
end
local function rough_diff_blocks_to_diff_blocks(rough_diff_blocks_)
@@ -623,35 +623,35 @@ function M.func(input, opts)
end
session_ctx.extmark_id_map = session_ctx.extmark_id_map or {}
local extmark_id_map = session_ctx.extmark_id_map[input.tool_use_id]
local extmark_id_map = session_ctx.extmark_id_map[opts.tool_use_id]
if not extmark_id_map then
extmark_id_map = {}
session_ctx.extmark_id_map[input.tool_use_id] = extmark_id_map
session_ctx.extmark_id_map[opts.tool_use_id] = extmark_id_map
end
session_ctx.virt_lines_map = session_ctx.virt_lines_map or {}
local virt_lines_map = session_ctx.virt_lines_map[input.tool_use_id]
local virt_lines_map = session_ctx.virt_lines_map[opts.tool_use_id]
if not virt_lines_map then
virt_lines_map = {}
session_ctx.virt_lines_map[input.tool_use_id] = virt_lines_map
session_ctx.virt_lines_map[opts.tool_use_id] = virt_lines_map
end
session_ctx.last_orig_diff_end_line_map = session_ctx.last_orig_diff_end_line_map or {}
local last_orig_diff_end_line = session_ctx.last_orig_diff_end_line_map[input.tool_use_id]
local last_orig_diff_end_line = session_ctx.last_orig_diff_end_line_map[opts.tool_use_id]
if not last_orig_diff_end_line then
last_orig_diff_end_line = 1
session_ctx.last_orig_diff_end_line_map[input.tool_use_id] = last_orig_diff_end_line
session_ctx.last_orig_diff_end_line_map[opts.tool_use_id] = last_orig_diff_end_line
end
session_ctx.last_resp_diff_end_line_map = session_ctx.last_resp_diff_end_line_map or {}
local last_resp_diff_end_line = session_ctx.last_resp_diff_end_line_map[input.tool_use_id]
local last_resp_diff_end_line = session_ctx.last_resp_diff_end_line_map[opts.tool_use_id]
if not last_resp_diff_end_line then
last_resp_diff_end_line = 1
session_ctx.last_resp_diff_end_line_map[input.tool_use_id] = last_resp_diff_end_line
session_ctx.last_resp_diff_end_line_map[opts.tool_use_id] = last_resp_diff_end_line
end
session_ctx.prev_diff_blocks_map = session_ctx.prev_diff_blocks_map or {}
local prev_diff_blocks = session_ctx.prev_diff_blocks_map[input.tool_use_id]
local prev_diff_blocks = session_ctx.prev_diff_blocks_map[opts.tool_use_id]
if not prev_diff_blocks then
prev_diff_blocks = {}
session_ctx.prev_diff_blocks_map[input.tool_use_id] = prev_diff_blocks
session_ctx.prev_diff_blocks_map[opts.tool_use_id] = prev_diff_blocks
end
local function get_unstable_diff_blocks(diff_blocks_)
@@ -671,7 +671,7 @@ function M.func(input, opts)
local function highlight_streaming_diff_blocks()
local unstable_diff_blocks = get_unstable_diff_blocks(diff_blocks)
session_ctx.prev_diff_blocks_map[input.tool_use_id] = diff_blocks
session_ctx.prev_diff_blocks_map[opts.tool_use_id] = diff_blocks
local max_col = vim.o.columns
for _, diff_block in ipairs(unstable_diff_blocks) do
local new_lines = diff_block.new_lines

View File

@@ -57,12 +57,10 @@ M.returns = {
function M.func(input, opts)
local replace_in_file = require("avante.llm_tools.replace_in_file")
local diff = "------- SEARCH\n" .. input.old_str .. "\n=======\n" .. input.new_str
if not input.streaming then diff = diff .. "\n+++++++ REPLACE" end
if not opts.streaming then diff = diff .. "\n+++++++ REPLACE" end
local new_input = {
path = input.path,
diff = diff,
streaming = input.streaming,
tool_use_id = input.tool_use_id,
}
return replace_in_file.func(new_input, opts)
end

View File

@@ -79,8 +79,6 @@ function M.func(input, opts)
path = input.path,
old_str = old_content,
new_str = input.content,
streaming = input.streaming,
tool_use_id = input.tool_use_id,
}
return str_replace.func(new_input, opts)
end

View File

@@ -413,6 +413,10 @@ function M:parse_response(ctx, data_stream, _, opts)
if usage then opts.update_tokens_usage(usage) end
end
end
if jsn.error and jsn.error ~= vim.NIL then
opts.on_stop({ reason = "error", error = vim.inspect(jsn.error) })
return
end
---@cast jsn AvanteOpenAIChatResponse
if not jsn.choices then return end
local choice = jsn.choices[1]

View File

@@ -419,6 +419,8 @@ vim.g.avante_login = vim.g.avante_login
---@field on_complete? fun(result: boolean | string | nil, error: string | nil): nil
---@field on_log? fun(log: string): nil
---@field set_store? fun(key: string, value: any): nil
---@field tool_use_id? string
---@field streaming? boolean
---
---@alias AvanteLLMToolFunc<T> fun(
--- input: T,