fix: diff format (#2510)
This commit is contained in:
@@ -3,6 +3,7 @@ local Helpers = require("avante.llm_tools.helpers")
|
|||||||
local Utils = require("avante.utils")
|
local Utils = require("avante.utils")
|
||||||
local Highlights = require("avante.highlights")
|
local Highlights = require("avante.highlights")
|
||||||
local Config = require("avante.config")
|
local Config = require("avante.config")
|
||||||
|
local diff2search_replace = require("avante.utils.diff2search_replace")
|
||||||
|
|
||||||
local PRIORITY = (vim.hl or vim.highlight).priorities.user
|
local PRIORITY = (vim.hl or vim.highlight).priorities.user
|
||||||
local NAMESPACE = vim.api.nvim_create_namespace("avante-diff")
|
local NAMESPACE = vim.api.nvim_create_namespace("avante-diff")
|
||||||
@@ -105,14 +106,12 @@ M.returns = {
|
|||||||
---@param diff string
|
---@param diff string
|
||||||
---@return string
|
---@return string
|
||||||
local function fix_diff(diff)
|
local function fix_diff(diff)
|
||||||
|
diff = diff2search_replace(diff)
|
||||||
-- Normalize block headers to the expected ones (fix for some LLMs output)
|
-- Normalize block headers to the expected ones (fix for some LLMs output)
|
||||||
diff = diff:gsub("<<<<<<<%s*SEARCH", "------- SEARCH")
|
diff = diff:gsub("<<<<<<<%s*SEARCH", "------- SEARCH")
|
||||||
diff = diff:gsub(">>>>>>>%s*REPLACE", "+++++++ REPLACE")
|
diff = diff:gsub(">>>>>>>%s*REPLACE", "+++++++ REPLACE")
|
||||||
diff = diff:gsub("-------%s*REPLACE", "+++++++ REPLACE")
|
diff = diff:gsub("-------%s*REPLACE", "+++++++ REPLACE")
|
||||||
|
|
||||||
local has_search_line = diff:match("^%s*-------* SEARCH") ~= nil
|
|
||||||
if has_search_line then return diff end
|
|
||||||
|
|
||||||
local fixed_diff_lines = {}
|
local fixed_diff_lines = {}
|
||||||
local lines = vim.split(diff, "\n")
|
local lines = vim.split(diff, "\n")
|
||||||
local first_line = lines[1]
|
local first_line = lines[1]
|
||||||
@@ -122,38 +121,45 @@ local function fix_diff(diff)
|
|||||||
fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 2)
|
fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 2)
|
||||||
else
|
else
|
||||||
table.insert(fixed_diff_lines, "------- SEARCH")
|
table.insert(fixed_diff_lines, "------- SEARCH")
|
||||||
fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 1)
|
if first_line:match("------- SEARCH") then
|
||||||
|
fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 2)
|
||||||
|
else
|
||||||
|
fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 1)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
local the_final_diff_lines = {}
|
local the_final_diff_lines = {}
|
||||||
local has_split_line = false
|
local has_split_line = false
|
||||||
|
local replace_block_closed = false
|
||||||
for _, line in ipairs(fixed_diff_lines) do
|
for _, line in ipairs(fixed_diff_lines) do
|
||||||
if line:match("^-------%s*SEARCH") then has_split_line = false end
|
if line:match("^-------%s*SEARCH") then has_split_line = false end
|
||||||
if line:match("^=======") then has_split_line = true end
|
if line:match("^=======") then has_split_line = true end
|
||||||
if line:match("^+++++++%s*REPLACE") then
|
if line:match("^+++++++%s*REPLACE") then
|
||||||
if not has_split_line then
|
if not has_split_line then
|
||||||
table.insert(the_final_diff_lines, "=======")
|
table.insert(the_final_diff_lines, "=======")
|
||||||
|
has_split_line = true
|
||||||
goto continue
|
goto continue
|
||||||
|
else
|
||||||
|
replace_block_closed = true
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
table.insert(the_final_diff_lines, line)
|
table.insert(the_final_diff_lines, line)
|
||||||
::continue::
|
::continue::
|
||||||
end
|
end
|
||||||
|
if not replace_block_closed then table.insert(the_final_diff_lines, "+++++++ REPLACE") end
|
||||||
return table.concat(the_final_diff_lines, "\n")
|
return table.concat(the_final_diff_lines, "\n")
|
||||||
end
|
end
|
||||||
|
|
||||||
--- IMPORTANT: Using "the_diff" instead of "diff" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "diff", making it impossible to achieve a streaming diff view.
|
--- IMPORTANT: Using "the_diff" instead of "diff" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "diff", making it impossible to achieve a streaming diff view.
|
||||||
---@type AvanteLLMToolFunc<{ path: string, diff: string, the_diff?: string }>
|
---@type AvanteLLMToolFunc<{ path: string, the_diff?: string }>
|
||||||
function M.func(input, opts)
|
function M.func(input, opts)
|
||||||
local on_log = opts.on_log
|
local on_log = opts.on_log
|
||||||
local on_complete = opts.on_complete
|
local on_complete = opts.on_complete
|
||||||
local session_ctx = opts.session_ctx
|
local session_ctx = opts.session_ctx
|
||||||
if not on_complete then return false, "on_complete not provided" end
|
if not on_complete then return false, "on_complete not provided" end
|
||||||
|
|
||||||
if input.the_diff ~= nil then
|
if not input.path or not input.the_diff then
|
||||||
input.diff = input.the_diff
|
return false, "path and the_diff are required " .. vim.inspect(input)
|
||||||
input.the_diff = nil
|
|
||||||
end
|
end
|
||||||
if not input.path or not input.diff then return false, "path and diff are required " .. vim.inspect(input) end
|
|
||||||
if on_log then on_log("path: " .. input.path) end
|
if on_log then on_log("path: " .. input.path) end
|
||||||
local abs_path = Helpers.get_abs_path(input.path)
|
local abs_path = Helpers.get_abs_path(input.path)
|
||||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||||
@@ -169,7 +175,7 @@ function M.func(input, opts)
|
|||||||
return false, "Diff hasn't changed in the last 2 seconds"
|
return false, "Diff hasn't changed in the last 2 seconds"
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
local streaming_diff_lines_count = Utils.count_lines(input.diff)
|
local streaming_diff_lines_count = Utils.count_lines(input.the_diff)
|
||||||
session_ctx.streaming_diff_lines_count_history = session_ctx.streaming_diff_lines_count_history or {}
|
session_ctx.streaming_diff_lines_count_history = session_ctx.streaming_diff_lines_count_history or {}
|
||||||
local prev_streaming_diff_lines_count = session_ctx.streaming_diff_lines_count_history[opts.tool_use_id]
|
local prev_streaming_diff_lines_count = session_ctx.streaming_diff_lines_count_history[opts.tool_use_id]
|
||||||
if streaming_diff_lines_count == prev_streaming_diff_lines_count then
|
if streaming_diff_lines_count == prev_streaming_diff_lines_count then
|
||||||
@@ -178,9 +184,9 @@ function M.func(input, opts)
|
|||||||
session_ctx.streaming_diff_lines_count_history[opts.tool_use_id] = streaming_diff_lines_count
|
session_ctx.streaming_diff_lines_count_history[opts.tool_use_id] = streaming_diff_lines_count
|
||||||
end
|
end
|
||||||
|
|
||||||
local diff = fix_diff(input.diff)
|
local diff = fix_diff(input.the_diff)
|
||||||
|
|
||||||
if on_log and diff ~= input.diff then on_log("diff fixed") end
|
if on_log and diff ~= input.the_diff then on_log("diff fixed") end
|
||||||
|
|
||||||
local diff_lines = vim.split(diff, "\n")
|
local diff_lines = vim.split(diff, "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -58,19 +58,15 @@ M.returns = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
--- IMPORTANT: Using "the_content" instead of "content" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "content", making it impossible to achieve a stream diff view.
|
--- IMPORTANT: Using "the_content" instead of "content" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "content", making it impossible to achieve a stream diff view.
|
||||||
---@type AvanteLLMToolFunc<{ path: string, content: string, the_content?: string }>
|
---@type AvanteLLMToolFunc<{ path: string, the_content?: string }>
|
||||||
function M.func(input, opts)
|
function M.func(input, opts)
|
||||||
if input.the_content ~= nil then
|
|
||||||
input.content = input.the_content
|
|
||||||
input.the_content = nil
|
|
||||||
end
|
|
||||||
local abs_path = Helpers.get_abs_path(input.path)
|
local abs_path = Helpers.get_abs_path(input.path)
|
||||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||||
if input.content == nil then return false, "content not provided" end
|
if input.the_content == nil then return false, "the_content not provided" end
|
||||||
if type(input.content) ~= "string" then input.content = vim.json.encode(input.content) end
|
if type(input.the_content) ~= "string" then input.the_content = vim.json.encode(input.the_content) end
|
||||||
if Utils.count_lines(input.content) == 1 then
|
if Utils.count_lines(input.the_content) == 1 then
|
||||||
Utils.debug("Trimming escapes from content")
|
Utils.debug("Trimming escapes from content")
|
||||||
input.content = Utils.trim_escapes(input.content)
|
input.the_content = Utils.trim_escapes(input.the_content)
|
||||||
end
|
end
|
||||||
local old_lines = Utils.read_file_from_buf_or_disk(abs_path)
|
local old_lines = Utils.read_file_from_buf_or_disk(abs_path)
|
||||||
local old_content = table.concat(old_lines or {}, "\n")
|
local old_content = table.concat(old_lines or {}, "\n")
|
||||||
@@ -78,7 +74,7 @@ function M.func(input, opts)
|
|||||||
local new_input = {
|
local new_input = {
|
||||||
path = input.path,
|
path = input.path,
|
||||||
old_str = old_content,
|
old_str = old_content,
|
||||||
new_str = input.content,
|
new_str = input.the_content,
|
||||||
}
|
}
|
||||||
return str_replace.func(new_input, opts)
|
return str_replace.func(new_input, opts)
|
||||||
end
|
end
|
||||||
|
|||||||
56
lua/avante/utils/diff2search_replace.lua
Normal file
56
lua/avante/utils/diff2search_replace.lua
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
local function trim(s) return s:gsub("^%s+", ""):gsub("%s+$", "") end
|
||||||
|
|
||||||
|
local function split_lines(text)
|
||||||
|
local lines = {}
|
||||||
|
for line in text:gmatch("[^\r\n]+") do
|
||||||
|
table.insert(lines, line)
|
||||||
|
end
|
||||||
|
return lines
|
||||||
|
end
|
||||||
|
|
||||||
|
local function diff2search_replace(diff_text)
|
||||||
|
if not diff_text:match("^@@") then return diff_text end
|
||||||
|
|
||||||
|
local blocks = {}
|
||||||
|
local pos = 1
|
||||||
|
local len = #diff_text
|
||||||
|
|
||||||
|
-- 解析每一个 @@ 块
|
||||||
|
while pos <= len do
|
||||||
|
-- 找到下一个 @@ 起始
|
||||||
|
local start_at = diff_text:find("@@%s*%-%d+,%d+%s%+", pos)
|
||||||
|
if not start_at then break end
|
||||||
|
|
||||||
|
-- 找到该块结束位置(下一个 @@ 或文件末尾)
|
||||||
|
local next_at = diff_text:find("@@%s*%-%d+,%d+%s%+", start_at + 1)
|
||||||
|
local block_end = next_at and (next_at - 1) or len
|
||||||
|
local block = diff_text:sub(start_at, block_end)
|
||||||
|
|
||||||
|
-- 去掉首行的 @@ ... @@ 行
|
||||||
|
local first_nl = block:find("\n")
|
||||||
|
if first_nl then block = block:sub(first_nl + 1) end
|
||||||
|
|
||||||
|
local search_lines, replace_lines = {}, {}
|
||||||
|
for _, line in ipairs(split_lines(block)) do
|
||||||
|
local first = line:sub(1, 1)
|
||||||
|
if first == "-" then
|
||||||
|
table.insert(search_lines, line:sub(2))
|
||||||
|
elseif first == "+" then
|
||||||
|
table.insert(replace_lines, line:sub(2))
|
||||||
|
elseif first == " " then
|
||||||
|
table.insert(search_lines, line:sub(2))
|
||||||
|
table.insert(replace_lines, line:sub(2))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
local search = table.concat(search_lines, "\n")
|
||||||
|
local replace = table.concat(replace_lines, "\n")
|
||||||
|
|
||||||
|
table.insert(blocks, "------- SEARCH\n" .. trim(search) .. "\n=======\n" .. trim(replace) .. "\n+++++++ REPLACE")
|
||||||
|
pos = block_end + 1
|
||||||
|
end
|
||||||
|
|
||||||
|
return table.concat(blocks, "\n\n")
|
||||||
|
end
|
||||||
|
|
||||||
|
return diff2search_replace
|
||||||
Reference in New Issue
Block a user