optimize: streaming diff performance (#2145)

This commit is contained in:
yetone
2025-06-04 15:45:12 +08:00
committed by GitHub
parent b54b806f71
commit 220594a66f
2 changed files with 110 additions and 40 deletions

View File

@@ -120,38 +120,37 @@ function M.func(opts, on_log, on_complete, session_ctx)
local diff_lines = vim.split(diff, "\n")
local is_searching = false
local is_replacing = false
local current_search = {}
local current_replace = {}
local current_old_lines = {}
local current_new_lines = {}
local rough_diff_blocks = {}
for _, line in ipairs(diff_lines) do
if line:match("^%s*-------* SEARCH") then
is_searching = true
is_replacing = false
current_search = {}
current_old_lines = {}
elseif line:match("^%s*=======*") and is_searching then
is_searching = false
is_replacing = true
current_replace = {}
current_new_lines = {}
elseif line:match("^%s*+++++++* REPLACE") and is_replacing then
is_replacing = false
table.insert(
rough_diff_blocks,
{ search = table.concat(current_search, "\n"), replace = table.concat(current_replace, "\n") }
)
table.insert(rough_diff_blocks, { old_lines = current_old_lines, new_lines = current_new_lines })
elseif is_searching then
table.insert(current_search, line)
table.insert(current_old_lines, line)
elseif is_replacing then
table.insert(current_replace, line)
table.insert(current_new_lines, line)
end
end
-- Handle streaming mode: if we're still in replace mode at the end, include the partial block
if is_streaming and is_replacing and #current_search > 0 then
if #current_search > #current_replace then current_search = vim.list_slice(current_search, 1, #current_replace) end
if is_streaming and is_replacing and #current_old_lines > 0 then
if #current_old_lines > #current_new_lines then
current_old_lines = vim.list_slice(current_old_lines, 1, #current_new_lines)
end
table.insert(
rough_diff_blocks,
{ search = table.concat(current_search, "\n"), replace = table.concat(current_replace, "\n") }
{ old_lines = current_old_lines, new_lines = current_new_lines, is_replacing = true }
)
end
@@ -175,14 +174,15 @@ function M.func(opts, on_log, on_complete, session_ctx)
local sidebar = require("avante").get()
if not sidebar then return false, "Avante sidebar not found" end
local function parse_rough_diff_block(rough_diff_block, current_lines)
local old_lines = vim.split(rough_diff_block.search, "\n")
local new_lines = vim.split(rough_diff_block.replace, "\n")
--- add line numbers to rough_diff_block
local function complete_rough_diff_block(rough_diff_block)
local old_lines = rough_diff_block.old_lines
local new_lines = rough_diff_block.new_lines
local start_line, end_line
for i = 1, #current_lines - #old_lines + 1 do
for i = 1, #original_lines - #old_lines + 1 do
local match = true
for j = 1, #old_lines do
if Utils.remove_indentation(current_lines[i + j - 1]) ~= Utils.remove_indentation(old_lines[j]) then
if Utils.remove_indentation(original_lines[i + j - 1]) ~= Utils.remove_indentation(old_lines[j]) then
match = false
break
end
@@ -194,36 +194,55 @@ function M.func(opts, on_log, on_complete, session_ctx)
end
end
if start_line == nil or end_line == nil then
return "Failed to find the old string:\n" .. rough_diff_block.search
local old_string = table.concat(old_lines, "\n")
return "Failed to find the old string:\n" .. old_string
end
local old_str = rough_diff_block.search
local new_str = rough_diff_block.replace
local original_indentation = Utils.get_indentation(current_lines[start_line])
local original_indentation = Utils.get_indentation(original_lines[start_line])
if original_indentation ~= Utils.get_indentation(old_lines[1]) then
old_lines = vim.tbl_map(function(line) return original_indentation .. line end, old_lines)
new_lines = vim.tbl_map(function(line) return original_indentation .. line end, new_lines)
old_str = table.concat(old_lines, "\n")
new_str = table.concat(new_lines, "\n")
end
rough_diff_block.old_lines = old_lines
rough_diff_block.new_lines = new_lines
rough_diff_block.search = old_str
rough_diff_block.replace = new_str
rough_diff_block.start_line = start_line
rough_diff_block.end_line = end_line
return nil
end
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map = session_ctx.rough_diff_blocks_to_diff_blocks_cache_map or {}
local rough_diff_blocks_to_diff_blocks_cache =
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[opts.tool_use_id]
if not rough_diff_blocks_to_diff_blocks_cache then
rough_diff_blocks_to_diff_blocks_cache = {}
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[opts.tool_use_id] = rough_diff_blocks_to_diff_blocks_cache
end
local function rough_diff_blocks_to_diff_blocks(rough_diff_blocks_)
local res = {}
local base_line_ = 0
for _, rough_diff_block in ipairs(rough_diff_blocks_) do
for idx, rough_diff_block in ipairs(rough_diff_blocks_) do
local cache_key = string.format("%s:%s", idx, #rough_diff_block.new_lines)
local cached_diff_blocks = rough_diff_blocks_to_diff_blocks_cache[cache_key]
if cached_diff_blocks then
res = vim.list_extend(res, cached_diff_blocks.diff_blocks)
base_line_ = cached_diff_blocks.base_line
goto continue
end
local old_lines = rough_diff_block.old_lines
local new_lines = rough_diff_block.new_lines
if rough_diff_block.is_replacing then
new_lines = vim.list_slice(new_lines, 1, #new_lines - 1)
old_lines = vim.list_slice(old_lines, 1, #new_lines)
end
local old_string = table.concat(old_lines, "\n")
local new_string = table.concat(new_lines, "\n")
---@diagnostic disable-next-line: assign-type-mismatch, missing-fields
local patch = vim.diff(rough_diff_block.search, rough_diff_block.replace, { ---@type integer[][]
local patch = vim.diff(old_string, new_string, { ---@type integer[][]
algorithm = "histogram",
result_type = "indices",
ctxlen = vim.o.scrolloff,
})
local diff_blocks_ = {}
for _, hunk in ipairs(patch) do
local start_a, count_a, start_b, count_b = unpack(hunk)
local diff_block = {}
@@ -243,9 +262,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
diff_block.start_line = base_line_ + rough_diff_block.start_line + start_a
end
diff_block.end_line = base_line_ + rough_diff_block.start_line + start_a + math.max(count_a, 1) - 2
diff_block.search = table.concat(diff_block.old_lines, "\n")
diff_block.replace = table.concat(diff_block.new_lines, "\n")
table.insert(res, diff_block)
table.insert(diff_blocks_, diff_block)
end
local distance = 0
@@ -257,12 +274,18 @@ function M.func(opts, on_log, on_complete, session_ctx)
local old_distance = #rough_diff_block.new_lines - #rough_diff_block.old_lines
base_line_ = base_line_ + distance - old_distance
rough_diff_blocks_to_diff_blocks_cache[cache_key] = { diff_blocks = diff_blocks_, base_line = base_line_ }
res = vim.list_extend(res, diff_blocks_)
::continue::
end
return res
end
for _, rough_diff_block in ipairs(rough_diff_blocks) do
local error = parse_rough_diff_block(rough_diff_block, original_lines)
local error = complete_rough_diff_block(rough_diff_block)
if error then
on_complete(false, error)
return
@@ -554,10 +577,46 @@ function M.func(opts, on_log, on_complete, session_ctx)
session_ctx.virt_lines_map[opts.tool_use_id] = virt_lines_map
end
session_ctx.last_orig_diff_end_line_map = session_ctx.last_orig_diff_end_line_map or {}
local last_orig_diff_end_line = session_ctx.last_orig_diff_end_line_map[opts.tool_use_id]
if not last_orig_diff_end_line then
last_orig_diff_end_line = 1
session_ctx.last_orig_diff_end_line_map[opts.tool_use_id] = last_orig_diff_end_line
end
session_ctx.last_resp_diff_end_line_map = session_ctx.last_resp_diff_end_line_map or {}
local last_resp_diff_end_line = session_ctx.last_resp_diff_end_line_map[opts.tool_use_id]
if not last_resp_diff_end_line then
last_resp_diff_end_line = 1
session_ctx.last_resp_diff_end_line_map[opts.tool_use_id] = last_resp_diff_end_line
end
session_ctx.prev_diff_blocks_map = session_ctx.prev_diff_blocks_map or {}
local prev_diff_blocks = session_ctx.prev_diff_blocks_map[opts.tool_use_id]
if not prev_diff_blocks then
prev_diff_blocks = {}
session_ctx.prev_diff_blocks_map[opts.tool_use_id] = prev_diff_blocks
end
local function get_unstable_diff_blocks(diff_blocks_)
local new_diff_blocks = {}
for _, diff_block in ipairs(diff_blocks_) do
local has = vim.iter(prev_diff_blocks):find(function(prev_diff_block)
if prev_diff_block.start_line ~= diff_block.start_line then return false end
if prev_diff_block.end_line ~= diff_block.end_line then return false end
if #prev_diff_block.old_lines ~= #diff_block.old_lines then return false end
if #prev_diff_block.new_lines ~= #diff_block.new_lines then return false end
return true
end)
if has == nil then table.insert(new_diff_blocks, diff_block) end
end
return new_diff_blocks
end
local function highlight_streaming_diff_blocks()
vim.api.nvim_buf_clear_namespace(bufnr, NAMESPACE, 0, -1)
local unstable_diff_blocks = get_unstable_diff_blocks(diff_blocks)
session_ctx.prev_diff_blocks_map[opts.tool_use_id] = diff_blocks
local max_col = vim.o.columns
for _, diff_block in ipairs(diff_blocks) do
for _, diff_block in ipairs(unstable_diff_blocks) do
local new_lines = diff_block.new_lines
local start_line = diff_block.start_line
if #diff_block.old_lines > 0 then
vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, start_line - 1, 0, {
@@ -567,9 +626,9 @@ function M.func(opts, on_log, on_complete, session_ctx)
end_row = start_line + #diff_block.old_lines - 1,
})
end
if #diff_block.new_lines == 0 then goto continue end
if #new_lines == 0 then goto continue end
local virt_lines = vim
.iter(diff_block.new_lines)
.iter(new_lines)
:map(function(line)
--- append spaces to the end of the line
local line_ = line .. string.rep(" ", max_col - #line)
@@ -582,11 +641,15 @@ function M.func(opts, on_log, on_complete, session_ctx)
else
extmark_line = math.max(0, start_line - 1 + #diff_block.old_lines)
end
vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, extmark_line, 0, {
-- Utils.debug("extmark_line", extmark_line, "idx", idx, "start_line", diff_block.start_line, "old_lines", table.concat(diff_block.old_lines, "\n"))
local old_extmark_id = extmark_id_map[start_line]
if old_extmark_id then vim.api.nvim_buf_del_extmark(bufnr, NAMESPACE, old_extmark_id) end
local extmark_id = vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, extmark_line, 0, {
virt_lines = virt_lines,
hl_eol = true,
hl_mode = "combine",
})
extmark_id_map[start_line] = extmark_id
::continue::
end
end

View File

@@ -226,7 +226,13 @@ function M:add_text_message(ctx, text, state, opts)
if llm_tool_names == nil then llm_tool_names = LlmTools.get_tool_names() end
if ctx.content == nil then ctx.content = "" end
ctx.content = ctx.content .. text
local content = ctx.content:gsub("<tool_code>", ""):gsub("</tool_code>", "")
local content = ctx.content
:gsub("<tool_code>", "")
:gsub("</tool_code>", "")
:gsub("<tool_call>", "")
:gsub("</tool_call>", "")
:gsub("<tool_use>", "")
:gsub("</tool_use>", "")
ctx.content = content
local msg = HistoryMessage:new({
role = "assistant",
@@ -278,7 +284,8 @@ function M:add_text_message(ctx, text, state, opts)
end
end
if next(input) ~= nil then
local tool_use_id = Utils.uuid()
local msg_uuid = ctx.content_uuid .. "-" .. idx
local tool_use_id = msg_uuid
local msg_ = HistoryMessage:new({
role = "assistant",
content = {
@@ -291,7 +298,7 @@ function M:add_text_message(ctx, text, state, opts)
},
}, {
state = state,
uuid = ctx.content_uuid .. "-" .. idx,
uuid = msg_uuid,
})
msgs[#msgs + 1] = msg_
ctx.tool_use_list = ctx.tool_use_list or {}