feat: streaming diff (#2107)

This commit is contained in:
yetone
2025-06-02 16:44:33 +08:00
committed by GitHub
parent bc403ddcbf
commit 746f071b37
12 changed files with 1449 additions and 130 deletions

View File

@@ -25,38 +25,14 @@ end
---@type AvanteLLMToolFunc<{ command: "view" | "str_replace" | "create" | "insert" | "undo_edit", path: string, old_str?: string, new_str?: string, file_text?: string, insert_line?: integer, new_str?: string, view_range?: integer[] }>
function M.str_replace_editor(opts, on_log, on_complete, session_ctx)
if on_log then on_log("command: " .. opts.command) end
if not on_complete then return false, "on_complete not provided" end
local abs_path = Helpers.get_abs_path(opts.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if opts.command == "view" then
local view = require("avante.llm_tools.view")
local opts_ = { path = opts.path }
if opts.view_range then
local start_line, end_line = unpack(opts.view_range)
opts_.view_range = {
start_line = start_line,
end_line = end_line,
}
end
return view(opts_, on_log, on_complete, session_ctx)
end
if opts.command == "str_replace" then
return require("avante.llm_tools.str_replace").func(opts, on_log, on_complete, session_ctx)
end
if opts.command == "create" then
return require("avante.llm_tools.create").func(opts, on_log, on_complete, session_ctx)
end
if opts.command == "insert" then
return require("avante.llm_tools.insert").func(opts, on_log, on_complete, session_ctx)
end
if opts.command == "undo_edit" then
return require("avante.llm_tools.undo_edit").func(opts, on_log, on_complete, session_ctx)
end
return false, "Unknown command: " .. opts.command
---@cast opts any
return M.str_replace_based_edit_tool(opts, on_log, on_complete, session_ctx)
end
---@type AvanteLLMToolFunc<{ command: "view" | "str_replace" | "create" | "insert", path: string, old_str?: string, new_str?: string, file_text?: string, insert_line?: integer, new_str?: string, view_range?: integer[] }>
---@type AvanteLLMToolFunc<{ command: "view" | "str_replace" | "create" | "insert", path: string, old_str?: string, new_str?: string, file_text?: string, insert_line?: integer, new_str?: string, view_range?: integer[], streaming?: boolean }>
function M.str_replace_based_edit_tool(opts, on_log, on_complete, session_ctx)
if on_log then on_log("command: " .. opts.command) end
if not on_complete then return false, "on_complete not provided" end
@@ -67,10 +43,8 @@ function M.str_replace_based_edit_tool(opts, on_log, on_complete, session_ctx)
local opts_ = { path = opts.path }
if opts.view_range then
local start_line, end_line = unpack(opts.view_range)
opts_.view_range = {
start_line = start_line,
end_line = end_line,
}
opts_.start_line = start_line
opts_.end_line = end_line
end
return view(opts_, on_log, on_complete, session_ctx)
end
@@ -1161,6 +1135,10 @@ M._tools = {
default = false,
},
},
usage = {
symbol_name = "The name of the symbol to retrieve the definition for, example: fibonacci",
show_line_numbers = "true or false",
},
},
returns = {
{

View File

@@ -28,7 +28,8 @@ M.param = {
type = "string",
},
{
name = "diff",
--- IMPORTANT: Using "the_diff" instead of "diff" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "diff", making it impossible to achieve a streaming diff view.
name = "the_diff",
description = [[
One or more SEARCH/REPLACE blocks following this exact format:
\`\`\`
@@ -61,7 +62,7 @@ One or more SEARCH/REPLACE blocks following this exact format:
},
usage = {
path = "File path here",
diff = "Search and replace blocks here",
the_diff = "Search and replace blocks here",
},
}
@@ -101,13 +102,17 @@ local function fix_diff(diff)
return table.concat(fixed_diff_lines, "\n")
end
---@type AvanteLLMToolFunc<{ path: string, diff: string }>
--- IMPORTANT: Using "the_diff" instead of "diff" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "diff", making it impossible to achieve a streaming diff view.
---@type AvanteLLMToolFunc<{ path: string, diff: string, the_diff?: string, streaming?: boolean, tool_use_id?: string }>
function M.func(opts, on_log, on_complete, session_ctx)
if not opts.path or not opts.diff then return false, "path and diff are required" end
if opts.the_diff ~= nil then opts.diff = opts.the_diff end
if not opts.path or not opts.diff then return false, "path and diff are required " .. vim.inspect(opts) end
if on_log then on_log("path: " .. opts.path) end
local abs_path = Helpers.get_abs_path(opts.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
local is_streaming = opts.streaming or false
local diff = fix_diff(opts.diff)
if on_log and diff ~= opts.diff then on_log("diff fixed") end
@@ -141,14 +146,31 @@ function M.func(opts, on_log, on_complete, session_ctx)
end
end
-- Handle streaming mode: if we're still in replace mode at the end, include the partial block
if is_streaming and is_replacing and #current_search > 0 then
if #current_search > #current_replace then current_search = vim.list_slice(current_search, 1, #current_replace) end
table.insert(
rough_diff_blocks,
{ search = table.concat(current_search, "\n"), replace = table.concat(current_replace, "\n") }
)
end
if #rough_diff_blocks == 0 then
Utils.debug("opts.diff", opts.diff)
Utils.debug("diff", diff)
-- Utils.debug("opts.diff", opts.diff)
-- Utils.debug("diff", diff)
return false, "No diff blocks found"
end
local bufnr, err = Helpers.get_bufnr(abs_path)
if err then return false, err end
session_ctx.undo_joined = session_ctx.undo_joined or {}
local undo_joined = session_ctx.undo_joined[opts.tool_use_id]
if not undo_joined then
pcall(vim.cmd.undojoin)
session_ctx.undo_joined[opts.tool_use_id] = true
end
local original_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
local sidebar = require("avante").get()
if not sidebar then return false, "Avante sidebar not found" end
@@ -519,18 +541,87 @@ function M.func(opts, on_log, on_complete, session_ctx)
end
end
insert_diff_blocks_new_lines()
highlight_diff_blocks()
register_cursor_move_events()
register_keybinding_events()
register_buf_write_events()
session_ctx.extmark_id_map = session_ctx.extmark_id_map or {}
local extmark_id_map = session_ctx.extmark_id_map[opts.tool_use_id]
if not extmark_id_map then
extmark_id_map = {}
session_ctx.extmark_id_map[opts.tool_use_id] = extmark_id_map
end
session_ctx.virt_lines_map = session_ctx.virt_lines_map or {}
local virt_lines_map = session_ctx.virt_lines_map[opts.tool_use_id]
if not virt_lines_map then
virt_lines_map = {}
session_ctx.virt_lines_map[opts.tool_use_id] = virt_lines_map
end
local function highlight_streaming_diff_blocks()
vim.api.nvim_buf_clear_namespace(bufnr, NAMESPACE, 0, -1)
local max_col = vim.o.columns
for _, diff_block in ipairs(diff_blocks) do
local start_line = diff_block.start_line
if #diff_block.old_lines > 0 then
vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, start_line - 1, 0, {
hl_group = Highlights.TO_BE_DELETED_WITHOUT_STRIKETHROUGH,
hl_eol = true,
hl_mode = "combine",
end_row = start_line + #diff_block.old_lines - 1,
})
end
if #diff_block.new_lines == 0 then goto continue end
local virt_lines = vim
.iter(diff_block.new_lines)
:map(function(line)
--- append spaces to the end of the line
local line_ = line .. string.rep(" ", max_col - #line)
return { { line_, Highlights.INCOMING } }
end)
:totable()
local extmark_line
if #diff_block.old_lines > 0 then
extmark_line = math.max(0, start_line - 2 + #diff_block.old_lines)
else
extmark_line = math.max(0, start_line - 1 + #diff_block.old_lines)
end
vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, extmark_line, 0, {
virt_lines = virt_lines,
hl_eol = true,
hl_mode = "combine",
})
::continue::
end
end
if not is_streaming then
insert_diff_blocks_new_lines()
highlight_diff_blocks()
register_cursor_move_events()
register_keybinding_events()
register_buf_write_events()
else
highlight_streaming_diff_blocks()
end
if diff_blocks[1] then
local winnr = Utils.get_winid(bufnr)
vim.api.nvim_win_set_cursor(winnr, { diff_blocks[1].new_start_line, 0 })
vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end)
if is_streaming then
-- In streaming mode, focus on the last diff block
local last_diff_block = diff_blocks[#diff_blocks]
vim.api.nvim_win_set_cursor(winnr, { last_diff_block.start_line, 0 })
vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end)
else
-- In normal mode, focus on the first diff block
vim.api.nvim_win_set_cursor(winnr, { diff_blocks[1].new_start_line, 0 })
vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end)
end
end
if is_streaming then
-- In streaming mode, don't show confirmation dialog, just apply changes
return
end
pcall(vim.cmd.undojoin)
confirm = Helpers.confirm("Are you sure you want to apply this modification?", function(ok, reason)
clear()
if not ok then

View File

@@ -54,13 +54,16 @@ M.returns = {
},
}
---@type AvanteLLMToolFunc<{ path: string, old_str: string, new_str: string }>
---@type AvanteLLMToolFunc<{ path: string, old_str: string, new_str: string, streaming?: boolean, tool_use_id?: string }>
function M.func(opts, on_log, on_complete, session_ctx)
local replace_in_file = require("avante.llm_tools.replace_in_file")
local diff = "<<<<<<< SEARCH\n" .. opts.old_str .. "\n=======\n" .. opts.new_str .. "\n>>>>>>> REPLACE"
local diff = "<<<<<<< SEARCH\n" .. opts.old_str .. "\n=======\n" .. opts.new_str
if not opts.streaming then diff = diff .. "\n>>>>>>> REPLACE" end
local new_opts = {
path = opts.path,
diff = diff,
streaming = opts.streaming,
tool_use_id = opts.tool_use_id,
}
return replace_in_file.func(new_opts, on_log, on_complete, session_ctx)
end

View File

@@ -28,14 +28,15 @@ M.param = {
type = "string",
},
{
name = "content",
--- IMPORTANT: Using "the_content" instead of "content" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "content", making it impossible to achieve a stream diff view.
name = "the_content",
description = "The content to write to the file. ALWAYS provide the COMPLETE intended content of the file, without any truncation or omissions. You MUST include ALL parts of the file, even if they haven't been modified.",
type = "string",
},
},
usage = {
path = "File path here",
content = "File content here",
the_content = "File content here",
},
}
@@ -54,21 +55,25 @@ M.returns = {
},
}
---@type AvanteLLMToolFunc<{ path: string, content: string }>
--- IMPORTANT: Using "the_content" instead of "content" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "content", making it impossible to achieve a stream diff view.
---@type AvanteLLMToolFunc<{ path: string, content: string, the_content?: string, streaming?: boolean, tool_use_id?: string }>
function M.func(opts, on_log, on_complete, session_ctx)
if opts.the_content ~= nil then opts.content = opts.the_content end
if not on_complete then return false, "on_complete not provided" end
local abs_path = Helpers.get_abs_path(opts.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if opts.content == nil then return false, "content not provided" end
local old_lines = Utils.read_file_from_buf_or_disk(abs_path)
local old_content = table.concat(old_lines or {}, "\n")
local replace_in_file = require("avante.llm_tools.replace_in_file")
local diff = "<<<<<<< SEARCH\n" .. old_content .. "\n=======\n" .. opts.content .. "\n>>>>>>> REPLACE"
local str_replace = require("avante.llm_tools.str_replace")
local new_opts = {
path = opts.path,
diff = diff,
old_str = old_content,
new_str = opts.content,
streaming = opts.streaming,
tool_use_id = opts.tool_use_id,
}
return replace_in_file.func(new_opts, on_log, on_complete, session_ctx)
return str_replace.func(new_opts, on_log, on_complete, session_ctx)
end
return M