fix: fuzzy match (#2221)

This commit is contained in:
yetone
2025-06-13 16:46:55 +08:00
committed by GitHub
parent f766b42d85
commit fdf4716ec0
4 changed files with 99 additions and 105 deletions

View File

@@ -199,21 +199,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
local function complete_rough_diff_block(rough_diff_block)
local old_lines = rough_diff_block.old_lines
local new_lines = rough_diff_block.new_lines
local start_line, end_line
for i = 1, #original_lines - #old_lines + 1 do
local match = true
for j = 1, #old_lines do
if Utils.remove_indentation(original_lines[i + j - 1]) ~= Utils.remove_indentation(old_lines[j]) then
match = false
break
end
end
if match then
start_line = i
end_line = i + #old_lines - 1
break
end
end
local start_line, end_line = Utils.fuzzy_match(original_lines, old_lines)
if start_line == nil or end_line == nil then
local old_string = table.concat(old_lines, "\n")
return "Failed to find the old string:\n" .. old_string

View File

@@ -368,32 +368,21 @@ local function transform_result_content(result_content, prev_filepath)
local prev_line = result_lines[i - 1]
if prev_line and prev_line:match("^%s*```$") then search_end = i - 1 end
local start_line = 0
local end_line = 0
local match_filetype = nil
local filepath = current_filepath or prev_filepath or ""
if filepath == "" then goto continue end
local file_content = Utils.read_file_from_buf_or_disk(filepath) or {}
local file_content_lines = Utils.read_file_from_buf_or_disk(filepath) or {}
local file_type = Utils.get_filetype(filepath)
if start_line ~= 0 or end_line ~= 0 then break end
for j = 1, #file_content - (search_end - search_start) + 1 do
local match = true
for k = 0, search_end - search_start - 1 do
if
Utils.remove_indentation(file_content[j + k]) ~= Utils.remove_indentation(result_lines[search_start + k])
then
match = false
break
end
end
if match then
start_line = j
end_line = j + (search_end - search_start) - 1
match_filetype = file_type
break
end
local search_lines = vim.list_slice(result_lines, search_start, search_end - 1)
local start_line, end_line = Utils.fuzzy_match(file_content_lines, search_lines)
if start_line ~= nil and end_line ~= nil then
match_filetype = file_type
else
start_line = 0
end_line = 0
end
-- when the filetype isn't detected, fallback to matching based on filepath.
@@ -634,61 +623,6 @@ local function extract_code_snippets_map(response_content)
return snippets_map
end
---@param snippets_map table<string, AvanteCodeSnippet[]>
---@return table<string, AvanteCodeSnippet[]>
local function ensure_snippets_no_overlap(snippets_map)
local new_snippets_map = {}
for filepath, snippets in pairs(snippets_map) do
table.sort(snippets, function(a, b) return a.range[1] < b.range[1] end)
local original_lines = {}
local file_exists = Utils.file.exists(filepath)
if file_exists then
local original_lines_ = Utils.read_file_from_buf_or_disk(filepath)
if original_lines_ then original_lines = original_lines_ end
end
local new_snippets = {}
local last_end_line = 0
for _, snippet in ipairs(snippets) do
if snippet.range[1] > last_end_line then
table.insert(new_snippets, snippet)
last_end_line = snippet.range[2]
elseif not file_exists and #snippets <= 1 then
-- if the file doesn't exist, and we only have 1 snippet, then we don't have to check for overlaps.
table.insert(new_snippets, snippet)
last_end_line = snippet.range[2]
else
local snippet_lines = vim.split(snippet.content, "\n")
-- Trim the overlapping part
local new_start_line = nil
for i = snippet.range[1], math.min(snippet.range[2], last_end_line) do
if
Utils.remove_indentation(original_lines[i])
== Utils.remove_indentation(snippet_lines[i - snippet.range[1] + 1])
then
new_start_line = i + 1
else
break
end
end
if new_start_line ~= nil then
snippet.content = table.concat(vim.list_slice(snippet_lines, new_start_line - snippet.range[1] + 1), "\n")
snippet.range[1] = new_start_line
table.insert(new_snippets, snippet)
last_end_line = snippet.range[2]
else
Utils.error("Failed to ensure snippets no overlap", { once = true, title = "Avante" })
end
end
end
new_snippets_map[filepath] = new_snippets
end
return new_snippets_map
end
local function insert_conflict_contents(bufnr, snippets)
-- sort snippets by start_line
table.sort(snippets, function(a, b) return a.range[1] < b.range[1] end)
@@ -900,7 +834,6 @@ end
function Sidebar:apply(current_cursor)
local response, response_start_line = self:get_content_between_separators()
local all_snippets_map = extract_code_snippets_map(response)
all_snippets_map = ensure_snippets_no_overlap(all_snippets_map)
local selected_snippets_map = {}
if current_cursor then
if self.result_container and self.result_container.winid then

View File

@@ -566,17 +566,69 @@ function M.is_type(type_name, v)
end
-- luacheck: pop
---@param code string
---@param text string
---@return string
function M.get_indentation(code)
if not code then return "" end
return code:match("^%s*") or ""
function M.get_indentation(text)
if not text then return "" end
return text:match("^%s*") or ""
end
--- remove indentation from code: spaces or tabs
function M.remove_indentation(code)
if not code then return code end
return code:gsub("%s*", "")
function M.trim_space(text)
if not text then return text end
return text:gsub("%s*", "")
end
---@param original_lines string[]
---@param target_lines string[]
---@param compare_fn fun(line_a: string, line_b: string): boolean
---@return integer | nil start_line
---@return integer | nil end_line
function M.try_find_match(original_lines, target_lines, compare_fn)
local start_line, end_line
for i = 1, #original_lines - #target_lines + 1 do
local match = true
for j = 1, #target_lines do
if not compare_fn(original_lines[i + j - 1], target_lines[j]) then
match = false
break
end
end
if match then
start_line = i
end_line = i + #target_lines - 1
break
end
end
return start_line, end_line
end
---@param original_lines string[]
---@param target_lines string[]
---@return integer | nil start_line
---@return integer | nil end_line
function M.fuzzy_match(original_lines, target_lines)
local start_line, end_line
---exact match
start_line, end_line = M.try_find_match(
original_lines,
target_lines,
function(line_a, line_b) return line_a == line_b end
)
if start_line ~= nil and end_line ~= nil then return start_line, end_line end
---fuzzy match
start_line, end_line = M.try_find_match(
original_lines,
target_lines,
function(line_a, line_b) return M.trim(line_a, { suffix = " \t" }) == M.trim(line_b, { suffix = " \t" }) end
)
if start_line ~= nil and end_line ~= nil then return start_line, end_line end
---trim_space match
start_line, end_line = M.try_find_match(
original_lines,
target_lines,
function(line_a, line_b) return M.trim_space(line_a) == M.trim_space(line_b) end
)
return start_line, end_line
end
function M.relative_path(absolute)

View File

@@ -65,16 +65,16 @@ describe("Utils", function()
end)
end)
describe("remove_indentation", function()
describe("trime_space", function()
it("should remove indentation correctly", function()
assert.equals("test", Utils.remove_indentation(" test"))
assert.equals("test", Utils.remove_indentation("\ttest"))
assert.equals("test", Utils.remove_indentation("test"))
assert.equals("test", Utils.trim_space(" test"))
assert.equals("test", Utils.trim_space("\ttest"))
assert.equals("test", Utils.trim_space("test"))
end)
it("should handle empty or nil input", function()
assert.equals("", Utils.remove_indentation(""))
assert.equals(nil, Utils.remove_indentation(nil))
assert.equals("", Utils.trim_space(""))
assert.equals(nil, Utils.trim_space(nil))
end)
end)
@@ -208,4 +208,27 @@ describe("Utils", function()
assert.equals(5, result)
end)
end)
describe("fuzzy_match", function()
it("should match exact lines", function()
local lines = { "test", "test2", "test3", "test4" }
local start_line, end_line = Utils.fuzzy_match(lines, { "test2", "test3" })
assert.equals(2, start_line)
assert.equals(3, end_line)
end)
it("should match lines with suffix", function()
local lines = { "test", "test2", "test3", "test4" }
local start_line, end_line = Utils.fuzzy_match(lines, { "test2 \t", "test3" })
assert.equals(2, start_line)
assert.equals(3, end_line)
end)
it("should match lines with space", function()
local lines = { "test", "test2", "test3", "test4" }
local start_line, end_line = Utils.fuzzy_match(lines, { "test2 ", " test3" })
assert.equals(2, start_line)
assert.equals(3, end_line)
end)
end)
end)