diff --git a/lua/avante/llm_tools/replace_in_file.lua b/lua/avante/llm_tools/replace_in_file.lua index b91ad58..efd4522 100644 --- a/lua/avante/llm_tools/replace_in_file.lua +++ b/lua/avante/llm_tools/replace_in_file.lua @@ -120,38 +120,37 @@ function M.func(opts, on_log, on_complete, session_ctx) local diff_lines = vim.split(diff, "\n") local is_searching = false local is_replacing = false - local current_search = {} - local current_replace = {} + local current_old_lines = {} + local current_new_lines = {} local rough_diff_blocks = {} for _, line in ipairs(diff_lines) do if line:match("^%s*-------* SEARCH") then is_searching = true is_replacing = false - current_search = {} + current_old_lines = {} elseif line:match("^%s*=======*") and is_searching then is_searching = false is_replacing = true - current_replace = {} + current_new_lines = {} elseif line:match("^%s*+++++++* REPLACE") and is_replacing then is_replacing = false - table.insert( - rough_diff_blocks, - { search = table.concat(current_search, "\n"), replace = table.concat(current_replace, "\n") } - ) + table.insert(rough_diff_blocks, { old_lines = current_old_lines, new_lines = current_new_lines }) elseif is_searching then - table.insert(current_search, line) + table.insert(current_old_lines, line) elseif is_replacing then - table.insert(current_replace, line) + table.insert(current_new_lines, line) end end -- Handle streaming mode: if we're still in replace mode at the end, include the partial block - if is_streaming and is_replacing and #current_search > 0 then - if #current_search > #current_replace then current_search = vim.list_slice(current_search, 1, #current_replace) end + if is_streaming and is_replacing and #current_old_lines > 0 then + if #current_old_lines > #current_new_lines then + current_old_lines = vim.list_slice(current_old_lines, 1, #current_new_lines) + end table.insert( rough_diff_blocks, - { search = table.concat(current_search, "\n"), replace = table.concat(current_replace, "\n") } + { old_lines = current_old_lines, new_lines = current_new_lines, is_replacing = true } ) end @@ -175,14 +174,15 @@ function M.func(opts, on_log, on_complete, session_ctx) local sidebar = require("avante").get() if not sidebar then return false, "Avante sidebar not found" end - local function parse_rough_diff_block(rough_diff_block, current_lines) - local old_lines = vim.split(rough_diff_block.search, "\n") - local new_lines = vim.split(rough_diff_block.replace, "\n") + --- add line numbers to rough_diff_block + local function complete_rough_diff_block(rough_diff_block) + local old_lines = rough_diff_block.old_lines + local new_lines = rough_diff_block.new_lines local start_line, end_line - for i = 1, #current_lines - #old_lines + 1 do + for i = 1, #original_lines - #old_lines + 1 do local match = true for j = 1, #old_lines do - if Utils.remove_indentation(current_lines[i + j - 1]) ~= Utils.remove_indentation(old_lines[j]) then + if Utils.remove_indentation(original_lines[i + j - 1]) ~= Utils.remove_indentation(old_lines[j]) then match = false break end @@ -194,36 +194,55 @@ function M.func(opts, on_log, on_complete, session_ctx) end end if start_line == nil or end_line == nil then - return "Failed to find the old string:\n" .. rough_diff_block.search + local old_string = table.concat(old_lines, "\n") + return "Failed to find the old string:\n" .. old_string end - local old_str = rough_diff_block.search - local new_str = rough_diff_block.replace - local original_indentation = Utils.get_indentation(current_lines[start_line]) + local original_indentation = Utils.get_indentation(original_lines[start_line]) if original_indentation ~= Utils.get_indentation(old_lines[1]) then old_lines = vim.tbl_map(function(line) return original_indentation .. line end, old_lines) new_lines = vim.tbl_map(function(line) return original_indentation .. line end, new_lines) - old_str = table.concat(old_lines, "\n") - new_str = table.concat(new_lines, "\n") end rough_diff_block.old_lines = old_lines rough_diff_block.new_lines = new_lines - rough_diff_block.search = old_str - rough_diff_block.replace = new_str rough_diff_block.start_line = start_line rough_diff_block.end_line = end_line return nil end + session_ctx.rough_diff_blocks_to_diff_blocks_cache_map = session_ctx.rough_diff_blocks_to_diff_blocks_cache_map or {} + local rough_diff_blocks_to_diff_blocks_cache = + session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[opts.tool_use_id] + if not rough_diff_blocks_to_diff_blocks_cache then + rough_diff_blocks_to_diff_blocks_cache = {} + session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[opts.tool_use_id] = rough_diff_blocks_to_diff_blocks_cache + end + local function rough_diff_blocks_to_diff_blocks(rough_diff_blocks_) local res = {} local base_line_ = 0 - for _, rough_diff_block in ipairs(rough_diff_blocks_) do + for idx, rough_diff_block in ipairs(rough_diff_blocks_) do + local cache_key = string.format("%s:%s", idx, #rough_diff_block.new_lines) + local cached_diff_blocks = rough_diff_blocks_to_diff_blocks_cache[cache_key] + if cached_diff_blocks then + res = vim.list_extend(res, cached_diff_blocks.diff_blocks) + base_line_ = cached_diff_blocks.base_line + goto continue + end + local old_lines = rough_diff_block.old_lines + local new_lines = rough_diff_block.new_lines + if rough_diff_block.is_replacing then + new_lines = vim.list_slice(new_lines, 1, #new_lines - 1) + old_lines = vim.list_slice(old_lines, 1, #new_lines) + end + local old_string = table.concat(old_lines, "\n") + local new_string = table.concat(new_lines, "\n") ---@diagnostic disable-next-line: assign-type-mismatch, missing-fields - local patch = vim.diff(rough_diff_block.search, rough_diff_block.replace, { ---@type integer[][] + local patch = vim.diff(old_string, new_string, { ---@type integer[][] algorithm = "histogram", result_type = "indices", ctxlen = vim.o.scrolloff, }) + local diff_blocks_ = {} for _, hunk in ipairs(patch) do local start_a, count_a, start_b, count_b = unpack(hunk) local diff_block = {} @@ -243,9 +262,7 @@ function M.func(opts, on_log, on_complete, session_ctx) diff_block.start_line = base_line_ + rough_diff_block.start_line + start_a end diff_block.end_line = base_line_ + rough_diff_block.start_line + start_a + math.max(count_a, 1) - 2 - diff_block.search = table.concat(diff_block.old_lines, "\n") - diff_block.replace = table.concat(diff_block.new_lines, "\n") - table.insert(res, diff_block) + table.insert(diff_blocks_, diff_block) end local distance = 0 @@ -257,12 +274,18 @@ function M.func(opts, on_log, on_complete, session_ctx) local old_distance = #rough_diff_block.new_lines - #rough_diff_block.old_lines base_line_ = base_line_ + distance - old_distance + + rough_diff_blocks_to_diff_blocks_cache[cache_key] = { diff_blocks = diff_blocks_, base_line = base_line_ } + + res = vim.list_extend(res, diff_blocks_) + + ::continue:: end return res end for _, rough_diff_block in ipairs(rough_diff_blocks) do - local error = parse_rough_diff_block(rough_diff_block, original_lines) + local error = complete_rough_diff_block(rough_diff_block) if error then on_complete(false, error) return @@ -554,10 +577,46 @@ function M.func(opts, on_log, on_complete, session_ctx) session_ctx.virt_lines_map[opts.tool_use_id] = virt_lines_map end + session_ctx.last_orig_diff_end_line_map = session_ctx.last_orig_diff_end_line_map or {} + local last_orig_diff_end_line = session_ctx.last_orig_diff_end_line_map[opts.tool_use_id] + if not last_orig_diff_end_line then + last_orig_diff_end_line = 1 + session_ctx.last_orig_diff_end_line_map[opts.tool_use_id] = last_orig_diff_end_line + end + session_ctx.last_resp_diff_end_line_map = session_ctx.last_resp_diff_end_line_map or {} + local last_resp_diff_end_line = session_ctx.last_resp_diff_end_line_map[opts.tool_use_id] + if not last_resp_diff_end_line then + last_resp_diff_end_line = 1 + session_ctx.last_resp_diff_end_line_map[opts.tool_use_id] = last_resp_diff_end_line + end + session_ctx.prev_diff_blocks_map = session_ctx.prev_diff_blocks_map or {} + local prev_diff_blocks = session_ctx.prev_diff_blocks_map[opts.tool_use_id] + if not prev_diff_blocks then + prev_diff_blocks = {} + session_ctx.prev_diff_blocks_map[opts.tool_use_id] = prev_diff_blocks + end + + local function get_unstable_diff_blocks(diff_blocks_) + local new_diff_blocks = {} + for _, diff_block in ipairs(diff_blocks_) do + local has = vim.iter(prev_diff_blocks):find(function(prev_diff_block) + if prev_diff_block.start_line ~= diff_block.start_line then return false end + if prev_diff_block.end_line ~= diff_block.end_line then return false end + if #prev_diff_block.old_lines ~= #diff_block.old_lines then return false end + if #prev_diff_block.new_lines ~= #diff_block.new_lines then return false end + return true + end) + if has == nil then table.insert(new_diff_blocks, diff_block) end + end + return new_diff_blocks + end + local function highlight_streaming_diff_blocks() - vim.api.nvim_buf_clear_namespace(bufnr, NAMESPACE, 0, -1) + local unstable_diff_blocks = get_unstable_diff_blocks(diff_blocks) + session_ctx.prev_diff_blocks_map[opts.tool_use_id] = diff_blocks local max_col = vim.o.columns - for _, diff_block in ipairs(diff_blocks) do + for _, diff_block in ipairs(unstable_diff_blocks) do + local new_lines = diff_block.new_lines local start_line = diff_block.start_line if #diff_block.old_lines > 0 then vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, start_line - 1, 0, { @@ -567,9 +626,9 @@ function M.func(opts, on_log, on_complete, session_ctx) end_row = start_line + #diff_block.old_lines - 1, }) end - if #diff_block.new_lines == 0 then goto continue end + if #new_lines == 0 then goto continue end local virt_lines = vim - .iter(diff_block.new_lines) + .iter(new_lines) :map(function(line) --- append spaces to the end of the line local line_ = line .. string.rep(" ", max_col - #line) @@ -582,11 +641,15 @@ function M.func(opts, on_log, on_complete, session_ctx) else extmark_line = math.max(0, start_line - 1 + #diff_block.old_lines) end - vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, extmark_line, 0, { + -- Utils.debug("extmark_line", extmark_line, "idx", idx, "start_line", diff_block.start_line, "old_lines", table.concat(diff_block.old_lines, "\n")) + local old_extmark_id = extmark_id_map[start_line] + if old_extmark_id then vim.api.nvim_buf_del_extmark(bufnr, NAMESPACE, old_extmark_id) end + local extmark_id = vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, extmark_line, 0, { virt_lines = virt_lines, hl_eol = true, hl_mode = "combine", }) + extmark_id_map[start_line] = extmark_id ::continue:: end end diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index 9ef7dd9..5e255da 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -226,7 +226,13 @@ function M:add_text_message(ctx, text, state, opts) if llm_tool_names == nil then llm_tool_names = LlmTools.get_tool_names() end if ctx.content == nil then ctx.content = "" end ctx.content = ctx.content .. text - local content = ctx.content:gsub("", ""):gsub("", "") + local content = ctx.content + :gsub("", "") + :gsub("", "") + :gsub("", "") + :gsub("", "") + :gsub("", "") + :gsub("", "") ctx.content = content local msg = HistoryMessage:new({ role = "assistant", @@ -278,7 +284,8 @@ function M:add_text_message(ctx, text, state, opts) end end if next(input) ~= nil then - local tool_use_id = Utils.uuid() + local msg_uuid = ctx.content_uuid .. "-" .. idx + local tool_use_id = msg_uuid local msg_ = HistoryMessage:new({ role = "assistant", content = { @@ -291,7 +298,7 @@ function M:add_text_message(ctx, text, state, opts) }, }, { state = state, - uuid = ctx.content_uuid .. "-" .. idx, + uuid = msg_uuid, }) msgs[#msgs + 1] = msg_ ctx.tool_use_list = ctx.tool_use_list or {}