diff --git a/lua/avante/llm_tools/replace_in_file.lua b/lua/avante/llm_tools/replace_in_file.lua index b1ad413..fce1128 100644 --- a/lua/avante/llm_tools/replace_in_file.lua +++ b/lua/avante/llm_tools/replace_in_file.lua @@ -3,6 +3,7 @@ local Helpers = require("avante.llm_tools.helpers") local Utils = require("avante.utils") local Highlights = require("avante.highlights") local Config = require("avante.config") +local diff2search_replace = require("avante.utils.diff2search_replace") local PRIORITY = (vim.hl or vim.highlight).priorities.user local NAMESPACE = vim.api.nvim_create_namespace("avante-diff") @@ -105,14 +106,12 @@ M.returns = { ---@param diff string ---@return string local function fix_diff(diff) + diff = diff2search_replace(diff) -- Normalize block headers to the expected ones (fix for some LLMs output) diff = diff:gsub("<<<<<<<%s*SEARCH", "------- SEARCH") diff = diff:gsub(">>>>>>>%s*REPLACE", "+++++++ REPLACE") diff = diff:gsub("-------%s*REPLACE", "+++++++ REPLACE") - local has_search_line = diff:match("^%s*-------* SEARCH") ~= nil - if has_search_line then return diff end - local fixed_diff_lines = {} local lines = vim.split(diff, "\n") local first_line = lines[1] @@ -122,38 +121,45 @@ local function fix_diff(diff) fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 2) else table.insert(fixed_diff_lines, "------- SEARCH") - fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 1) + if first_line:match("------- SEARCH") then + fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 2) + else + fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 1) + end end local the_final_diff_lines = {} local has_split_line = false + local replace_block_closed = false for _, line in ipairs(fixed_diff_lines) do if line:match("^-------%s*SEARCH") then has_split_line = false end if line:match("^=======") then has_split_line = true end if line:match("^+++++++%s*REPLACE") then if not has_split_line then table.insert(the_final_diff_lines, "=======") + has_split_line = true goto continue + else + replace_block_closed = true end end table.insert(the_final_diff_lines, line) ::continue:: end + if not replace_block_closed then table.insert(the_final_diff_lines, "+++++++ REPLACE") end return table.concat(the_final_diff_lines, "\n") end --- IMPORTANT: Using "the_diff" instead of "diff" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "diff", making it impossible to achieve a streaming diff view. ----@type AvanteLLMToolFunc<{ path: string, diff: string, the_diff?: string }> +---@type AvanteLLMToolFunc<{ path: string, the_diff?: string }> function M.func(input, opts) local on_log = opts.on_log local on_complete = opts.on_complete local session_ctx = opts.session_ctx if not on_complete then return false, "on_complete not provided" end - if input.the_diff ~= nil then - input.diff = input.the_diff - input.the_diff = nil + if not input.path or not input.the_diff then + return false, "path and the_diff are required " .. vim.inspect(input) end - if not input.path or not input.diff then return false, "path and diff are required " .. vim.inspect(input) end if on_log then on_log("path: " .. input.path) end local abs_path = Helpers.get_abs_path(input.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end @@ -169,7 +175,7 @@ function M.func(input, opts) return false, "Diff hasn't changed in the last 2 seconds" end end - local streaming_diff_lines_count = Utils.count_lines(input.diff) + local streaming_diff_lines_count = Utils.count_lines(input.the_diff) session_ctx.streaming_diff_lines_count_history = session_ctx.streaming_diff_lines_count_history or {} local prev_streaming_diff_lines_count = session_ctx.streaming_diff_lines_count_history[opts.tool_use_id] if streaming_diff_lines_count == prev_streaming_diff_lines_count then @@ -178,9 +184,9 @@ function M.func(input, opts) session_ctx.streaming_diff_lines_count_history[opts.tool_use_id] = streaming_diff_lines_count end - local diff = fix_diff(input.diff) + local diff = fix_diff(input.the_diff) - if on_log and diff ~= input.diff then on_log("diff fixed") end + if on_log and diff ~= input.the_diff then on_log("diff fixed") end local diff_lines = vim.split(diff, "\n") diff --git a/lua/avante/llm_tools/write_to_file.lua b/lua/avante/llm_tools/write_to_file.lua index ccbef83..62e57e4 100644 --- a/lua/avante/llm_tools/write_to_file.lua +++ b/lua/avante/llm_tools/write_to_file.lua @@ -58,19 +58,15 @@ M.returns = { } --- IMPORTANT: Using "the_content" instead of "content" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "content", making it impossible to achieve a stream diff view. ----@type AvanteLLMToolFunc<{ path: string, content: string, the_content?: string }> +---@type AvanteLLMToolFunc<{ path: string, the_content?: string }> function M.func(input, opts) - if input.the_content ~= nil then - input.content = input.the_content - input.the_content = nil - end local abs_path = Helpers.get_abs_path(input.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end - if input.content == nil then return false, "content not provided" end - if type(input.content) ~= "string" then input.content = vim.json.encode(input.content) end - if Utils.count_lines(input.content) == 1 then + if input.the_content == nil then return false, "the_content not provided" end + if type(input.the_content) ~= "string" then input.the_content = vim.json.encode(input.the_content) end + if Utils.count_lines(input.the_content) == 1 then Utils.debug("Trimming escapes from content") - input.content = Utils.trim_escapes(input.content) + input.the_content = Utils.trim_escapes(input.the_content) end local old_lines = Utils.read_file_from_buf_or_disk(abs_path) local old_content = table.concat(old_lines or {}, "\n") @@ -78,7 +74,7 @@ function M.func(input, opts) local new_input = { path = input.path, old_str = old_content, - new_str = input.content, + new_str = input.the_content, } return str_replace.func(new_input, opts) end diff --git a/lua/avante/utils/diff2search_replace.lua b/lua/avante/utils/diff2search_replace.lua new file mode 100644 index 0000000..4bd2be5 --- /dev/null +++ b/lua/avante/utils/diff2search_replace.lua @@ -0,0 +1,56 @@ +local function trim(s) return s:gsub("^%s+", ""):gsub("%s+$", "") end + +local function split_lines(text) + local lines = {} + for line in text:gmatch("[^\r\n]+") do + table.insert(lines, line) + end + return lines +end + +local function diff2search_replace(diff_text) + if not diff_text:match("^@@") then return diff_text end + + local blocks = {} + local pos = 1 + local len = #diff_text + + -- 解析每一个 @@ 块 + while pos <= len do + -- 找到下一个 @@ 起始 + local start_at = diff_text:find("@@%s*%-%d+,%d+%s%+", pos) + if not start_at then break end + + -- 找到该块结束位置(下一个 @@ 或文件末尾) + local next_at = diff_text:find("@@%s*%-%d+,%d+%s%+", start_at + 1) + local block_end = next_at and (next_at - 1) or len + local block = diff_text:sub(start_at, block_end) + + -- 去掉首行的 @@ ... @@ 行 + local first_nl = block:find("\n") + if first_nl then block = block:sub(first_nl + 1) end + + local search_lines, replace_lines = {}, {} + for _, line in ipairs(split_lines(block)) do + local first = line:sub(1, 1) + if first == "-" then + table.insert(search_lines, line:sub(2)) + elseif first == "+" then + table.insert(replace_lines, line:sub(2)) + elseif first == " " then + table.insert(search_lines, line:sub(2)) + table.insert(replace_lines, line:sub(2)) + end + end + + local search = table.concat(search_lines, "\n") + local replace = table.concat(replace_lines, "\n") + + table.insert(blocks, "------- SEARCH\n" .. trim(search) .. "\n=======\n" .. trim(replace) .. "\n+++++++ REPLACE") + pos = block_end + 1 + end + + return table.concat(blocks, "\n\n") +end + +return diff2search_replace