fix: diff format (#2510)
This commit is contained in:
@@ -3,6 +3,7 @@ local Helpers = require("avante.llm_tools.helpers")
|
||||
local Utils = require("avante.utils")
|
||||
local Highlights = require("avante.highlights")
|
||||
local Config = require("avante.config")
|
||||
local diff2search_replace = require("avante.utils.diff2search_replace")
|
||||
|
||||
local PRIORITY = (vim.hl or vim.highlight).priorities.user
|
||||
local NAMESPACE = vim.api.nvim_create_namespace("avante-diff")
|
||||
@@ -105,14 +106,12 @@ M.returns = {
|
||||
---@param diff string
|
||||
---@return string
|
||||
local function fix_diff(diff)
|
||||
diff = diff2search_replace(diff)
|
||||
-- Normalize block headers to the expected ones (fix for some LLMs output)
|
||||
diff = diff:gsub("<<<<<<<%s*SEARCH", "------- SEARCH")
|
||||
diff = diff:gsub(">>>>>>>%s*REPLACE", "+++++++ REPLACE")
|
||||
diff = diff:gsub("-------%s*REPLACE", "+++++++ REPLACE")
|
||||
|
||||
local has_search_line = diff:match("^%s*-------* SEARCH") ~= nil
|
||||
if has_search_line then return diff end
|
||||
|
||||
local fixed_diff_lines = {}
|
||||
local lines = vim.split(diff, "\n")
|
||||
local first_line = lines[1]
|
||||
@@ -122,38 +121,45 @@ local function fix_diff(diff)
|
||||
fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 2)
|
||||
else
|
||||
table.insert(fixed_diff_lines, "------- SEARCH")
|
||||
fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 1)
|
||||
if first_line:match("------- SEARCH") then
|
||||
fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 2)
|
||||
else
|
||||
fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 1)
|
||||
end
|
||||
end
|
||||
local the_final_diff_lines = {}
|
||||
local has_split_line = false
|
||||
local replace_block_closed = false
|
||||
for _, line in ipairs(fixed_diff_lines) do
|
||||
if line:match("^-------%s*SEARCH") then has_split_line = false end
|
||||
if line:match("^=======") then has_split_line = true end
|
||||
if line:match("^+++++++%s*REPLACE") then
|
||||
if not has_split_line then
|
||||
table.insert(the_final_diff_lines, "=======")
|
||||
has_split_line = true
|
||||
goto continue
|
||||
else
|
||||
replace_block_closed = true
|
||||
end
|
||||
end
|
||||
table.insert(the_final_diff_lines, line)
|
||||
::continue::
|
||||
end
|
||||
if not replace_block_closed then table.insert(the_final_diff_lines, "+++++++ REPLACE") end
|
||||
return table.concat(the_final_diff_lines, "\n")
|
||||
end
|
||||
|
||||
--- IMPORTANT: Using "the_diff" instead of "diff" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "diff", making it impossible to achieve a streaming diff view.
|
||||
---@type AvanteLLMToolFunc<{ path: string, diff: string, the_diff?: string }>
|
||||
---@type AvanteLLMToolFunc<{ path: string, the_diff?: string }>
|
||||
function M.func(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
local session_ctx = opts.session_ctx
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
|
||||
if input.the_diff ~= nil then
|
||||
input.diff = input.the_diff
|
||||
input.the_diff = nil
|
||||
if not input.path or not input.the_diff then
|
||||
return false, "path and the_diff are required " .. vim.inspect(input)
|
||||
end
|
||||
if not input.path or not input.diff then return false, "path and diff are required " .. vim.inspect(input) end
|
||||
if on_log then on_log("path: " .. input.path) end
|
||||
local abs_path = Helpers.get_abs_path(input.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
@@ -169,7 +175,7 @@ function M.func(input, opts)
|
||||
return false, "Diff hasn't changed in the last 2 seconds"
|
||||
end
|
||||
end
|
||||
local streaming_diff_lines_count = Utils.count_lines(input.diff)
|
||||
local streaming_diff_lines_count = Utils.count_lines(input.the_diff)
|
||||
session_ctx.streaming_diff_lines_count_history = session_ctx.streaming_diff_lines_count_history or {}
|
||||
local prev_streaming_diff_lines_count = session_ctx.streaming_diff_lines_count_history[opts.tool_use_id]
|
||||
if streaming_diff_lines_count == prev_streaming_diff_lines_count then
|
||||
@@ -178,9 +184,9 @@ function M.func(input, opts)
|
||||
session_ctx.streaming_diff_lines_count_history[opts.tool_use_id] = streaming_diff_lines_count
|
||||
end
|
||||
|
||||
local diff = fix_diff(input.diff)
|
||||
local diff = fix_diff(input.the_diff)
|
||||
|
||||
if on_log and diff ~= input.diff then on_log("diff fixed") end
|
||||
if on_log and diff ~= input.the_diff then on_log("diff fixed") end
|
||||
|
||||
local diff_lines = vim.split(diff, "\n")
|
||||
|
||||
|
||||
@@ -58,19 +58,15 @@ M.returns = {
|
||||
}
|
||||
|
||||
--- IMPORTANT: Using "the_content" instead of "content" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "content", making it impossible to achieve a stream diff view.
|
||||
---@type AvanteLLMToolFunc<{ path: string, content: string, the_content?: string }>
|
||||
---@type AvanteLLMToolFunc<{ path: string, the_content?: string }>
|
||||
function M.func(input, opts)
|
||||
if input.the_content ~= nil then
|
||||
input.content = input.the_content
|
||||
input.the_content = nil
|
||||
end
|
||||
local abs_path = Helpers.get_abs_path(input.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if input.content == nil then return false, "content not provided" end
|
||||
if type(input.content) ~= "string" then input.content = vim.json.encode(input.content) end
|
||||
if Utils.count_lines(input.content) == 1 then
|
||||
if input.the_content == nil then return false, "the_content not provided" end
|
||||
if type(input.the_content) ~= "string" then input.the_content = vim.json.encode(input.the_content) end
|
||||
if Utils.count_lines(input.the_content) == 1 then
|
||||
Utils.debug("Trimming escapes from content")
|
||||
input.content = Utils.trim_escapes(input.content)
|
||||
input.the_content = Utils.trim_escapes(input.the_content)
|
||||
end
|
||||
local old_lines = Utils.read_file_from_buf_or_disk(abs_path)
|
||||
local old_content = table.concat(old_lines or {}, "\n")
|
||||
@@ -78,7 +74,7 @@ function M.func(input, opts)
|
||||
local new_input = {
|
||||
path = input.path,
|
||||
old_str = old_content,
|
||||
new_str = input.content,
|
||||
new_str = input.the_content,
|
||||
}
|
||||
return str_replace.func(new_input, opts)
|
||||
end
|
||||
|
||||
56
lua/avante/utils/diff2search_replace.lua
Normal file
56
lua/avante/utils/diff2search_replace.lua
Normal file
@@ -0,0 +1,56 @@
|
||||
local function trim(s) return s:gsub("^%s+", ""):gsub("%s+$", "") end
|
||||
|
||||
local function split_lines(text)
|
||||
local lines = {}
|
||||
for line in text:gmatch("[^\r\n]+") do
|
||||
table.insert(lines, line)
|
||||
end
|
||||
return lines
|
||||
end
|
||||
|
||||
local function diff2search_replace(diff_text)
|
||||
if not diff_text:match("^@@") then return diff_text end
|
||||
|
||||
local blocks = {}
|
||||
local pos = 1
|
||||
local len = #diff_text
|
||||
|
||||
-- 解析每一个 @@ 块
|
||||
while pos <= len do
|
||||
-- 找到下一个 @@ 起始
|
||||
local start_at = diff_text:find("@@%s*%-%d+,%d+%s%+", pos)
|
||||
if not start_at then break end
|
||||
|
||||
-- 找到该块结束位置(下一个 @@ 或文件末尾)
|
||||
local next_at = diff_text:find("@@%s*%-%d+,%d+%s%+", start_at + 1)
|
||||
local block_end = next_at and (next_at - 1) or len
|
||||
local block = diff_text:sub(start_at, block_end)
|
||||
|
||||
-- 去掉首行的 @@ ... @@ 行
|
||||
local first_nl = block:find("\n")
|
||||
if first_nl then block = block:sub(first_nl + 1) end
|
||||
|
||||
local search_lines, replace_lines = {}, {}
|
||||
for _, line in ipairs(split_lines(block)) do
|
||||
local first = line:sub(1, 1)
|
||||
if first == "-" then
|
||||
table.insert(search_lines, line:sub(2))
|
||||
elseif first == "+" then
|
||||
table.insert(replace_lines, line:sub(2))
|
||||
elseif first == " " then
|
||||
table.insert(search_lines, line:sub(2))
|
||||
table.insert(replace_lines, line:sub(2))
|
||||
end
|
||||
end
|
||||
|
||||
local search = table.concat(search_lines, "\n")
|
||||
local replace = table.concat(replace_lines, "\n")
|
||||
|
||||
table.insert(blocks, "------- SEARCH\n" .. trim(search) .. "\n=======\n" .. trim(replace) .. "\n+++++++ REPLACE")
|
||||
pos = block_end + 1
|
||||
end
|
||||
|
||||
return table.concat(blocks, "\n\n")
|
||||
end
|
||||
|
||||
return diff2search_replace
|
||||
Reference in New Issue
Block a user