From a2aec079c9e430200d687a8f4284afc1db33a497 Mon Sep 17 00:00:00 2001 From: yetone Date: Tue, 25 Mar 2025 18:43:57 +0800 Subject: [PATCH] feat: extract str_replace tool (#1710) --- lua/avante/config.lua | 8 - lua/avante/llm_tools/init.lua | 114 +----------- lua/avante/llm_tools/str_replace.lua | 176 ++++++++++++++++++ lua/avante/llm_tools/view.lua | 2 +- lua/avante/providers/claude.lua | 1 + lua/avante/sidebar.lua | 8 +- .../claude-text-editor-tool.avanterules | 2 + 7 files changed, 186 insertions(+), 125 deletions(-) create mode 100644 lua/avante/llm_tools/str_replace.lua diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 99f9e6b..8473311 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -517,14 +517,6 @@ function M.setup(opts) vim.validate({ vendors = { M._options.vendors, "table", true } }) M.provider_names = vim.list_extend(M.provider_names, vim.tbl_keys(M._options.vendors)) end - - if M._options.behaviour.enable_claude_text_editor_tool_mode and M._options.provider ~= "claude" then - Utils.warn( - "Claude text editor tool mode is only supported for claude provider! So it will be disabled!", - { title = "Avante" } - ) - M._options.behaviour.enable_claude_text_editor_tool_mode = false - end end ---@param opts table diff --git a/lua/avante/llm_tools/init.lua b/lua/avante/llm_tools/init.lua index 76aa37c..fe7b63c 100644 --- a/lua/avante/llm_tools/init.lua +++ b/lua/avante/llm_tools/init.lua @@ -3,7 +3,6 @@ local Utils = require("avante.utils") local Path = require("plenary.path") local Config = require("avante.config") local RagService = require("avante.rag_service") -local Diff = require("avante.diff") local Highlights = require("avante.highlights") local Helpers = require("avante.llm_tools.helpers") @@ -28,7 +27,6 @@ end ---@type AvanteLLMToolFunc<{ command: "view" | "str_replace" | "create" | "insert" | "undo_edit", path: string, old_str?: string, new_str?: string, file_text?: string, insert_line?: integer, new_str?: string, view_range?: integer[] }> function M.str_replace_editor(opts, on_log, on_complete, session_ctx) if on_log then on_log("command: " .. opts.command) end - if on_log then on_log("path: " .. vim.inspect(opts.path)) end if not on_complete then return false, "on_complete not provided" end local abs_path = Helpers.get_abs_path(opts.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end @@ -54,115 +52,10 @@ function M.str_replace_editor(opts, on_log, on_complete, session_ctx) return view(opts_, on_log, on_complete, session_ctx) end if opts.command == "str_replace" then - if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end - if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end - local file = io.open(abs_path, "r") - if not file then return false, "file not found: " .. abs_path end - if opts.old_str == nil then return false, "old_str not provided" end - if opts.new_str == nil then return false, "new_str not provided" end - local bufnr = get_bufnr() - local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) - local lines_content = table.concat(lines, "\n") - local old_lines = vim.split(opts.old_str, "\n") - local new_lines = vim.split(opts.new_str, "\n") - local start_line, end_line - for i = 1, #lines - #old_lines + 1 do - local match = true - for j = 1, #old_lines do - if lines[i + j - 1] ~= old_lines[j] then - match = false - break - end - end - if match then - start_line = i - end_line = i + #old_lines - 1 - break - end - end - if start_line == nil or end_line == nil then - on_complete(false, "Failed to find the old string: " .. opts.old_str) - return - end - ---@diagnostic disable-next-line: assign-type-mismatch, missing-fields - local patch = vim.diff(opts.old_str, opts.new_str, { ---@type integer[][] - algorithm = "histogram", - result_type = "indices", - ctxlen = vim.o.scrolloff, - }) - local patch_start_line_content = "<<<<<<< HEAD" - local patch_end_line_content = ">>>>>>> new " - --- add random characters to the end of the line to avoid conflicts - patch_end_line_content = patch_end_line_content .. Utils.random_string(10) - local current_start_a = 1 - local patched_new_lines = {} - for _, hunk in ipairs(patch) do - local start_a, count_a, start_b, count_b = unpack(hunk) - if current_start_a < start_a then - if count_a > 0 then - vim.list_extend(patched_new_lines, vim.list_slice(old_lines, current_start_a, start_a - 1)) - else - vim.list_extend(patched_new_lines, vim.list_slice(old_lines, current_start_a, start_a)) - end - end - table.insert(patched_new_lines, patch_start_line_content) - if count_a > 0 then - vim.list_extend(patched_new_lines, vim.list_slice(old_lines, start_a, start_a + count_a - 1)) - end - table.insert(patched_new_lines, "=======") - vim.list_extend(patched_new_lines, vim.list_slice(new_lines, start_b, start_b + count_b - 1)) - table.insert(patched_new_lines, patch_end_line_content) - current_start_a = start_a + math.max(count_a, 1) - end - if current_start_a <= #old_lines then - vim.list_extend(patched_new_lines, vim.list_slice(old_lines, current_start_a, #old_lines)) - end - vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, patched_new_lines) - local current_winid = vim.api.nvim_get_current_win() - vim.api.nvim_set_current_win(sidebar.code.winid) - Diff.add_visited_buffer(bufnr) - Diff.process(bufnr) - if #patch > 0 then - vim.api.nvim_win_set_cursor(sidebar.code.winid, { math.max(patch[1][1] + start_line - 1, 1), 0 }) - end - vim.cmd("normal! zz") - vim.api.nvim_set_current_win(current_winid) - local augroup = vim.api.nvim_create_augroup("avante_str_replace_editor", { clear = true }) - local confirm = Helpers.confirm("Are you sure you want to apply this modification?", function(ok) - pcall(vim.api.nvim_del_augroup_by_id, augroup) - vim.api.nvim_set_current_win(sidebar.code.winid) - vim.api.nvim_feedkeys(vim.api.nvim_replace_termcodes("", true, false, true), "n", true) - vim.cmd("undo") - if not ok then - vim.api.nvim_set_current_win(current_winid) - on_complete(false, "User canceled") - return - end - vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, new_lines) - vim.api.nvim_set_current_win(current_winid) - on_complete(true, nil) - end, { focus = false }) - vim.api.nvim_set_current_win(sidebar.code.winid) - vim.api.nvim_create_autocmd({ "TextChangedI", "TextChanged" }, { - group = augroup, - buffer = bufnr, - callback = function() - local current_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) - local current_lines_content = table.concat(current_lines, "\n") - if current_lines_content:find(patch_end_line_content) then return end - pcall(vim.api.nvim_del_augroup_by_id, augroup) - if confirm then confirm:close() end - if vim.api.nvim_win_is_valid(current_winid) then vim.api.nvim_set_current_win(current_winid) end - if lines_content == current_lines_content then - on_complete(false, "User canceled") - return - end - on_complete(true, nil) - end, - }) - return + return require("avante.llm_tools.str_replace").func(opts, on_log, on_complete) end if opts.command == "create" then + if on_log then on_log("path: " .. vim.inspect(opts.path)) end if opts.file_text == nil then return false, "file_text not provided" end if Path:new(abs_path):exists() then return false, "File already exists: " .. abs_path end local lines = vim.split(opts.file_text, "\n") @@ -186,6 +79,7 @@ function M.str_replace_editor(opts, on_log, on_complete, session_ctx) return end if opts.command == "insert" then + if on_log then on_log("path: " .. vim.inspect(opts.path)) end if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end if opts.insert_line == nil then return false, "insert_line not provided" end @@ -220,6 +114,7 @@ function M.str_replace_editor(opts, on_log, on_complete, session_ctx) return end if opts.command == "undo_edit" then + if on_log then on_log("path: " .. vim.inspect(opts.path)) end if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end local bufnr = get_bufnr() @@ -900,6 +795,7 @@ M._tools = { }, }, }, + require("avante.llm_tools.str_replace"), require("avante.llm_tools.view"), { name = "read_global_file", diff --git a/lua/avante/llm_tools/str_replace.lua b/lua/avante/llm_tools/str_replace.lua new file mode 100644 index 0000000..a900aaa --- /dev/null +++ b/lua/avante/llm_tools/str_replace.lua @@ -0,0 +1,176 @@ +local Path = require("plenary.path") +local Utils = require("avante.utils") +local Base = require("avante.llm_tools.base") +local Helpers = require("avante.llm_tools.helpers") +local Diff = require("avante.diff") + +---@class AvanteLLMTool +local M = setmetatable({}, Base) + +M.name = "str_replace" + +M.description = + "The str_replace command allows you to replace a specific string in a file with a new string. This is used for making precise edits." + +---@type AvanteLLMToolParam +M.param = { + type = "table", + fields = { + { + name = "path", + description = "The path to the file in the current project scope", + type = "string", + }, + { + name = "old_str", + description = "The text to replace (must match exactly, including whitespace and indentation)", + type = "string", + }, + { + name = "new_str", + description = "The new text to insert in place of the old text", + type = "string", + }, + }, +} + +---@type AvanteLLMToolReturn[] +M.returns = { + { + name = "success", + description = "True if the replacement was successful, false otherwise", + type = "boolean", + }, + { + name = "error", + description = "Error message if the replacement failed", + type = "string", + optional = true, + }, +} + +---@type AvanteLLMToolFunc<{ path: string, old_str: string, new_str: string }> +function M.func(opts, on_log, on_complete) + if on_log then on_log("path: " .. opts.path) end + local abs_path = Helpers.get_abs_path(opts.path) + if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end + local sidebar = require("avante").get() + if not sidebar then return false, "Avante sidebar not found" end + local get_bufnr = function() + local current_winid = vim.api.nvim_get_current_win() + vim.api.nvim_set_current_win(sidebar.code.winid) + local bufnr = Utils.get_or_create_buffer_with_filepath(abs_path) + vim.api.nvim_set_current_win(current_winid) + return bufnr + end + if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end + if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end + local file = io.open(abs_path, "r") + if not file then return false, "file not found: " .. abs_path end + if opts.old_str == nil then return false, "old_str not provided" end + if opts.new_str == nil then return false, "new_str not provided" end + Utils.debug("old_str", opts.old_str) + Utils.debug("new_str", opts.new_str) + local bufnr = get_bufnr() + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local lines_content = table.concat(lines, "\n") + local old_lines = vim.split(opts.old_str, "\n") + local new_lines = vim.split(opts.new_str, "\n") + local start_line, end_line + for i = 1, #lines - #old_lines + 1 do + local match = true + for j = 1, #old_lines do + if lines[i + j - 1] ~= old_lines[j] then + match = false + break + end + end + if match then + start_line = i + end_line = i + #old_lines - 1 + break + end + end + if start_line == nil or end_line == nil then + on_complete(false, "Failed to find the old string: " .. opts.old_str) + return + end + ---@diagnostic disable-next-line: assign-type-mismatch, missing-fields + local patch = vim.diff(opts.old_str, opts.new_str, { ---@type integer[][] + algorithm = "histogram", + result_type = "indices", + ctxlen = vim.o.scrolloff, + }) + local patch_start_line_content = "<<<<<<< HEAD" + local patch_end_line_content = ">>>>>>> new " + --- add random characters to the end of the line to avoid conflicts + patch_end_line_content = patch_end_line_content .. Utils.random_string(10) + local current_start_a = 1 + local patched_new_lines = {} + for _, hunk in ipairs(patch) do + local start_a, count_a, start_b, count_b = unpack(hunk) + if current_start_a <= start_a then + if count_a > 0 then + vim.list_extend(patched_new_lines, vim.list_slice(old_lines, current_start_a, start_a - 1)) + else + vim.list_extend(patched_new_lines, vim.list_slice(old_lines, current_start_a, start_a)) + end + end + table.insert(patched_new_lines, patch_start_line_content) + if count_a > 0 then + vim.list_extend(patched_new_lines, vim.list_slice(old_lines, start_a, start_a + count_a - 1)) + end + table.insert(patched_new_lines, "=======") + vim.list_extend(patched_new_lines, vim.list_slice(new_lines, start_b, start_b + count_b - 1)) + table.insert(patched_new_lines, patch_end_line_content) + current_start_a = start_a + math.max(count_a, 1) + end + if current_start_a <= #old_lines then + vim.list_extend(patched_new_lines, vim.list_slice(old_lines, current_start_a, #old_lines)) + end + vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, patched_new_lines) + local current_winid = vim.api.nvim_get_current_win() + vim.api.nvim_set_current_win(sidebar.code.winid) + Diff.add_visited_buffer(bufnr) + Diff.process(bufnr) + if #patch > 0 then + vim.api.nvim_win_set_cursor(sidebar.code.winid, { math.max(patch[1][1] + start_line - 1, 1), 0 }) + end + vim.cmd("normal! zz") + vim.api.nvim_set_current_win(current_winid) + local augroup = vim.api.nvim_create_augroup("avante_str_replace_editor", { clear = true }) + local confirm = Helpers.confirm("Are you sure you want to apply this modification?", function(ok) + pcall(vim.api.nvim_del_augroup_by_id, augroup) + vim.api.nvim_set_current_win(sidebar.code.winid) + vim.api.nvim_feedkeys(vim.api.nvim_replace_termcodes("", true, false, true), "n", true) + vim.cmd("undo") + if not ok then + vim.api.nvim_set_current_win(current_winid) + on_complete(false, "User canceled") + return + end + vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, new_lines) + vim.api.nvim_set_current_win(current_winid) + on_complete(true, nil) + end, { focus = false }) + vim.api.nvim_set_current_win(sidebar.code.winid) + vim.api.nvim_create_autocmd({ "TextChangedI", "TextChanged" }, { + group = augroup, + buffer = bufnr, + callback = function() + local current_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local current_lines_content = table.concat(current_lines, "\n") + if current_lines_content:find(patch_end_line_content) then return end + pcall(vim.api.nvim_del_augroup_by_id, augroup) + if confirm then confirm:close() end + if vim.api.nvim_win_is_valid(current_winid) then vim.api.nvim_set_current_win(current_winid) end + if lines_content == current_lines_content then + on_complete(false, "User canceled") + return + end + on_complete(true, nil) + end, + }) +end + +return M diff --git a/lua/avante/llm_tools/view.lua b/lua/avante/llm_tools/view.lua index 15f0532..4b98dd7 100644 --- a/lua/avante/llm_tools/view.lua +++ b/lua/avante/llm_tools/view.lua @@ -75,6 +75,7 @@ M.returns = { ---@type AvanteLLMToolFunc<{ path: string, view_range?: { start_line: integer, end_line: integer } }> function M.func(opts, on_log, on_complete, session_ctx) if not on_complete then return false, "on_complete not provided" end + if on_log then on_log("path: " .. opts.path) end if Helpers.already_in_context(opts.path) then on_complete(nil, "Ooooops! This file is already in the context! Why you are trying to read it again?") return @@ -91,7 +92,6 @@ function M.func(opts, on_log, on_complete, session_ctx) end local abs_path = Helpers.get_abs_path(opts.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end - if on_log then on_log("path: " .. abs_path) end if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end if Path:new(abs_path):is_dir() then local files = vim.fn.glob(abs_path .. "/*", false, true) diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index 7830d57..b0256d0 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -335,6 +335,7 @@ function M:parse_curl_args(prompt_opts) if Config.behaviour.enable_claude_text_editor_tool_mode then if tool.name == "create_file" then goto continue end if tool.name == "view" then goto continue end + if tool.name == "str_replace" then goto continue end end table.insert(tools, self:transform_tool(tool)) ::continue:: diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 04917cd..a11baaa 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -2551,13 +2551,7 @@ function Sidebar:create_input_container(opts) local mode = "planning" if Config.behaviour.enable_cursor_planning_mode then mode = "cursor-planning" end - local provider_config = Config.get_provider_config(Config.provider) - local is_claude_model = provider_config - and provider_config.model - and (provider_config.model:lower():match("claude") or Config.provider:lower():match("claude")) - if Config.behaviour.enable_claude_text_editor_tool_mode and is_claude_model then - mode = "claude-text-editor-tool" - end + if Config.behaviour.enable_claude_text_editor_tool_mode then mode = "claude-text-editor-tool" end ---@type AvanteGeneratePromptsOptions local prompts_opts = { diff --git a/lua/avante/templates/claude-text-editor-tool.avanterules b/lua/avante/templates/claude-text-editor-tool.avanterules index 28c93ba..a95c245 100644 --- a/lua/avante/templates/claude-text-editor-tool.avanterules +++ b/lua/avante/templates/claude-text-editor-tool.avanterules @@ -1,4 +1,6 @@ {% extends "base.avanterules" %} {% block extra_prompt %} Always reply to the user in the same language they are using. + +Don't just provide code suggestions, use the `str_replace` tool to help users fulfill their needs. {% endblock %}