diff --git a/lua/avante/llm_tools/replace_in_file.lua b/lua/avante/llm_tools/replace_in_file.lua
index b91ad58..efd4522 100644
--- a/lua/avante/llm_tools/replace_in_file.lua
+++ b/lua/avante/llm_tools/replace_in_file.lua
@@ -120,38 +120,37 @@ function M.func(opts, on_log, on_complete, session_ctx)
local diff_lines = vim.split(diff, "\n")
local is_searching = false
local is_replacing = false
- local current_search = {}
- local current_replace = {}
+ local current_old_lines = {}
+ local current_new_lines = {}
local rough_diff_blocks = {}
for _, line in ipairs(diff_lines) do
if line:match("^%s*-------* SEARCH") then
is_searching = true
is_replacing = false
- current_search = {}
+ current_old_lines = {}
elseif line:match("^%s*=======*") and is_searching then
is_searching = false
is_replacing = true
- current_replace = {}
+ current_new_lines = {}
elseif line:match("^%s*+++++++* REPLACE") and is_replacing then
is_replacing = false
- table.insert(
- rough_diff_blocks,
- { search = table.concat(current_search, "\n"), replace = table.concat(current_replace, "\n") }
- )
+ table.insert(rough_diff_blocks, { old_lines = current_old_lines, new_lines = current_new_lines })
elseif is_searching then
- table.insert(current_search, line)
+ table.insert(current_old_lines, line)
elseif is_replacing then
- table.insert(current_replace, line)
+ table.insert(current_new_lines, line)
end
end
-- Handle streaming mode: if we're still in replace mode at the end, include the partial block
- if is_streaming and is_replacing and #current_search > 0 then
- if #current_search > #current_replace then current_search = vim.list_slice(current_search, 1, #current_replace) end
+ if is_streaming and is_replacing and #current_old_lines > 0 then
+ if #current_old_lines > #current_new_lines then
+ current_old_lines = vim.list_slice(current_old_lines, 1, #current_new_lines)
+ end
table.insert(
rough_diff_blocks,
- { search = table.concat(current_search, "\n"), replace = table.concat(current_replace, "\n") }
+ { old_lines = current_old_lines, new_lines = current_new_lines, is_replacing = true }
)
end
@@ -175,14 +174,15 @@ function M.func(opts, on_log, on_complete, session_ctx)
local sidebar = require("avante").get()
if not sidebar then return false, "Avante sidebar not found" end
- local function parse_rough_diff_block(rough_diff_block, current_lines)
- local old_lines = vim.split(rough_diff_block.search, "\n")
- local new_lines = vim.split(rough_diff_block.replace, "\n")
+ --- add line numbers to rough_diff_block
+ local function complete_rough_diff_block(rough_diff_block)
+ local old_lines = rough_diff_block.old_lines
+ local new_lines = rough_diff_block.new_lines
local start_line, end_line
- for i = 1, #current_lines - #old_lines + 1 do
+ for i = 1, #original_lines - #old_lines + 1 do
local match = true
for j = 1, #old_lines do
- if Utils.remove_indentation(current_lines[i + j - 1]) ~= Utils.remove_indentation(old_lines[j]) then
+ if Utils.remove_indentation(original_lines[i + j - 1]) ~= Utils.remove_indentation(old_lines[j]) then
match = false
break
end
@@ -194,36 +194,55 @@ function M.func(opts, on_log, on_complete, session_ctx)
end
end
if start_line == nil or end_line == nil then
- return "Failed to find the old string:\n" .. rough_diff_block.search
+ local old_string = table.concat(old_lines, "\n")
+ return "Failed to find the old string:\n" .. old_string
end
- local old_str = rough_diff_block.search
- local new_str = rough_diff_block.replace
- local original_indentation = Utils.get_indentation(current_lines[start_line])
+ local original_indentation = Utils.get_indentation(original_lines[start_line])
if original_indentation ~= Utils.get_indentation(old_lines[1]) then
old_lines = vim.tbl_map(function(line) return original_indentation .. line end, old_lines)
new_lines = vim.tbl_map(function(line) return original_indentation .. line end, new_lines)
- old_str = table.concat(old_lines, "\n")
- new_str = table.concat(new_lines, "\n")
end
rough_diff_block.old_lines = old_lines
rough_diff_block.new_lines = new_lines
- rough_diff_block.search = old_str
- rough_diff_block.replace = new_str
rough_diff_block.start_line = start_line
rough_diff_block.end_line = end_line
return nil
end
+ session_ctx.rough_diff_blocks_to_diff_blocks_cache_map = session_ctx.rough_diff_blocks_to_diff_blocks_cache_map or {}
+ local rough_diff_blocks_to_diff_blocks_cache =
+ session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[opts.tool_use_id]
+ if not rough_diff_blocks_to_diff_blocks_cache then
+ rough_diff_blocks_to_diff_blocks_cache = {}
+ session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[opts.tool_use_id] = rough_diff_blocks_to_diff_blocks_cache
+ end
+
local function rough_diff_blocks_to_diff_blocks(rough_diff_blocks_)
local res = {}
local base_line_ = 0
- for _, rough_diff_block in ipairs(rough_diff_blocks_) do
+ for idx, rough_diff_block in ipairs(rough_diff_blocks_) do
+ local cache_key = string.format("%s:%s", idx, #rough_diff_block.new_lines)
+ local cached_diff_blocks = rough_diff_blocks_to_diff_blocks_cache[cache_key]
+ if cached_diff_blocks then
+ res = vim.list_extend(res, cached_diff_blocks.diff_blocks)
+ base_line_ = cached_diff_blocks.base_line
+ goto continue
+ end
+ local old_lines = rough_diff_block.old_lines
+ local new_lines = rough_diff_block.new_lines
+ if rough_diff_block.is_replacing then
+ new_lines = vim.list_slice(new_lines, 1, #new_lines - 1)
+ old_lines = vim.list_slice(old_lines, 1, #new_lines)
+ end
+ local old_string = table.concat(old_lines, "\n")
+ local new_string = table.concat(new_lines, "\n")
---@diagnostic disable-next-line: assign-type-mismatch, missing-fields
- local patch = vim.diff(rough_diff_block.search, rough_diff_block.replace, { ---@type integer[][]
+ local patch = vim.diff(old_string, new_string, { ---@type integer[][]
algorithm = "histogram",
result_type = "indices",
ctxlen = vim.o.scrolloff,
})
+ local diff_blocks_ = {}
for _, hunk in ipairs(patch) do
local start_a, count_a, start_b, count_b = unpack(hunk)
local diff_block = {}
@@ -243,9 +262,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
diff_block.start_line = base_line_ + rough_diff_block.start_line + start_a
end
diff_block.end_line = base_line_ + rough_diff_block.start_line + start_a + math.max(count_a, 1) - 2
- diff_block.search = table.concat(diff_block.old_lines, "\n")
- diff_block.replace = table.concat(diff_block.new_lines, "\n")
- table.insert(res, diff_block)
+ table.insert(diff_blocks_, diff_block)
end
local distance = 0
@@ -257,12 +274,18 @@ function M.func(opts, on_log, on_complete, session_ctx)
local old_distance = #rough_diff_block.new_lines - #rough_diff_block.old_lines
base_line_ = base_line_ + distance - old_distance
+
+ rough_diff_blocks_to_diff_blocks_cache[cache_key] = { diff_blocks = diff_blocks_, base_line = base_line_ }
+
+ res = vim.list_extend(res, diff_blocks_)
+
+ ::continue::
end
return res
end
for _, rough_diff_block in ipairs(rough_diff_blocks) do
- local error = parse_rough_diff_block(rough_diff_block, original_lines)
+ local error = complete_rough_diff_block(rough_diff_block)
if error then
on_complete(false, error)
return
@@ -554,10 +577,46 @@ function M.func(opts, on_log, on_complete, session_ctx)
session_ctx.virt_lines_map[opts.tool_use_id] = virt_lines_map
end
+ session_ctx.last_orig_diff_end_line_map = session_ctx.last_orig_diff_end_line_map or {}
+ local last_orig_diff_end_line = session_ctx.last_orig_diff_end_line_map[opts.tool_use_id]
+ if not last_orig_diff_end_line then
+ last_orig_diff_end_line = 1
+ session_ctx.last_orig_diff_end_line_map[opts.tool_use_id] = last_orig_diff_end_line
+ end
+ session_ctx.last_resp_diff_end_line_map = session_ctx.last_resp_diff_end_line_map or {}
+ local last_resp_diff_end_line = session_ctx.last_resp_diff_end_line_map[opts.tool_use_id]
+ if not last_resp_diff_end_line then
+ last_resp_diff_end_line = 1
+ session_ctx.last_resp_diff_end_line_map[opts.tool_use_id] = last_resp_diff_end_line
+ end
+ session_ctx.prev_diff_blocks_map = session_ctx.prev_diff_blocks_map or {}
+ local prev_diff_blocks = session_ctx.prev_diff_blocks_map[opts.tool_use_id]
+ if not prev_diff_blocks then
+ prev_diff_blocks = {}
+ session_ctx.prev_diff_blocks_map[opts.tool_use_id] = prev_diff_blocks
+ end
+
+ local function get_unstable_diff_blocks(diff_blocks_)
+ local new_diff_blocks = {}
+ for _, diff_block in ipairs(diff_blocks_) do
+ local has = vim.iter(prev_diff_blocks):find(function(prev_diff_block)
+ if prev_diff_block.start_line ~= diff_block.start_line then return false end
+ if prev_diff_block.end_line ~= diff_block.end_line then return false end
+ if #prev_diff_block.old_lines ~= #diff_block.old_lines then return false end
+ if #prev_diff_block.new_lines ~= #diff_block.new_lines then return false end
+ return true
+ end)
+ if has == nil then table.insert(new_diff_blocks, diff_block) end
+ end
+ return new_diff_blocks
+ end
+
local function highlight_streaming_diff_blocks()
- vim.api.nvim_buf_clear_namespace(bufnr, NAMESPACE, 0, -1)
+ local unstable_diff_blocks = get_unstable_diff_blocks(diff_blocks)
+ session_ctx.prev_diff_blocks_map[opts.tool_use_id] = diff_blocks
local max_col = vim.o.columns
- for _, diff_block in ipairs(diff_blocks) do
+ for _, diff_block in ipairs(unstable_diff_blocks) do
+ local new_lines = diff_block.new_lines
local start_line = diff_block.start_line
if #diff_block.old_lines > 0 then
vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, start_line - 1, 0, {
@@ -567,9 +626,9 @@ function M.func(opts, on_log, on_complete, session_ctx)
end_row = start_line + #diff_block.old_lines - 1,
})
end
- if #diff_block.new_lines == 0 then goto continue end
+ if #new_lines == 0 then goto continue end
local virt_lines = vim
- .iter(diff_block.new_lines)
+ .iter(new_lines)
:map(function(line)
--- append spaces to the end of the line
local line_ = line .. string.rep(" ", max_col - #line)
@@ -582,11 +641,15 @@ function M.func(opts, on_log, on_complete, session_ctx)
else
extmark_line = math.max(0, start_line - 1 + #diff_block.old_lines)
end
- vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, extmark_line, 0, {
+ -- Utils.debug("extmark_line", extmark_line, "idx", idx, "start_line", diff_block.start_line, "old_lines", table.concat(diff_block.old_lines, "\n"))
+ local old_extmark_id = extmark_id_map[start_line]
+ if old_extmark_id then vim.api.nvim_buf_del_extmark(bufnr, NAMESPACE, old_extmark_id) end
+ local extmark_id = vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, extmark_line, 0, {
virt_lines = virt_lines,
hl_eol = true,
hl_mode = "combine",
})
+ extmark_id_map[start_line] = extmark_id
::continue::
end
end
diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua
index 9ef7dd9..5e255da 100644
--- a/lua/avante/providers/openai.lua
+++ b/lua/avante/providers/openai.lua
@@ -226,7 +226,13 @@ function M:add_text_message(ctx, text, state, opts)
if llm_tool_names == nil then llm_tool_names = LlmTools.get_tool_names() end
if ctx.content == nil then ctx.content = "" end
ctx.content = ctx.content .. text
- local content = ctx.content:gsub("", ""):gsub("", "")
+ local content = ctx.content
+ :gsub("", "")
+ :gsub("", "")
+ :gsub("", "")
+ :gsub("", "")
+ :gsub("", "")
+ :gsub("", "")
ctx.content = content
local msg = HistoryMessage:new({
role = "assistant",
@@ -278,7 +284,8 @@ function M:add_text_message(ctx, text, state, opts)
end
end
if next(input) ~= nil then
- local tool_use_id = Utils.uuid()
+ local msg_uuid = ctx.content_uuid .. "-" .. idx
+ local tool_use_id = msg_uuid
local msg_ = HistoryMessage:new({
role = "assistant",
content = {
@@ -291,7 +298,7 @@ function M:add_text_message(ctx, text, state, opts)
},
}, {
state = state,
- uuid = ctx.content_uuid .. "-" .. idx,
+ uuid = msg_uuid,
})
msgs[#msgs + 1] = msg_
ctx.tool_use_list = ctx.tool_use_list or {}