diff --git a/lua/avante/llm_tools/replace_in_file.lua b/lua/avante/llm_tools/replace_in_file.lua index 4520ce6..b9a92c3 100644 --- a/lua/avante/llm_tools/replace_in_file.lua +++ b/lua/avante/llm_tools/replace_in_file.lua @@ -3,7 +3,6 @@ local Helpers = require("avante.llm_tools.helpers") local Utils = require("avante.utils") local Highlights = require("avante.highlights") local Config = require("avante.config") -local diff2search_replace = require("avante.utils.diff2search_replace") local PRIORITY = (vim.hl or vim.highlight).priorities.user local NAMESPACE = vim.api.nvim_create_namespace("avante-diff") @@ -102,55 +101,6 @@ M.returns = { }, } ---- Some models (e.g., gpt-4o) cannot correctly return diff content and often miss the SEARCH line, so this needs to be manually fixed in such cases. ----@param diff string ----@return string -local function fix_diff(diff) - diff = diff2search_replace(diff) - -- Normalize block headers to the expected ones (fix for some LLMs output) - diff = diff:gsub("<<<<<<<%s*SEARCH", "------- SEARCH") - diff = diff:gsub(">>>>>>>%s*REPLACE", "+++++++ REPLACE") - diff = diff:gsub("-------%s*REPLACE", "+++++++ REPLACE") - diff = diff:gsub("------- ", "------- SEARCH\n") - diff = diff:gsub("======= ", "======= \n") - - local fixed_diff_lines = {} - local lines = vim.split(diff, "\n") - local first_line = lines[1] - if first_line and first_line:match("^%s*```") then - table.insert(fixed_diff_lines, first_line) - table.insert(fixed_diff_lines, "------- SEARCH") - fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 2) - else - table.insert(fixed_diff_lines, "------- SEARCH") - if first_line:match("------- SEARCH") then - fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 2) - else - fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 1) - end - end - local the_final_diff_lines = {} - local has_split_line = false - local replace_block_closed = false - for _, line in ipairs(fixed_diff_lines) do - if line:match("^-------%s*SEARCH") then has_split_line = false end - if line:match("^=======") then has_split_line = true end - if line:match("^+++++++%s*REPLACE") then - if not has_split_line then - table.insert(the_final_diff_lines, "=======") - has_split_line = true - goto continue - else - replace_block_closed = true - end - end - table.insert(the_final_diff_lines, line) - ::continue:: - end - if not replace_block_closed then table.insert(the_final_diff_lines, "+++++++ REPLACE") end - return table.concat(the_final_diff_lines, "\n") -end - --- IMPORTANT: Using "the_diff" instead of "diff" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "diff", making it impossible to achieve a streaming diff view. ---@type AvanteLLMToolFunc<{ path: string, the_diff?: string }> function M.func(input, opts) @@ -186,7 +136,7 @@ function M.func(input, opts) session_ctx.streaming_diff_lines_count_history[opts.tool_use_id] = streaming_diff_lines_count end - local diff = fix_diff(input.the_diff) + local diff = Utils.fix_diff(input.the_diff) if on_log and diff ~= input.the_diff then on_log("diff fixed") end diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index b06effd..55cf868 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -3,6 +3,7 @@ local fn = vim.fn local lsp = vim.lsp local LRUCache = require("avante.utils.lru_cache") +local diff2search_replace = require("avante.utils.diff2search_replace") ---@class avante.utils: LazyUtilCore ---@field tokens avante.utils.tokens @@ -1621,4 +1622,61 @@ function M.call_once(func) end end +--- Some models (e.g., gpt-4o) cannot correctly return diff content and often miss the SEARCH line, so this needs to be manually fixed in such cases. +---@param diff string +---@return string +function M.fix_diff(diff) + diff = diff2search_replace(diff) + -- Normalize block headers to the expected ones (fix for some LLMs output) + diff = diff:gsub("<<<<<<<%s*SEARCH", "------- SEARCH") + diff = diff:gsub(">>>>>>>%s*REPLACE", "+++++++ REPLACE") + diff = diff:gsub("-------%s*REPLACE", "+++++++ REPLACE") + diff = diff:gsub("------- ", "------- SEARCH\n") + diff = diff:gsub("======= ", "=======\n") + + local fixed_diff_lines = {} + local lines = vim.split(diff, "\n") + local first_line = lines[1] + if first_line and first_line:match("^%s*```") then + table.insert(fixed_diff_lines, first_line) + table.insert(fixed_diff_lines, "------- SEARCH") + fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 2) + else + table.insert(fixed_diff_lines, "------- SEARCH") + if first_line:match("------- SEARCH") then + fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 2) + else + fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 1) + end + end + local the_final_diff_lines = {} + local has_split_line = false + local replace_block_closed = false + local should_delete_following_lines = false + for _, line in ipairs(fixed_diff_lines) do + if should_delete_following_lines then goto continue end + if line:match("^-------%s*SEARCH") then has_split_line = false end + if line:match("^=======") then + if has_split_line then + should_delete_following_lines = true + goto continue + end + has_split_line = true + end + if line:match("^+++++++%s*REPLACE") then + if not has_split_line then + table.insert(the_final_diff_lines, "=======") + has_split_line = true + goto continue + else + replace_block_closed = true + end + end + table.insert(the_final_diff_lines, line) + ::continue:: + end + if not replace_block_closed then table.insert(the_final_diff_lines, "+++++++ REPLACE") end + return table.concat(the_final_diff_lines, "\n") +end + return M diff --git a/tests/utils/fix_diff_spec.lua b/tests/utils/fix_diff_spec.lua new file mode 100644 index 0000000..bca742a --- /dev/null +++ b/tests/utils/fix_diff_spec.lua @@ -0,0 +1,718 @@ +local Utils = require("avante.utils") + +describe("Utils.fix_diff", function() + it("should not break normal diff", function() + local diff = [[------- SEARCH + setShowLogs(false)} title="Project PRD Logs" size="xl"> +
+
+ {logs.split('\n').join('\n\n')} +
{logsLoading && }
+
+
+ {logs.length > 0 && ( +
+ +
+ )} +
+======= + setShowLogs(false)} title="Project PRD Logs" size="xl"> +
+
+
+ {logs.split('\n').join('\n\n')} +
+
{logsLoading && }
+
{ + if (el) { + el.scrollIntoView({ behavior: 'smooth', block: 'end' }); + } + }} /> +
+
+ {logs.length > 0 && ( +
+ +
+ )} + ++++++++ REPLACE +]] + + local fixed_diff = Utils.fix_diff(diff) + assert.equals(diff, fixed_diff) + end) + + it("should not break normal multiple diff", function() + local diff = [[------- SEARCH + setShowLogs(false)} title="Project PRD Logs" size="xl"> +
+
+ {logs.split('\n').join('\n\n')} +
{logsLoading && }
+
+
+ {logs.length > 0 && ( +
+ +
+ )} +
+======= + setShowLogs(false)} title="Project PRD Logs" size="xl"> +
+
+
+ {logs.split('\n').join('\n\n')} +
+
{logsLoading && }
+
{ + if (el) { + el.scrollIntoView({ behavior: 'smooth', block: 'end' }); + } + }} /> +
+
+ {logs.length > 0 && ( +
+ +
+ )} + ++++++++ REPLACE + +------- SEARCH + setShowLogs(false)} title="Project PRD Logs" size="xl"> +
+======= + setShowLogs(false)} title="Project PRD Logs" size="xl aaa"> +
++++++++ REPLACE +]] + + local fixed_diff = Utils.fix_diff(diff) + assert.equals(diff, fixed_diff) + end) + + it("should fix duplicated REPLACE delimiters", function() + local diff = [[------- SEARCH + setShowLogs(false)} title="Project PRD Logs" size="xl"> +
+
+ {logs.split('\n').join('\n\n')} +
{logsLoading && }
+
+
+ {logs.length > 0 && ( +
+ +
+ )} +
+------- REPLACE + setShowLogs(false)} title="Project PRD Logs" size="xl"> +
+
+
+ {logs.split('\n').join('\n\n')} +
+
{logsLoading && }
+
{ + if (el) { + el.scrollIntoView({ behavior: 'smooth', block: 'end' }); + } + }} /> +
+
+ {logs.length > 0 && ( +
+ +
+ )} + +------- REPLACE +]] + + local expected_diff = [[------- SEARCH + setShowLogs(false)} title="Project PRD Logs" size="xl"> +
+
+ {logs.split('\n').join('\n\n')} +
{logsLoading && }
+
+
+ {logs.length > 0 && ( +
+ +
+ )} +
+======= + setShowLogs(false)} title="Project PRD Logs" size="xl"> +
+
+
+ {logs.split('\n').join('\n\n')} +
+
{logsLoading && }
+
{ + if (el) { + el.scrollIntoView({ behavior: 'smooth', block: 'end' }); + } + }} /> +
+
+ {logs.length > 0 && ( +
+ +
+ )} + ++++++++ REPLACE +]] + + local fixed_diff = Utils.fix_diff(diff) + assert.equals(expected_diff, fixed_diff) + end) + + it("should fix the delimiter is on the same line as the content", function() + local diff = [[------- // Fetch initial stages when project changes + useEffect(() => { + if (!subscribedProject) return; + + const fetchStages = async () => { + try { + const response = await fetch(`/api/projects/${subscribedProject}/stages`); + if (response.ok) { + const stagesData = await response.json(); + setStages(stagesData); + } + } catch (error) { + console.error('Failed to fetch stages:', error); + } + }; + + fetchStages(); + }, [subscribedProject, forceUpdateCounter]); +======= // Fetch initial stages when project changes + useEffect(() => { + if (!subscribedProject) return; + + const fetchStages = async () => { + try { + // Use the correct API endpoint for stages by project UUID + const response = await fetch(`/api/stages?project_uuid=${subscribedProject}`); + if (response.ok) { + const stagesData = await response.json(); + setStages(stagesData); + } + } catch (error) { + console.error('Failed to fetch stages:', error); + } + }; + + fetchStages(); + }, [subscribedProject, forceUpdateCounter]); ++++++++ REPLACE +]] + + local expected_diff = [[------- SEARCH + // Fetch initial stages when project changes + useEffect(() => { + if (!subscribedProject) return; + + const fetchStages = async () => { + try { + const response = await fetch(`/api/projects/${subscribedProject}/stages`); + if (response.ok) { + const stagesData = await response.json(); + setStages(stagesData); + } + } catch (error) { + console.error('Failed to fetch stages:', error); + } + }; + + fetchStages(); + }, [subscribedProject, forceUpdateCounter]); +======= + // Fetch initial stages when project changes + useEffect(() => { + if (!subscribedProject) return; + + const fetchStages = async () => { + try { + // Use the correct API endpoint for stages by project UUID + const response = await fetch(`/api/stages?project_uuid=${subscribedProject}`); + if (response.ok) { + const stagesData = await response.json(); + setStages(stagesData); + } + } catch (error) { + console.error('Failed to fetch stages:', error); + } + }; + + fetchStages(); + }, [subscribedProject, forceUpdateCounter]); ++++++++ REPLACE +]] + + local fixed_diff = Utils.fix_diff(diff) + assert.equals(expected_diff, fixed_diff) + end) + + it("should fix unified diff", function() + local diff = [[--- lua/avante/sidebar.lua ++++ lua/avante/sidebar.lua +@@ -3099,7 +3099,7 @@ + function Sidebar:create_todos_container() + local history = Path.history.load(self.code.bufnr) + if not history or not history.todos or #history.todos == 0 then +- if self.containers.todos then self.containers.todos:unmount() end ++ if self.containers.todos and Utils.is_valid_container(self.containers.todos) then self.containers.todos:unmount() end + self.containers.todos = nil + self:adjust_layout() + return +@@ -3121,7 +3121,7 @@ + }), + position = "bottom", + size = { +- height = 3, ++ height = math.min(3, math.max(1, vim.o.lines - 5)), + }, + }) + self.containers.todos:mount() +@@ -3151,11 +3151,15 @@ + self:render_header( + self.containers.todos.winid, + todos_buf, +- Utils.icon(" ") .. "Todos" .. " (" .. done_count .. "/" .. total_count .. ")", ++ Utils.icon(" ") .. "Todos" .. " (" .. done_count .. "/" .. total_count .. ")", + Highlights.SUBTITLE, + Highlights.REVERSED_SUBTITLE + ) +- self:adjust_layout() ++ ++ local ok, err = pcall(function() ++ self:adjust_layout() ++ end) ++ if not ok then Utils.debug("Failed to adjust layout after todos creation:", err) end + end + + function Sidebar:adjust_layout() +]] + + local expected_diff = [[------- SEARCH +function Sidebar:create_todos_container() + local history = Path.history.load(self.code.bufnr) + if not history or not history.todos or #history.todos == 0 then + if self.containers.todos then self.containers.todos:unmount() end + self.containers.todos = nil + self:adjust_layout() + return +======= +function Sidebar:create_todos_container() + local history = Path.history.load(self.code.bufnr) + if not history or not history.todos or #history.todos == 0 then + if self.containers.todos and Utils.is_valid_container(self.containers.todos) then self.containers.todos:unmount() end + self.containers.todos = nil + self:adjust_layout() + return ++++++++ REPLACE + +------- SEARCH +}), + position = "bottom", + size = { + height = 3, + }, + }) + self.containers.todos:mount() +======= +}), + position = "bottom", + size = { + height = math.min(3, math.max(1, vim.o.lines - 5)), + }, + }) + self.containers.todos:mount() ++++++++ REPLACE + +------- SEARCH +self:render_header( + self.containers.todos.winid, + todos_buf, + Utils.icon(" ") .. "Todos" .. " (" .. done_count .. "/" .. total_count .. ")", + Highlights.SUBTITLE, + Highlights.REVERSED_SUBTITLE + ) + self:adjust_layout() +end +function Sidebar:adjust_layout() +======= +self:render_header( + self.containers.todos.winid, + todos_buf, + Utils.icon(" ") .. "Todos" .. " (" .. done_count .. "/" .. total_count .. ")", + Highlights.SUBTITLE, + Highlights.REVERSED_SUBTITLE + ) + + local ok, err = pcall(function() + self:adjust_layout() + end) + if not ok then Utils.debug("Failed to adjust layout after todos creation:", err) end +end +function Sidebar:adjust_layout() ++++++++ REPLACE]] + + local fixed_diff = Utils.fix_diff(diff) + assert.equals(expected_diff, fixed_diff) + end) + + it("should fix unified diff 2", function() + local diff = [[ +@@ -3099,7 +3099,7 @@ + function Sidebar:create_todos_container() + local history = Path.history.load(self.code.bufnr) + if not history or not history.todos or #history.todos == 0 then +- if self.containers.todos then self.containers.todos:unmount() end ++ if self.containers.todos and Utils.is_valid_container(self.containers.todos) then self.containers.todos:unmount() end + self.containers.todos = nil + self:adjust_layout() + return +@@ -3121,7 +3121,7 @@ + }), + position = "bottom", + size = { +- height = 3, ++ height = math.min(3, math.max(1, vim.o.lines - 5)), + }, + }) + self.containers.todos:mount() +@@ -3151,11 +3151,15 @@ + self:render_header( + self.containers.todos.winid, + todos_buf, +- Utils.icon(" ") .. "Todos" .. " (" .. done_count .. "/" .. total_count .. ")", ++ Utils.icon(" ") .. "Todos" .. " (" .. done_count .. "/" .. total_count .. ")", + Highlights.SUBTITLE, + Highlights.REVERSED_SUBTITLE + ) +- self:adjust_layout() ++ ++ local ok, err = pcall(function() ++ self:adjust_layout() ++ end) ++ if not ok then Utils.debug("Failed to adjust layout after todos creation:", err) end + end + + function Sidebar:adjust_layout() +]] + local expected_diff = [[------- SEARCH +function Sidebar:create_todos_container() + local history = Path.history.load(self.code.bufnr) + if not history or not history.todos or #history.todos == 0 then + if self.containers.todos then self.containers.todos:unmount() end + self.containers.todos = nil + self:adjust_layout() + return +======= +function Sidebar:create_todos_container() + local history = Path.history.load(self.code.bufnr) + if not history or not history.todos or #history.todos == 0 then + if self.containers.todos and Utils.is_valid_container(self.containers.todos) then self.containers.todos:unmount() end + self.containers.todos = nil + self:adjust_layout() + return ++++++++ REPLACE + +------- SEARCH +}), + position = "bottom", + size = { + height = 3, + }, + }) + self.containers.todos:mount() +======= +}), + position = "bottom", + size = { + height = math.min(3, math.max(1, vim.o.lines - 5)), + }, + }) + self.containers.todos:mount() ++++++++ REPLACE + +------- SEARCH +self:render_header( + self.containers.todos.winid, + todos_buf, + Utils.icon(" ") .. "Todos" .. " (" .. done_count .. "/" .. total_count .. ")", + Highlights.SUBTITLE, + Highlights.REVERSED_SUBTITLE + ) + self:adjust_layout() +end +function Sidebar:adjust_layout() +======= +self:render_header( + self.containers.todos.winid, + todos_buf, + Utils.icon(" ") .. "Todos" .. " (" .. done_count .. "/" .. total_count .. ")", + Highlights.SUBTITLE, + Highlights.REVERSED_SUBTITLE + ) + + local ok, err = pcall(function() + self:adjust_layout() + end) + if not ok then Utils.debug("Failed to adjust layout after todos creation:", err) end +end +function Sidebar:adjust_layout() ++++++++ REPLACE]] + + local fixed_diff = Utils.fix_diff(diff) + assert.equals(expected_diff, fixed_diff) + end) + + it("should fix duplicated replace blocks", function() + local diff = [[------- SEARCH + useEffect(() => { + if (!isExpanded || !textContentRef.current) { + setShowFixedCollapseButton(false); + return; + } + + const observer = new IntersectionObserver( + ([entry]) => { + setShowFixedCollapseButton(!entry.isIntersecting); + }, + { + root: null, + rootMargin: '0px', + threshold: 1.0, + } + ); + + const collapseButton = collapseButtonRef.current; + if (collapseButton) { + observer.observe(collapseButton); + } + + return () => { + if (collapseButton) { + observer.unobserve(collapseButton); + } + }; + }, [isExpanded, textContentRef.current]); +======= + useEffect(() => { + if (!isExpanded || !textContentRef.current) { + setShowFixedCollapseButton(false); + return; + } + + // Check initial visibility of the collapse button + const checkInitialVisibility = () => { + const collapseButton = collapseButtonRef.current; + if (collapseButton) { + const rect = collapseButton.getBoundingClientRect(); + const isVisible = rect.top >= 0 && rect.bottom <= window.innerHeight; + setShowFixedCollapseButton(!isVisible); + } + }; + + // Small delay to ensure DOM is updated after expansion + const timeoutId = setTimeout(checkInitialVisibility, 100); + + const observer = new IntersectionObserver( + ([entry]) => { + setShowFixedCollapseButton(!entry.isIntersecting); + }, + { + root: null, + rootMargin: '0px', + threshold: [0, 1.0], // Check both when it starts to leave and when fully visible + } + ); + + const collapseButton = collapseButtonRef.current; + if (collapseButton) { + observer.observe(collapseButton); + } + + return () => { + clearTimeout(timeoutId); + if (collapseButton) { + observer.unobserve(collapseButton); + } + }; + }, [isExpanded, textContentRef.current]); +======= + useEffect(() => { + if (!isExpanded || !textContentRef.current) { + setShowFixedCollapseButton(false); + return; + } + + // Check initial visibility of the collapse button + const checkInitialVisibility = () => { + const collapseButton = collapseButtonRef.current; + if (collapseButton) { + const rect = collapseButton.getBoundingClientRect(); + const isVisible = rect.top >= 0 && rect.bottom <= window.innerHeight; + setShowFixedCollapseButton(!isVisible); + } + }; + + // Small delay to ensure DOM is updated after expansion + const timeoutId = setTimeout(checkInitialVisibility, 100); + + const observer = new IntersectionObserver( + ([entry]) => { + setShowFixedCollapseButton(!entry.isIntersecting); + }, + { + root: null, + rootMargin: '0px', + threshold: [0, 1.0], // Check both when it starts to leave and when fully visible + } + ); + + const collapseButton = collapseButtonRef.current; + if (collapseButton) { + observer.observe(collapseButton); + } + + return () => { + clearTimeout(timeoutId); + if (collapseButton) { + observer.unobserve(collapseButton); + } + }; + }, [isExpanded, textContentRef.current]); ++++++++ REPLACE +]] + + local expected_diff = [[------- SEARCH + useEffect(() => { + if (!isExpanded || !textContentRef.current) { + setShowFixedCollapseButton(false); + return; + } + + const observer = new IntersectionObserver( + ([entry]) => { + setShowFixedCollapseButton(!entry.isIntersecting); + }, + { + root: null, + rootMargin: '0px', + threshold: 1.0, + } + ); + + const collapseButton = collapseButtonRef.current; + if (collapseButton) { + observer.observe(collapseButton); + } + + return () => { + if (collapseButton) { + observer.unobserve(collapseButton); + } + }; + }, [isExpanded, textContentRef.current]); +======= + useEffect(() => { + if (!isExpanded || !textContentRef.current) { + setShowFixedCollapseButton(false); + return; + } + + // Check initial visibility of the collapse button + const checkInitialVisibility = () => { + const collapseButton = collapseButtonRef.current; + if (collapseButton) { + const rect = collapseButton.getBoundingClientRect(); + const isVisible = rect.top >= 0 && rect.bottom <= window.innerHeight; + setShowFixedCollapseButton(!isVisible); + } + }; + + // Small delay to ensure DOM is updated after expansion + const timeoutId = setTimeout(checkInitialVisibility, 100); + + const observer = new IntersectionObserver( + ([entry]) => { + setShowFixedCollapseButton(!entry.isIntersecting); + }, + { + root: null, + rootMargin: '0px', + threshold: [0, 1.0], // Check both when it starts to leave and when fully visible + } + ); + + const collapseButton = collapseButtonRef.current; + if (collapseButton) { + observer.observe(collapseButton); + } + + return () => { + clearTimeout(timeoutId); + if (collapseButton) { + observer.unobserve(collapseButton); + } + }; + }, [isExpanded, textContentRef.current]); ++++++++ REPLACE]] + + local fixed_diff = Utils.fix_diff(diff) + assert.equals(expected_diff, fixed_diff) + end) +end)