optimize: streaming diff performance (#2145)
This commit is contained in:
@@ -120,38 +120,37 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
local diff_lines = vim.split(diff, "\n")
|
||||
local is_searching = false
|
||||
local is_replacing = false
|
||||
local current_search = {}
|
||||
local current_replace = {}
|
||||
local current_old_lines = {}
|
||||
local current_new_lines = {}
|
||||
local rough_diff_blocks = {}
|
||||
|
||||
for _, line in ipairs(diff_lines) do
|
||||
if line:match("^%s*-------* SEARCH") then
|
||||
is_searching = true
|
||||
is_replacing = false
|
||||
current_search = {}
|
||||
current_old_lines = {}
|
||||
elseif line:match("^%s*=======*") and is_searching then
|
||||
is_searching = false
|
||||
is_replacing = true
|
||||
current_replace = {}
|
||||
current_new_lines = {}
|
||||
elseif line:match("^%s*+++++++* REPLACE") and is_replacing then
|
||||
is_replacing = false
|
||||
table.insert(
|
||||
rough_diff_blocks,
|
||||
{ search = table.concat(current_search, "\n"), replace = table.concat(current_replace, "\n") }
|
||||
)
|
||||
table.insert(rough_diff_blocks, { old_lines = current_old_lines, new_lines = current_new_lines })
|
||||
elseif is_searching then
|
||||
table.insert(current_search, line)
|
||||
table.insert(current_old_lines, line)
|
||||
elseif is_replacing then
|
||||
table.insert(current_replace, line)
|
||||
table.insert(current_new_lines, line)
|
||||
end
|
||||
end
|
||||
|
||||
-- Handle streaming mode: if we're still in replace mode at the end, include the partial block
|
||||
if is_streaming and is_replacing and #current_search > 0 then
|
||||
if #current_search > #current_replace then current_search = vim.list_slice(current_search, 1, #current_replace) end
|
||||
if is_streaming and is_replacing and #current_old_lines > 0 then
|
||||
if #current_old_lines > #current_new_lines then
|
||||
current_old_lines = vim.list_slice(current_old_lines, 1, #current_new_lines)
|
||||
end
|
||||
table.insert(
|
||||
rough_diff_blocks,
|
||||
{ search = table.concat(current_search, "\n"), replace = table.concat(current_replace, "\n") }
|
||||
{ old_lines = current_old_lines, new_lines = current_new_lines, is_replacing = true }
|
||||
)
|
||||
end
|
||||
|
||||
@@ -175,14 +174,15 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
local sidebar = require("avante").get()
|
||||
if not sidebar then return false, "Avante sidebar not found" end
|
||||
|
||||
local function parse_rough_diff_block(rough_diff_block, current_lines)
|
||||
local old_lines = vim.split(rough_diff_block.search, "\n")
|
||||
local new_lines = vim.split(rough_diff_block.replace, "\n")
|
||||
--- add line numbers to rough_diff_block
|
||||
local function complete_rough_diff_block(rough_diff_block)
|
||||
local old_lines = rough_diff_block.old_lines
|
||||
local new_lines = rough_diff_block.new_lines
|
||||
local start_line, end_line
|
||||
for i = 1, #current_lines - #old_lines + 1 do
|
||||
for i = 1, #original_lines - #old_lines + 1 do
|
||||
local match = true
|
||||
for j = 1, #old_lines do
|
||||
if Utils.remove_indentation(current_lines[i + j - 1]) ~= Utils.remove_indentation(old_lines[j]) then
|
||||
if Utils.remove_indentation(original_lines[i + j - 1]) ~= Utils.remove_indentation(old_lines[j]) then
|
||||
match = false
|
||||
break
|
||||
end
|
||||
@@ -194,36 +194,55 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
end
|
||||
if start_line == nil or end_line == nil then
|
||||
return "Failed to find the old string:\n" .. rough_diff_block.search
|
||||
local old_string = table.concat(old_lines, "\n")
|
||||
return "Failed to find the old string:\n" .. old_string
|
||||
end
|
||||
local old_str = rough_diff_block.search
|
||||
local new_str = rough_diff_block.replace
|
||||
local original_indentation = Utils.get_indentation(current_lines[start_line])
|
||||
local original_indentation = Utils.get_indentation(original_lines[start_line])
|
||||
if original_indentation ~= Utils.get_indentation(old_lines[1]) then
|
||||
old_lines = vim.tbl_map(function(line) return original_indentation .. line end, old_lines)
|
||||
new_lines = vim.tbl_map(function(line) return original_indentation .. line end, new_lines)
|
||||
old_str = table.concat(old_lines, "\n")
|
||||
new_str = table.concat(new_lines, "\n")
|
||||
end
|
||||
rough_diff_block.old_lines = old_lines
|
||||
rough_diff_block.new_lines = new_lines
|
||||
rough_diff_block.search = old_str
|
||||
rough_diff_block.replace = new_str
|
||||
rough_diff_block.start_line = start_line
|
||||
rough_diff_block.end_line = end_line
|
||||
return nil
|
||||
end
|
||||
|
||||
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map = session_ctx.rough_diff_blocks_to_diff_blocks_cache_map or {}
|
||||
local rough_diff_blocks_to_diff_blocks_cache =
|
||||
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[opts.tool_use_id]
|
||||
if not rough_diff_blocks_to_diff_blocks_cache then
|
||||
rough_diff_blocks_to_diff_blocks_cache = {}
|
||||
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[opts.tool_use_id] = rough_diff_blocks_to_diff_blocks_cache
|
||||
end
|
||||
|
||||
local function rough_diff_blocks_to_diff_blocks(rough_diff_blocks_)
|
||||
local res = {}
|
||||
local base_line_ = 0
|
||||
for _, rough_diff_block in ipairs(rough_diff_blocks_) do
|
||||
for idx, rough_diff_block in ipairs(rough_diff_blocks_) do
|
||||
local cache_key = string.format("%s:%s", idx, #rough_diff_block.new_lines)
|
||||
local cached_diff_blocks = rough_diff_blocks_to_diff_blocks_cache[cache_key]
|
||||
if cached_diff_blocks then
|
||||
res = vim.list_extend(res, cached_diff_blocks.diff_blocks)
|
||||
base_line_ = cached_diff_blocks.base_line
|
||||
goto continue
|
||||
end
|
||||
local old_lines = rough_diff_block.old_lines
|
||||
local new_lines = rough_diff_block.new_lines
|
||||
if rough_diff_block.is_replacing then
|
||||
new_lines = vim.list_slice(new_lines, 1, #new_lines - 1)
|
||||
old_lines = vim.list_slice(old_lines, 1, #new_lines)
|
||||
end
|
||||
local old_string = table.concat(old_lines, "\n")
|
||||
local new_string = table.concat(new_lines, "\n")
|
||||
---@diagnostic disable-next-line: assign-type-mismatch, missing-fields
|
||||
local patch = vim.diff(rough_diff_block.search, rough_diff_block.replace, { ---@type integer[][]
|
||||
local patch = vim.diff(old_string, new_string, { ---@type integer[][]
|
||||
algorithm = "histogram",
|
||||
result_type = "indices",
|
||||
ctxlen = vim.o.scrolloff,
|
||||
})
|
||||
local diff_blocks_ = {}
|
||||
for _, hunk in ipairs(patch) do
|
||||
local start_a, count_a, start_b, count_b = unpack(hunk)
|
||||
local diff_block = {}
|
||||
@@ -243,9 +262,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
diff_block.start_line = base_line_ + rough_diff_block.start_line + start_a
|
||||
end
|
||||
diff_block.end_line = base_line_ + rough_diff_block.start_line + start_a + math.max(count_a, 1) - 2
|
||||
diff_block.search = table.concat(diff_block.old_lines, "\n")
|
||||
diff_block.replace = table.concat(diff_block.new_lines, "\n")
|
||||
table.insert(res, diff_block)
|
||||
table.insert(diff_blocks_, diff_block)
|
||||
end
|
||||
|
||||
local distance = 0
|
||||
@@ -257,12 +274,18 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
local old_distance = #rough_diff_block.new_lines - #rough_diff_block.old_lines
|
||||
|
||||
base_line_ = base_line_ + distance - old_distance
|
||||
|
||||
rough_diff_blocks_to_diff_blocks_cache[cache_key] = { diff_blocks = diff_blocks_, base_line = base_line_ }
|
||||
|
||||
res = vim.list_extend(res, diff_blocks_)
|
||||
|
||||
::continue::
|
||||
end
|
||||
return res
|
||||
end
|
||||
|
||||
for _, rough_diff_block in ipairs(rough_diff_blocks) do
|
||||
local error = parse_rough_diff_block(rough_diff_block, original_lines)
|
||||
local error = complete_rough_diff_block(rough_diff_block)
|
||||
if error then
|
||||
on_complete(false, error)
|
||||
return
|
||||
@@ -554,10 +577,46 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
session_ctx.virt_lines_map[opts.tool_use_id] = virt_lines_map
|
||||
end
|
||||
|
||||
session_ctx.last_orig_diff_end_line_map = session_ctx.last_orig_diff_end_line_map or {}
|
||||
local last_orig_diff_end_line = session_ctx.last_orig_diff_end_line_map[opts.tool_use_id]
|
||||
if not last_orig_diff_end_line then
|
||||
last_orig_diff_end_line = 1
|
||||
session_ctx.last_orig_diff_end_line_map[opts.tool_use_id] = last_orig_diff_end_line
|
||||
end
|
||||
session_ctx.last_resp_diff_end_line_map = session_ctx.last_resp_diff_end_line_map or {}
|
||||
local last_resp_diff_end_line = session_ctx.last_resp_diff_end_line_map[opts.tool_use_id]
|
||||
if not last_resp_diff_end_line then
|
||||
last_resp_diff_end_line = 1
|
||||
session_ctx.last_resp_diff_end_line_map[opts.tool_use_id] = last_resp_diff_end_line
|
||||
end
|
||||
session_ctx.prev_diff_blocks_map = session_ctx.prev_diff_blocks_map or {}
|
||||
local prev_diff_blocks = session_ctx.prev_diff_blocks_map[opts.tool_use_id]
|
||||
if not prev_diff_blocks then
|
||||
prev_diff_blocks = {}
|
||||
session_ctx.prev_diff_blocks_map[opts.tool_use_id] = prev_diff_blocks
|
||||
end
|
||||
|
||||
local function get_unstable_diff_blocks(diff_blocks_)
|
||||
local new_diff_blocks = {}
|
||||
for _, diff_block in ipairs(diff_blocks_) do
|
||||
local has = vim.iter(prev_diff_blocks):find(function(prev_diff_block)
|
||||
if prev_diff_block.start_line ~= diff_block.start_line then return false end
|
||||
if prev_diff_block.end_line ~= diff_block.end_line then return false end
|
||||
if #prev_diff_block.old_lines ~= #diff_block.old_lines then return false end
|
||||
if #prev_diff_block.new_lines ~= #diff_block.new_lines then return false end
|
||||
return true
|
||||
end)
|
||||
if has == nil then table.insert(new_diff_blocks, diff_block) end
|
||||
end
|
||||
return new_diff_blocks
|
||||
end
|
||||
|
||||
local function highlight_streaming_diff_blocks()
|
||||
vim.api.nvim_buf_clear_namespace(bufnr, NAMESPACE, 0, -1)
|
||||
local unstable_diff_blocks = get_unstable_diff_blocks(diff_blocks)
|
||||
session_ctx.prev_diff_blocks_map[opts.tool_use_id] = diff_blocks
|
||||
local max_col = vim.o.columns
|
||||
for _, diff_block in ipairs(diff_blocks) do
|
||||
for _, diff_block in ipairs(unstable_diff_blocks) do
|
||||
local new_lines = diff_block.new_lines
|
||||
local start_line = diff_block.start_line
|
||||
if #diff_block.old_lines > 0 then
|
||||
vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, start_line - 1, 0, {
|
||||
@@ -567,9 +626,9 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
end_row = start_line + #diff_block.old_lines - 1,
|
||||
})
|
||||
end
|
||||
if #diff_block.new_lines == 0 then goto continue end
|
||||
if #new_lines == 0 then goto continue end
|
||||
local virt_lines = vim
|
||||
.iter(diff_block.new_lines)
|
||||
.iter(new_lines)
|
||||
:map(function(line)
|
||||
--- append spaces to the end of the line
|
||||
local line_ = line .. string.rep(" ", max_col - #line)
|
||||
@@ -582,11 +641,15 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
else
|
||||
extmark_line = math.max(0, start_line - 1 + #diff_block.old_lines)
|
||||
end
|
||||
vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, extmark_line, 0, {
|
||||
-- Utils.debug("extmark_line", extmark_line, "idx", idx, "start_line", diff_block.start_line, "old_lines", table.concat(diff_block.old_lines, "\n"))
|
||||
local old_extmark_id = extmark_id_map[start_line]
|
||||
if old_extmark_id then vim.api.nvim_buf_del_extmark(bufnr, NAMESPACE, old_extmark_id) end
|
||||
local extmark_id = vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, extmark_line, 0, {
|
||||
virt_lines = virt_lines,
|
||||
hl_eol = true,
|
||||
hl_mode = "combine",
|
||||
})
|
||||
extmark_id_map[start_line] = extmark_id
|
||||
::continue::
|
||||
end
|
||||
end
|
||||
|
||||
@@ -226,7 +226,13 @@ function M:add_text_message(ctx, text, state, opts)
|
||||
if llm_tool_names == nil then llm_tool_names = LlmTools.get_tool_names() end
|
||||
if ctx.content == nil then ctx.content = "" end
|
||||
ctx.content = ctx.content .. text
|
||||
local content = ctx.content:gsub("<tool_code>", ""):gsub("</tool_code>", "")
|
||||
local content = ctx.content
|
||||
:gsub("<tool_code>", "")
|
||||
:gsub("</tool_code>", "")
|
||||
:gsub("<tool_call>", "")
|
||||
:gsub("</tool_call>", "")
|
||||
:gsub("<tool_use>", "")
|
||||
:gsub("</tool_use>", "")
|
||||
ctx.content = content
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
@@ -278,7 +284,8 @@ function M:add_text_message(ctx, text, state, opts)
|
||||
end
|
||||
end
|
||||
if next(input) ~= nil then
|
||||
local tool_use_id = Utils.uuid()
|
||||
local msg_uuid = ctx.content_uuid .. "-" .. idx
|
||||
local tool_use_id = msg_uuid
|
||||
local msg_ = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = {
|
||||
@@ -291,7 +298,7 @@ function M:add_text_message(ctx, text, state, opts)
|
||||
},
|
||||
}, {
|
||||
state = state,
|
||||
uuid = ctx.content_uuid .. "-" .. idx,
|
||||
uuid = msg_uuid,
|
||||
})
|
||||
msgs[#msgs + 1] = msg_
|
||||
ctx.tool_use_list = ctx.tool_use_list or {}
|
||||
|
||||
Reference in New Issue
Block a user