feat: streaming diff (#2107)
This commit is contained in:
@@ -25,38 +25,14 @@ end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ command: "view" | "str_replace" | "create" | "insert" | "undo_edit", path: string, old_str?: string, new_str?: string, file_text?: string, insert_line?: integer, new_str?: string, view_range?: integer[] }>
|
||||
function M.str_replace_editor(opts, on_log, on_complete, session_ctx)
|
||||
if on_log then on_log("command: " .. opts.command) end
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if opts.command == "view" then
|
||||
local view = require("avante.llm_tools.view")
|
||||
local opts_ = { path = opts.path }
|
||||
if opts.view_range then
|
||||
local start_line, end_line = unpack(opts.view_range)
|
||||
opts_.view_range = {
|
||||
start_line = start_line,
|
||||
end_line = end_line,
|
||||
}
|
||||
end
|
||||
return view(opts_, on_log, on_complete, session_ctx)
|
||||
end
|
||||
if opts.command == "str_replace" then
|
||||
return require("avante.llm_tools.str_replace").func(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
if opts.command == "create" then
|
||||
return require("avante.llm_tools.create").func(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
if opts.command == "insert" then
|
||||
return require("avante.llm_tools.insert").func(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
if opts.command == "undo_edit" then
|
||||
return require("avante.llm_tools.undo_edit").func(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
return false, "Unknown command: " .. opts.command
|
||||
---@cast opts any
|
||||
return M.str_replace_based_edit_tool(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ command: "view" | "str_replace" | "create" | "insert", path: string, old_str?: string, new_str?: string, file_text?: string, insert_line?: integer, new_str?: string, view_range?: integer[] }>
|
||||
---@type AvanteLLMToolFunc<{ command: "view" | "str_replace" | "create" | "insert", path: string, old_str?: string, new_str?: string, file_text?: string, insert_line?: integer, new_str?: string, view_range?: integer[], streaming?: boolean }>
|
||||
function M.str_replace_based_edit_tool(opts, on_log, on_complete, session_ctx)
|
||||
if on_log then on_log("command: " .. opts.command) end
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
@@ -67,10 +43,8 @@ function M.str_replace_based_edit_tool(opts, on_log, on_complete, session_ctx)
|
||||
local opts_ = { path = opts.path }
|
||||
if opts.view_range then
|
||||
local start_line, end_line = unpack(opts.view_range)
|
||||
opts_.view_range = {
|
||||
start_line = start_line,
|
||||
end_line = end_line,
|
||||
}
|
||||
opts_.start_line = start_line
|
||||
opts_.end_line = end_line
|
||||
end
|
||||
return view(opts_, on_log, on_complete, session_ctx)
|
||||
end
|
||||
@@ -1161,6 +1135,10 @@ M._tools = {
|
||||
default = false,
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
symbol_name = "The name of the symbol to retrieve the definition for, example: fibonacci",
|
||||
show_line_numbers = "true or false",
|
||||
},
|
||||
},
|
||||
returns = {
|
||||
{
|
||||
|
||||
@@ -28,7 +28,8 @@ M.param = {
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "diff",
|
||||
--- IMPORTANT: Using "the_diff" instead of "diff" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "diff", making it impossible to achieve a streaming diff view.
|
||||
name = "the_diff",
|
||||
description = [[
|
||||
One or more SEARCH/REPLACE blocks following this exact format:
|
||||
\`\`\`
|
||||
@@ -61,7 +62,7 @@ One or more SEARCH/REPLACE blocks following this exact format:
|
||||
},
|
||||
usage = {
|
||||
path = "File path here",
|
||||
diff = "Search and replace blocks here",
|
||||
the_diff = "Search and replace blocks here",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -101,13 +102,17 @@ local function fix_diff(diff)
|
||||
return table.concat(fixed_diff_lines, "\n")
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ path: string, diff: string }>
|
||||
--- IMPORTANT: Using "the_diff" instead of "diff" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "diff", making it impossible to achieve a streaming diff view.
|
||||
---@type AvanteLLMToolFunc<{ path: string, diff: string, the_diff?: string, streaming?: boolean, tool_use_id?: string }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
if not opts.path or not opts.diff then return false, "path and diff are required" end
|
||||
if opts.the_diff ~= nil then opts.diff = opts.the_diff end
|
||||
if not opts.path or not opts.diff then return false, "path and diff are required " .. vim.inspect(opts) end
|
||||
if on_log then on_log("path: " .. opts.path) end
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
|
||||
local is_streaming = opts.streaming or false
|
||||
|
||||
local diff = fix_diff(opts.diff)
|
||||
|
||||
if on_log and diff ~= opts.diff then on_log("diff fixed") end
|
||||
@@ -141,14 +146,31 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
end
|
||||
|
||||
-- Handle streaming mode: if we're still in replace mode at the end, include the partial block
|
||||
if is_streaming and is_replacing and #current_search > 0 then
|
||||
if #current_search > #current_replace then current_search = vim.list_slice(current_search, 1, #current_replace) end
|
||||
table.insert(
|
||||
rough_diff_blocks,
|
||||
{ search = table.concat(current_search, "\n"), replace = table.concat(current_replace, "\n") }
|
||||
)
|
||||
end
|
||||
|
||||
if #rough_diff_blocks == 0 then
|
||||
Utils.debug("opts.diff", opts.diff)
|
||||
Utils.debug("diff", diff)
|
||||
-- Utils.debug("opts.diff", opts.diff)
|
||||
-- Utils.debug("diff", diff)
|
||||
return false, "No diff blocks found"
|
||||
end
|
||||
|
||||
local bufnr, err = Helpers.get_bufnr(abs_path)
|
||||
if err then return false, err end
|
||||
|
||||
session_ctx.undo_joined = session_ctx.undo_joined or {}
|
||||
local undo_joined = session_ctx.undo_joined[opts.tool_use_id]
|
||||
if not undo_joined then
|
||||
pcall(vim.cmd.undojoin)
|
||||
session_ctx.undo_joined[opts.tool_use_id] = true
|
||||
end
|
||||
|
||||
local original_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
local sidebar = require("avante").get()
|
||||
if not sidebar then return false, "Avante sidebar not found" end
|
||||
@@ -519,18 +541,87 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
end
|
||||
|
||||
insert_diff_blocks_new_lines()
|
||||
highlight_diff_blocks()
|
||||
register_cursor_move_events()
|
||||
register_keybinding_events()
|
||||
register_buf_write_events()
|
||||
session_ctx.extmark_id_map = session_ctx.extmark_id_map or {}
|
||||
local extmark_id_map = session_ctx.extmark_id_map[opts.tool_use_id]
|
||||
if not extmark_id_map then
|
||||
extmark_id_map = {}
|
||||
session_ctx.extmark_id_map[opts.tool_use_id] = extmark_id_map
|
||||
end
|
||||
session_ctx.virt_lines_map = session_ctx.virt_lines_map or {}
|
||||
local virt_lines_map = session_ctx.virt_lines_map[opts.tool_use_id]
|
||||
if not virt_lines_map then
|
||||
virt_lines_map = {}
|
||||
session_ctx.virt_lines_map[opts.tool_use_id] = virt_lines_map
|
||||
end
|
||||
|
||||
local function highlight_streaming_diff_blocks()
|
||||
vim.api.nvim_buf_clear_namespace(bufnr, NAMESPACE, 0, -1)
|
||||
local max_col = vim.o.columns
|
||||
for _, diff_block in ipairs(diff_blocks) do
|
||||
local start_line = diff_block.start_line
|
||||
if #diff_block.old_lines > 0 then
|
||||
vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, start_line - 1, 0, {
|
||||
hl_group = Highlights.TO_BE_DELETED_WITHOUT_STRIKETHROUGH,
|
||||
hl_eol = true,
|
||||
hl_mode = "combine",
|
||||
end_row = start_line + #diff_block.old_lines - 1,
|
||||
})
|
||||
end
|
||||
if #diff_block.new_lines == 0 then goto continue end
|
||||
local virt_lines = vim
|
||||
.iter(diff_block.new_lines)
|
||||
:map(function(line)
|
||||
--- append spaces to the end of the line
|
||||
local line_ = line .. string.rep(" ", max_col - #line)
|
||||
return { { line_, Highlights.INCOMING } }
|
||||
end)
|
||||
:totable()
|
||||
local extmark_line
|
||||
if #diff_block.old_lines > 0 then
|
||||
extmark_line = math.max(0, start_line - 2 + #diff_block.old_lines)
|
||||
else
|
||||
extmark_line = math.max(0, start_line - 1 + #diff_block.old_lines)
|
||||
end
|
||||
vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, extmark_line, 0, {
|
||||
virt_lines = virt_lines,
|
||||
hl_eol = true,
|
||||
hl_mode = "combine",
|
||||
})
|
||||
::continue::
|
||||
end
|
||||
end
|
||||
|
||||
if not is_streaming then
|
||||
insert_diff_blocks_new_lines()
|
||||
highlight_diff_blocks()
|
||||
register_cursor_move_events()
|
||||
register_keybinding_events()
|
||||
register_buf_write_events()
|
||||
else
|
||||
highlight_streaming_diff_blocks()
|
||||
end
|
||||
|
||||
if diff_blocks[1] then
|
||||
local winnr = Utils.get_winid(bufnr)
|
||||
vim.api.nvim_win_set_cursor(winnr, { diff_blocks[1].new_start_line, 0 })
|
||||
vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end)
|
||||
if is_streaming then
|
||||
-- In streaming mode, focus on the last diff block
|
||||
local last_diff_block = diff_blocks[#diff_blocks]
|
||||
vim.api.nvim_win_set_cursor(winnr, { last_diff_block.start_line, 0 })
|
||||
vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end)
|
||||
else
|
||||
-- In normal mode, focus on the first diff block
|
||||
vim.api.nvim_win_set_cursor(winnr, { diff_blocks[1].new_start_line, 0 })
|
||||
vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end)
|
||||
end
|
||||
end
|
||||
|
||||
if is_streaming then
|
||||
-- In streaming mode, don't show confirmation dialog, just apply changes
|
||||
return
|
||||
end
|
||||
|
||||
pcall(vim.cmd.undojoin)
|
||||
|
||||
confirm = Helpers.confirm("Are you sure you want to apply this modification?", function(ok, reason)
|
||||
clear()
|
||||
if not ok then
|
||||
|
||||
@@ -54,13 +54,16 @@ M.returns = {
|
||||
},
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolFunc<{ path: string, old_str: string, new_str: string }>
|
||||
---@type AvanteLLMToolFunc<{ path: string, old_str: string, new_str: string, streaming?: boolean, tool_use_id?: string }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
local replace_in_file = require("avante.llm_tools.replace_in_file")
|
||||
local diff = "<<<<<<< SEARCH\n" .. opts.old_str .. "\n=======\n" .. opts.new_str .. "\n>>>>>>> REPLACE"
|
||||
local diff = "<<<<<<< SEARCH\n" .. opts.old_str .. "\n=======\n" .. opts.new_str
|
||||
if not opts.streaming then diff = diff .. "\n>>>>>>> REPLACE" end
|
||||
local new_opts = {
|
||||
path = opts.path,
|
||||
diff = diff,
|
||||
streaming = opts.streaming,
|
||||
tool_use_id = opts.tool_use_id,
|
||||
}
|
||||
return replace_in_file.func(new_opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
|
||||
@@ -28,14 +28,15 @@ M.param = {
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "content",
|
||||
--- IMPORTANT: Using "the_content" instead of "content" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "content", making it impossible to achieve a stream diff view.
|
||||
name = "the_content",
|
||||
description = "The content to write to the file. ALWAYS provide the COMPLETE intended content of the file, without any truncation or omissions. You MUST include ALL parts of the file, even if they haven't been modified.",
|
||||
type = "string",
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
path = "File path here",
|
||||
content = "File content here",
|
||||
the_content = "File content here",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -54,21 +55,25 @@ M.returns = {
|
||||
},
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolFunc<{ path: string, content: string }>
|
||||
--- IMPORTANT: Using "the_content" instead of "content" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "content", making it impossible to achieve a stream diff view.
|
||||
---@type AvanteLLMToolFunc<{ path: string, content: string, the_content?: string, streaming?: boolean, tool_use_id?: string }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
if opts.the_content ~= nil then opts.content = opts.the_content end
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if opts.content == nil then return false, "content not provided" end
|
||||
local old_lines = Utils.read_file_from_buf_or_disk(abs_path)
|
||||
local old_content = table.concat(old_lines or {}, "\n")
|
||||
local replace_in_file = require("avante.llm_tools.replace_in_file")
|
||||
local diff = "<<<<<<< SEARCH\n" .. old_content .. "\n=======\n" .. opts.content .. "\n>>>>>>> REPLACE"
|
||||
local str_replace = require("avante.llm_tools.str_replace")
|
||||
local new_opts = {
|
||||
path = opts.path,
|
||||
diff = diff,
|
||||
old_str = old_content,
|
||||
new_str = opts.content,
|
||||
streaming = opts.streaming,
|
||||
tool_use_id = opts.tool_use_id,
|
||||
}
|
||||
return replace_in_file.func(new_opts, on_log, on_complete, session_ctx)
|
||||
return str_replace.func(new_opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
|
||||
return M
|
||||
|
||||
Reference in New Issue
Block a user