diff --git a/lua/avante/llm_tools/replace_in_file.lua b/lua/avante/llm_tools/replace_in_file.lua index f4a67af..e72bd15 100644 --- a/lua/avante/llm_tools/replace_in_file.lua +++ b/lua/avante/llm_tools/replace_in_file.lua @@ -199,21 +199,7 @@ function M.func(opts, on_log, on_complete, session_ctx) local function complete_rough_diff_block(rough_diff_block) local old_lines = rough_diff_block.old_lines local new_lines = rough_diff_block.new_lines - local start_line, end_line - for i = 1, #original_lines - #old_lines + 1 do - local match = true - for j = 1, #old_lines do - if Utils.remove_indentation(original_lines[i + j - 1]) ~= Utils.remove_indentation(old_lines[j]) then - match = false - break - end - end - if match then - start_line = i - end_line = i + #old_lines - 1 - break - end - end + local start_line, end_line = Utils.fuzzy_match(original_lines, old_lines) if start_line == nil or end_line == nil then local old_string = table.concat(old_lines, "\n") return "Failed to find the old string:\n" .. old_string diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index d9cb841..7100408 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -368,32 +368,21 @@ local function transform_result_content(result_content, prev_filepath) local prev_line = result_lines[i - 1] if prev_line and prev_line:match("^%s*```$") then search_end = i - 1 end - local start_line = 0 - local end_line = 0 local match_filetype = nil local filepath = current_filepath or prev_filepath or "" if filepath == "" then goto continue end - local file_content = Utils.read_file_from_buf_or_disk(filepath) or {} + local file_content_lines = Utils.read_file_from_buf_or_disk(filepath) or {} local file_type = Utils.get_filetype(filepath) - if start_line ~= 0 or end_line ~= 0 then break end - for j = 1, #file_content - (search_end - search_start) + 1 do - local match = true - for k = 0, search_end - search_start - 1 do - if - Utils.remove_indentation(file_content[j + k]) ~= Utils.remove_indentation(result_lines[search_start + k]) - then - match = false - break - end - end - if match then - start_line = j - end_line = j + (search_end - search_start) - 1 - match_filetype = file_type - break - end + local search_lines = vim.list_slice(result_lines, search_start, search_end - 1) + local start_line, end_line = Utils.fuzzy_match(file_content_lines, search_lines) + + if start_line ~= nil and end_line ~= nil then + match_filetype = file_type + else + start_line = 0 + end_line = 0 end -- when the filetype isn't detected, fallback to matching based on filepath. @@ -634,61 +623,6 @@ local function extract_code_snippets_map(response_content) return snippets_map end ----@param snippets_map table ----@return table -local function ensure_snippets_no_overlap(snippets_map) - local new_snippets_map = {} - - for filepath, snippets in pairs(snippets_map) do - table.sort(snippets, function(a, b) return a.range[1] < b.range[1] end) - - local original_lines = {} - local file_exists = Utils.file.exists(filepath) - if file_exists then - local original_lines_ = Utils.read_file_from_buf_or_disk(filepath) - if original_lines_ then original_lines = original_lines_ end - end - - local new_snippets = {} - local last_end_line = 0 - for _, snippet in ipairs(snippets) do - if snippet.range[1] > last_end_line then - table.insert(new_snippets, snippet) - last_end_line = snippet.range[2] - elseif not file_exists and #snippets <= 1 then - -- if the file doesn't exist, and we only have 1 snippet, then we don't have to check for overlaps. - table.insert(new_snippets, snippet) - last_end_line = snippet.range[2] - else - local snippet_lines = vim.split(snippet.content, "\n") - -- Trim the overlapping part - local new_start_line = nil - for i = snippet.range[1], math.min(snippet.range[2], last_end_line) do - if - Utils.remove_indentation(original_lines[i]) - == Utils.remove_indentation(snippet_lines[i - snippet.range[1] + 1]) - then - new_start_line = i + 1 - else - break - end - end - if new_start_line ~= nil then - snippet.content = table.concat(vim.list_slice(snippet_lines, new_start_line - snippet.range[1] + 1), "\n") - snippet.range[1] = new_start_line - table.insert(new_snippets, snippet) - last_end_line = snippet.range[2] - else - Utils.error("Failed to ensure snippets no overlap", { once = true, title = "Avante" }) - end - end - end - new_snippets_map[filepath] = new_snippets - end - - return new_snippets_map -end - local function insert_conflict_contents(bufnr, snippets) -- sort snippets by start_line table.sort(snippets, function(a, b) return a.range[1] < b.range[1] end) @@ -900,7 +834,6 @@ end function Sidebar:apply(current_cursor) local response, response_start_line = self:get_content_between_separators() local all_snippets_map = extract_code_snippets_map(response) - all_snippets_map = ensure_snippets_no_overlap(all_snippets_map) local selected_snippets_map = {} if current_cursor then if self.result_container and self.result_container.winid then diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index 4a13f84..0df5c44 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -566,17 +566,69 @@ function M.is_type(type_name, v) end -- luacheck: pop ----@param code string +---@param text string ---@return string -function M.get_indentation(code) - if not code then return "" end - return code:match("^%s*") or "" +function M.get_indentation(text) + if not text then return "" end + return text:match("^%s*") or "" end ---- remove indentation from code: spaces or tabs -function M.remove_indentation(code) - if not code then return code end - return code:gsub("%s*", "") +function M.trim_space(text) + if not text then return text end + return text:gsub("%s*", "") +end + +---@param original_lines string[] +---@param target_lines string[] +---@param compare_fn fun(line_a: string, line_b: string): boolean +---@return integer | nil start_line +---@return integer | nil end_line +function M.try_find_match(original_lines, target_lines, compare_fn) + local start_line, end_line + for i = 1, #original_lines - #target_lines + 1 do + local match = true + for j = 1, #target_lines do + if not compare_fn(original_lines[i + j - 1], target_lines[j]) then + match = false + break + end + end + if match then + start_line = i + end_line = i + #target_lines - 1 + break + end + end + return start_line, end_line +end + +---@param original_lines string[] +---@param target_lines string[] +---@return integer | nil start_line +---@return integer | nil end_line +function M.fuzzy_match(original_lines, target_lines) + local start_line, end_line + ---exact match + start_line, end_line = M.try_find_match( + original_lines, + target_lines, + function(line_a, line_b) return line_a == line_b end + ) + if start_line ~= nil and end_line ~= nil then return start_line, end_line end + ---fuzzy match + start_line, end_line = M.try_find_match( + original_lines, + target_lines, + function(line_a, line_b) return M.trim(line_a, { suffix = " \t" }) == M.trim(line_b, { suffix = " \t" }) end + ) + if start_line ~= nil and end_line ~= nil then return start_line, end_line end + ---trim_space match + start_line, end_line = M.try_find_match( + original_lines, + target_lines, + function(line_a, line_b) return M.trim_space(line_a) == M.trim_space(line_b) end + ) + return start_line, end_line end function M.relative_path(absolute) diff --git a/tests/utils/init_spec.lua b/tests/utils/init_spec.lua index 71caf4e..1df6c9b 100644 --- a/tests/utils/init_spec.lua +++ b/tests/utils/init_spec.lua @@ -65,16 +65,16 @@ describe("Utils", function() end) end) - describe("remove_indentation", function() + describe("trime_space", function() it("should remove indentation correctly", function() - assert.equals("test", Utils.remove_indentation(" test")) - assert.equals("test", Utils.remove_indentation("\ttest")) - assert.equals("test", Utils.remove_indentation("test")) + assert.equals("test", Utils.trim_space(" test")) + assert.equals("test", Utils.trim_space("\ttest")) + assert.equals("test", Utils.trim_space("test")) end) it("should handle empty or nil input", function() - assert.equals("", Utils.remove_indentation("")) - assert.equals(nil, Utils.remove_indentation(nil)) + assert.equals("", Utils.trim_space("")) + assert.equals(nil, Utils.trim_space(nil)) end) end) @@ -208,4 +208,27 @@ describe("Utils", function() assert.equals(5, result) end) end) + + describe("fuzzy_match", function() + it("should match exact lines", function() + local lines = { "test", "test2", "test3", "test4" } + local start_line, end_line = Utils.fuzzy_match(lines, { "test2", "test3" }) + assert.equals(2, start_line) + assert.equals(3, end_line) + end) + + it("should match lines with suffix", function() + local lines = { "test", "test2", "test3", "test4" } + local start_line, end_line = Utils.fuzzy_match(lines, { "test2 \t", "test3" }) + assert.equals(2, start_line) + assert.equals(3, end_line) + end) + + it("should match lines with space", function() + local lines = { "test", "test2", "test3", "test4" } + local start_line, end_line = Utils.fuzzy_match(lines, { "test2 ", " test3" }) + assert.equals(2, start_line) + assert.equals(3, end_line) + end) + end) end)