fix: diff format (#2510)

This commit is contained in:
yetone
2025-07-22 20:38:57 +08:00
committed by GitHub
parent f4f82a09d7
commit c65604837c
3 changed files with 80 additions and 22 deletions

View File

@@ -3,6 +3,7 @@ local Helpers = require("avante.llm_tools.helpers")
local Utils = require("avante.utils")
local Highlights = require("avante.highlights")
local Config = require("avante.config")
local diff2search_replace = require("avante.utils.diff2search_replace")
local PRIORITY = (vim.hl or vim.highlight).priorities.user
local NAMESPACE = vim.api.nvim_create_namespace("avante-diff")
@@ -105,14 +106,12 @@ M.returns = {
---@param diff string
---@return string
local function fix_diff(diff)
diff = diff2search_replace(diff)
-- Normalize block headers to the expected ones (fix for some LLMs output)
diff = diff:gsub("<<<<<<<%s*SEARCH", "------- SEARCH")
diff = diff:gsub(">>>>>>>%s*REPLACE", "+++++++ REPLACE")
diff = diff:gsub("-------%s*REPLACE", "+++++++ REPLACE")
local has_search_line = diff:match("^%s*-------* SEARCH") ~= nil
if has_search_line then return diff end
local fixed_diff_lines = {}
local lines = vim.split(diff, "\n")
local first_line = lines[1]
@@ -122,38 +121,45 @@ local function fix_diff(diff)
fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 2)
else
table.insert(fixed_diff_lines, "------- SEARCH")
fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 1)
if first_line:match("------- SEARCH") then
fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 2)
else
fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 1)
end
end
local the_final_diff_lines = {}
local has_split_line = false
local replace_block_closed = false
for _, line in ipairs(fixed_diff_lines) do
if line:match("^-------%s*SEARCH") then has_split_line = false end
if line:match("^=======") then has_split_line = true end
if line:match("^+++++++%s*REPLACE") then
if not has_split_line then
table.insert(the_final_diff_lines, "=======")
has_split_line = true
goto continue
else
replace_block_closed = true
end
end
table.insert(the_final_diff_lines, line)
::continue::
end
if not replace_block_closed then table.insert(the_final_diff_lines, "+++++++ REPLACE") end
return table.concat(the_final_diff_lines, "\n")
end
--- IMPORTANT: Using "the_diff" instead of "diff" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "diff", making it impossible to achieve a streaming diff view.
---@type AvanteLLMToolFunc<{ path: string, diff: string, the_diff?: string }>
---@type AvanteLLMToolFunc<{ path: string, the_diff?: string }>
function M.func(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local session_ctx = opts.session_ctx
if not on_complete then return false, "on_complete not provided" end
if input.the_diff ~= nil then
input.diff = input.the_diff
input.the_diff = nil
if not input.path or not input.the_diff then
return false, "path and the_diff are required " .. vim.inspect(input)
end
if not input.path or not input.diff then return false, "path and diff are required " .. vim.inspect(input) end
if on_log then on_log("path: " .. input.path) end
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
@@ -169,7 +175,7 @@ function M.func(input, opts)
return false, "Diff hasn't changed in the last 2 seconds"
end
end
local streaming_diff_lines_count = Utils.count_lines(input.diff)
local streaming_diff_lines_count = Utils.count_lines(input.the_diff)
session_ctx.streaming_diff_lines_count_history = session_ctx.streaming_diff_lines_count_history or {}
local prev_streaming_diff_lines_count = session_ctx.streaming_diff_lines_count_history[opts.tool_use_id]
if streaming_diff_lines_count == prev_streaming_diff_lines_count then
@@ -178,9 +184,9 @@ function M.func(input, opts)
session_ctx.streaming_diff_lines_count_history[opts.tool_use_id] = streaming_diff_lines_count
end
local diff = fix_diff(input.diff)
local diff = fix_diff(input.the_diff)
if on_log and diff ~= input.diff then on_log("diff fixed") end
if on_log and diff ~= input.the_diff then on_log("diff fixed") end
local diff_lines = vim.split(diff, "\n")

View File

@@ -58,19 +58,15 @@ M.returns = {
}
--- IMPORTANT: Using "the_content" instead of "content" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "content", making it impossible to achieve a stream diff view.
---@type AvanteLLMToolFunc<{ path: string, content: string, the_content?: string }>
---@type AvanteLLMToolFunc<{ path: string, the_content?: string }>
function M.func(input, opts)
if input.the_content ~= nil then
input.content = input.the_content
input.the_content = nil
end
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if input.content == nil then return false, "content not provided" end
if type(input.content) ~= "string" then input.content = vim.json.encode(input.content) end
if Utils.count_lines(input.content) == 1 then
if input.the_content == nil then return false, "the_content not provided" end
if type(input.the_content) ~= "string" then input.the_content = vim.json.encode(input.the_content) end
if Utils.count_lines(input.the_content) == 1 then
Utils.debug("Trimming escapes from content")
input.content = Utils.trim_escapes(input.content)
input.the_content = Utils.trim_escapes(input.the_content)
end
local old_lines = Utils.read_file_from_buf_or_disk(abs_path)
local old_content = table.concat(old_lines or {}, "\n")
@@ -78,7 +74,7 @@ function M.func(input, opts)
local new_input = {
path = input.path,
old_str = old_content,
new_str = input.content,
new_str = input.the_content,
}
return str_replace.func(new_input, opts)
end

View File

@@ -0,0 +1,56 @@
local function trim(s) return s:gsub("^%s+", ""):gsub("%s+$", "") end
local function split_lines(text)
local lines = {}
for line in text:gmatch("[^\r\n]+") do
table.insert(lines, line)
end
return lines
end
local function diff2search_replace(diff_text)
if not diff_text:match("^@@") then return diff_text end
local blocks = {}
local pos = 1
local len = #diff_text
-- 解析每一个 @@ 块
while pos <= len do
-- 找到下一个 @@ 起始
local start_at = diff_text:find("@@%s*%-%d+,%d+%s%+", pos)
if not start_at then break end
-- 找到该块结束位置(下一个 @@ 或文件末尾)
local next_at = diff_text:find("@@%s*%-%d+,%d+%s%+", start_at + 1)
local block_end = next_at and (next_at - 1) or len
local block = diff_text:sub(start_at, block_end)
-- 去掉首行的 @@ ... @@ 行
local first_nl = block:find("\n")
if first_nl then block = block:sub(first_nl + 1) end
local search_lines, replace_lines = {}, {}
for _, line in ipairs(split_lines(block)) do
local first = line:sub(1, 1)
if first == "-" then
table.insert(search_lines, line:sub(2))
elseif first == "+" then
table.insert(replace_lines, line:sub(2))
elseif first == " " then
table.insert(search_lines, line:sub(2))
table.insert(replace_lines, line:sub(2))
end
end
local search = table.concat(search_lines, "\n")
local replace = table.concat(replace_lines, "\n")
table.insert(blocks, "------- SEARCH\n" .. trim(search) .. "\n=======\n" .. trim(replace) .. "\n+++++++ REPLACE")
pos = block_end + 1
end
return table.concat(blocks, "\n\n")
end
return diff2search_replace