refactor: llm tool parameters (#2449)

This commit is contained in:
yetone
2025-07-15 16:40:25 +08:00
committed by GitHub
parent 0c6a8f5688
commit b8bb0fd969
25 changed files with 627 additions and 381 deletions

View File

@@ -125,39 +125,44 @@ end
--- IMPORTANT: Using "the_diff" instead of "diff" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "diff", making it impossible to achieve a streaming diff view.
---@type AvanteLLMToolFunc<{ path: string, diff: string, the_diff?: string, streaming?: boolean, tool_use_id?: string }>
function M.func(opts, on_log, on_complete, session_ctx)
if opts.the_diff ~= nil then
opts.diff = opts.the_diff
opts.the_diff = nil
function M.func(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local session_ctx = opts.session_ctx
if not on_complete then return false, "on_complete not provided" end
if input.the_diff ~= nil then
input.diff = input.the_diff
input.the_diff = nil
end
if not opts.path or not opts.diff then return false, "path and diff are required " .. vim.inspect(opts) end
if on_log then on_log("path: " .. opts.path) end
local abs_path = Helpers.get_abs_path(opts.path)
if not input.path or not input.diff then return false, "path and diff are required " .. vim.inspect(input) end
if on_log then on_log("path: " .. input.path) end
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
local is_streaming = opts.streaming or false
local is_streaming = input.streaming or false
session_ctx.prev_streaming_diff_timestamp_map = session_ctx.prev_streaming_diff_timestamp_map or {}
local current_timestamp = os.time()
if is_streaming then
local prev_streaming_diff_timestamp = session_ctx.prev_streaming_diff_timestamp_map[opts.tool_use_id]
local prev_streaming_diff_timestamp = session_ctx.prev_streaming_diff_timestamp_map[input.tool_use_id]
if prev_streaming_diff_timestamp ~= nil then
if current_timestamp - prev_streaming_diff_timestamp < 2 then
return false, "Diff hasn't changed in the last 2 seconds"
end
end
local streaming_diff_lines_count = Utils.count_lines(opts.diff)
local streaming_diff_lines_count = Utils.count_lines(input.diff)
session_ctx.streaming_diff_lines_count_history = session_ctx.streaming_diff_lines_count_history or {}
local prev_streaming_diff_lines_count = session_ctx.streaming_diff_lines_count_history[opts.tool_use_id]
local prev_streaming_diff_lines_count = session_ctx.streaming_diff_lines_count_history[input.tool_use_id]
if streaming_diff_lines_count == prev_streaming_diff_lines_count then
return false, "Diff lines count hasn't changed"
end
session_ctx.streaming_diff_lines_count_history[opts.tool_use_id] = streaming_diff_lines_count
session_ctx.streaming_diff_lines_count_history[input.tool_use_id] = streaming_diff_lines_count
end
local diff = fix_diff(opts.diff)
local diff = fix_diff(input.diff)
if on_log and diff ~= opts.diff then on_log("diff fixed") end
if on_log and diff ~= input.diff then on_log("diff fixed") end
local diff_lines = vim.split(diff, "\n")
@@ -203,16 +208,16 @@ function M.func(opts, on_log, on_complete, session_ctx)
return false, "No diff blocks found"
end
session_ctx.prev_streaming_diff_timestamp_map[opts.tool_use_id] = current_timestamp
session_ctx.prev_streaming_diff_timestamp_map[input.tool_use_id] = current_timestamp
local bufnr, err = Helpers.get_bufnr(abs_path)
if err then return false, err end
session_ctx.undo_joined = session_ctx.undo_joined or {}
local undo_joined = session_ctx.undo_joined[opts.tool_use_id]
local undo_joined = session_ctx.undo_joined[input.tool_use_id]
if not undo_joined then
pcall(vim.cmd.undojoin)
session_ctx.undo_joined[opts.tool_use_id] = true
session_ctx.undo_joined[input.tool_use_id] = true
end
local original_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
@@ -242,10 +247,10 @@ function M.func(opts, on_log, on_complete, session_ctx)
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map = session_ctx.rough_diff_blocks_to_diff_blocks_cache_map or {}
local rough_diff_blocks_to_diff_blocks_cache =
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[opts.tool_use_id]
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[input.tool_use_id]
if not rough_diff_blocks_to_diff_blocks_cache then
rough_diff_blocks_to_diff_blocks_cache = {}
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[opts.tool_use_id] = rough_diff_blocks_to_diff_blocks_cache
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[input.tool_use_id] = rough_diff_blocks_to_diff_blocks_cache
end
local function rough_diff_blocks_to_diff_blocks(rough_diff_blocks_)
@@ -472,7 +477,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
on_complete(false, "User canceled")
return
end
if session_ctx then Helpers.mark_as_not_viewed(opts.path, session_ctx) end
if session_ctx then Helpers.mark_as_not_viewed(input.path, session_ctx) end
on_complete(true, nil)
end,
})
@@ -615,35 +620,35 @@ function M.func(opts, on_log, on_complete, session_ctx)
end
session_ctx.extmark_id_map = session_ctx.extmark_id_map or {}
local extmark_id_map = session_ctx.extmark_id_map[opts.tool_use_id]
local extmark_id_map = session_ctx.extmark_id_map[input.tool_use_id]
if not extmark_id_map then
extmark_id_map = {}
session_ctx.extmark_id_map[opts.tool_use_id] = extmark_id_map
session_ctx.extmark_id_map[input.tool_use_id] = extmark_id_map
end
session_ctx.virt_lines_map = session_ctx.virt_lines_map or {}
local virt_lines_map = session_ctx.virt_lines_map[opts.tool_use_id]
local virt_lines_map = session_ctx.virt_lines_map[input.tool_use_id]
if not virt_lines_map then
virt_lines_map = {}
session_ctx.virt_lines_map[opts.tool_use_id] = virt_lines_map
session_ctx.virt_lines_map[input.tool_use_id] = virt_lines_map
end
session_ctx.last_orig_diff_end_line_map = session_ctx.last_orig_diff_end_line_map or {}
local last_orig_diff_end_line = session_ctx.last_orig_diff_end_line_map[opts.tool_use_id]
local last_orig_diff_end_line = session_ctx.last_orig_diff_end_line_map[input.tool_use_id]
if not last_orig_diff_end_line then
last_orig_diff_end_line = 1
session_ctx.last_orig_diff_end_line_map[opts.tool_use_id] = last_orig_diff_end_line
session_ctx.last_orig_diff_end_line_map[input.tool_use_id] = last_orig_diff_end_line
end
session_ctx.last_resp_diff_end_line_map = session_ctx.last_resp_diff_end_line_map or {}
local last_resp_diff_end_line = session_ctx.last_resp_diff_end_line_map[opts.tool_use_id]
local last_resp_diff_end_line = session_ctx.last_resp_diff_end_line_map[input.tool_use_id]
if not last_resp_diff_end_line then
last_resp_diff_end_line = 1
session_ctx.last_resp_diff_end_line_map[opts.tool_use_id] = last_resp_diff_end_line
session_ctx.last_resp_diff_end_line_map[input.tool_use_id] = last_resp_diff_end_line
end
session_ctx.prev_diff_blocks_map = session_ctx.prev_diff_blocks_map or {}
local prev_diff_blocks = session_ctx.prev_diff_blocks_map[opts.tool_use_id]
local prev_diff_blocks = session_ctx.prev_diff_blocks_map[input.tool_use_id]
if not prev_diff_blocks then
prev_diff_blocks = {}
session_ctx.prev_diff_blocks_map[opts.tool_use_id] = prev_diff_blocks
session_ctx.prev_diff_blocks_map[input.tool_use_id] = prev_diff_blocks
end
local function get_unstable_diff_blocks(diff_blocks_)
@@ -663,7 +668,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
local function highlight_streaming_diff_blocks()
local unstable_diff_blocks = get_unstable_diff_blocks(diff_blocks)
session_ctx.prev_diff_blocks_map[opts.tool_use_id] = diff_blocks
session_ctx.prev_diff_blocks_map[input.tool_use_id] = diff_blocks
local max_col = vim.o.columns
for _, diff_block in ipairs(unstable_diff_blocks) do
local new_lines = diff_block.new_lines
@@ -747,7 +752,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
--- check if the parent dir is exists, if not, create it
if vim.fn.isdirectory(parent_dir) == 0 then vim.fn.mkdir(parent_dir, "p") end
vim.api.nvim_buf_call(bufnr, function() vim.cmd("noautocmd write") end)
if session_ctx then Helpers.mark_as_not_viewed(opts.path, session_ctx) end
if session_ctx then Helpers.mark_as_not_viewed(input.path, session_ctx) end
on_complete(true, nil)
end, { focus = not Config.behaviour.auto_focus_on_diff_view }, session_ctx, M.name)
end