feat: extract str_replace tool (#1710)

This commit is contained in:
yetone
2025-03-25 18:43:57 +08:00
committed by GitHub
parent 8c4244b940
commit a2aec079c9
7 changed files with 186 additions and 125 deletions

View File

@@ -517,14 +517,6 @@ function M.setup(opts)
vim.validate({ vendors = { M._options.vendors, "table", true } }) vim.validate({ vendors = { M._options.vendors, "table", true } })
M.provider_names = vim.list_extend(M.provider_names, vim.tbl_keys(M._options.vendors)) M.provider_names = vim.list_extend(M.provider_names, vim.tbl_keys(M._options.vendors))
end end
if M._options.behaviour.enable_claude_text_editor_tool_mode and M._options.provider ~= "claude" then
Utils.warn(
"Claude text editor tool mode is only supported for claude provider! So it will be disabled!",
{ title = "Avante" }
)
M._options.behaviour.enable_claude_text_editor_tool_mode = false
end
end end
---@param opts table<string, any> ---@param opts table<string, any>

View File

@@ -3,7 +3,6 @@ local Utils = require("avante.utils")
local Path = require("plenary.path") local Path = require("plenary.path")
local Config = require("avante.config") local Config = require("avante.config")
local RagService = require("avante.rag_service") local RagService = require("avante.rag_service")
local Diff = require("avante.diff")
local Highlights = require("avante.highlights") local Highlights = require("avante.highlights")
local Helpers = require("avante.llm_tools.helpers") local Helpers = require("avante.llm_tools.helpers")
@@ -28,7 +27,6 @@ end
---@type AvanteLLMToolFunc<{ command: "view" | "str_replace" | "create" | "insert" | "undo_edit", path: string, old_str?: string, new_str?: string, file_text?: string, insert_line?: integer, new_str?: string, view_range?: integer[] }> ---@type AvanteLLMToolFunc<{ command: "view" | "str_replace" | "create" | "insert" | "undo_edit", path: string, old_str?: string, new_str?: string, file_text?: string, insert_line?: integer, new_str?: string, view_range?: integer[] }>
function M.str_replace_editor(opts, on_log, on_complete, session_ctx) function M.str_replace_editor(opts, on_log, on_complete, session_ctx)
if on_log then on_log("command: " .. opts.command) end if on_log then on_log("command: " .. opts.command) end
if on_log then on_log("path: " .. vim.inspect(opts.path)) end
if not on_complete then return false, "on_complete not provided" end if not on_complete then return false, "on_complete not provided" end
local abs_path = Helpers.get_abs_path(opts.path) local abs_path = Helpers.get_abs_path(opts.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
@@ -54,115 +52,10 @@ function M.str_replace_editor(opts, on_log, on_complete, session_ctx)
return view(opts_, on_log, on_complete, session_ctx) return view(opts_, on_log, on_complete, session_ctx)
end end
if opts.command == "str_replace" then if opts.command == "str_replace" then
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end return require("avante.llm_tools.str_replace").func(opts, on_log, on_complete)
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
local file = io.open(abs_path, "r")
if not file then return false, "file not found: " .. abs_path end
if opts.old_str == nil then return false, "old_str not provided" end
if opts.new_str == nil then return false, "new_str not provided" end
local bufnr = get_bufnr()
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
local lines_content = table.concat(lines, "\n")
local old_lines = vim.split(opts.old_str, "\n")
local new_lines = vim.split(opts.new_str, "\n")
local start_line, end_line
for i = 1, #lines - #old_lines + 1 do
local match = true
for j = 1, #old_lines do
if lines[i + j - 1] ~= old_lines[j] then
match = false
break
end
end
if match then
start_line = i
end_line = i + #old_lines - 1
break
end
end
if start_line == nil or end_line == nil then
on_complete(false, "Failed to find the old string: " .. opts.old_str)
return
end
---@diagnostic disable-next-line: assign-type-mismatch, missing-fields
local patch = vim.diff(opts.old_str, opts.new_str, { ---@type integer[][]
algorithm = "histogram",
result_type = "indices",
ctxlen = vim.o.scrolloff,
})
local patch_start_line_content = "<<<<<<< HEAD"
local patch_end_line_content = ">>>>>>> new "
--- add random characters to the end of the line to avoid conflicts
patch_end_line_content = patch_end_line_content .. Utils.random_string(10)
local current_start_a = 1
local patched_new_lines = {}
for _, hunk in ipairs(patch) do
local start_a, count_a, start_b, count_b = unpack(hunk)
if current_start_a < start_a then
if count_a > 0 then
vim.list_extend(patched_new_lines, vim.list_slice(old_lines, current_start_a, start_a - 1))
else
vim.list_extend(patched_new_lines, vim.list_slice(old_lines, current_start_a, start_a))
end
end
table.insert(patched_new_lines, patch_start_line_content)
if count_a > 0 then
vim.list_extend(patched_new_lines, vim.list_slice(old_lines, start_a, start_a + count_a - 1))
end
table.insert(patched_new_lines, "=======")
vim.list_extend(patched_new_lines, vim.list_slice(new_lines, start_b, start_b + count_b - 1))
table.insert(patched_new_lines, patch_end_line_content)
current_start_a = start_a + math.max(count_a, 1)
end
if current_start_a <= #old_lines then
vim.list_extend(patched_new_lines, vim.list_slice(old_lines, current_start_a, #old_lines))
end
vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, patched_new_lines)
local current_winid = vim.api.nvim_get_current_win()
vim.api.nvim_set_current_win(sidebar.code.winid)
Diff.add_visited_buffer(bufnr)
Diff.process(bufnr)
if #patch > 0 then
vim.api.nvim_win_set_cursor(sidebar.code.winid, { math.max(patch[1][1] + start_line - 1, 1), 0 })
end
vim.cmd("normal! zz")
vim.api.nvim_set_current_win(current_winid)
local augroup = vim.api.nvim_create_augroup("avante_str_replace_editor", { clear = true })
local confirm = Helpers.confirm("Are you sure you want to apply this modification?", function(ok)
pcall(vim.api.nvim_del_augroup_by_id, augroup)
vim.api.nvim_set_current_win(sidebar.code.winid)
vim.api.nvim_feedkeys(vim.api.nvim_replace_termcodes("<Esc>", true, false, true), "n", true)
vim.cmd("undo")
if not ok then
vim.api.nvim_set_current_win(current_winid)
on_complete(false, "User canceled")
return
end
vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, new_lines)
vim.api.nvim_set_current_win(current_winid)
on_complete(true, nil)
end, { focus = false })
vim.api.nvim_set_current_win(sidebar.code.winid)
vim.api.nvim_create_autocmd({ "TextChangedI", "TextChanged" }, {
group = augroup,
buffer = bufnr,
callback = function()
local current_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
local current_lines_content = table.concat(current_lines, "\n")
if current_lines_content:find(patch_end_line_content) then return end
pcall(vim.api.nvim_del_augroup_by_id, augroup)
if confirm then confirm:close() end
if vim.api.nvim_win_is_valid(current_winid) then vim.api.nvim_set_current_win(current_winid) end
if lines_content == current_lines_content then
on_complete(false, "User canceled")
return
end
on_complete(true, nil)
end,
})
return
end end
if opts.command == "create" then if opts.command == "create" then
if on_log then on_log("path: " .. vim.inspect(opts.path)) end
if opts.file_text == nil then return false, "file_text not provided" end if opts.file_text == nil then return false, "file_text not provided" end
if Path:new(abs_path):exists() then return false, "File already exists: " .. abs_path end if Path:new(abs_path):exists() then return false, "File already exists: " .. abs_path end
local lines = vim.split(opts.file_text, "\n") local lines = vim.split(opts.file_text, "\n")
@@ -186,6 +79,7 @@ function M.str_replace_editor(opts, on_log, on_complete, session_ctx)
return return
end end
if opts.command == "insert" then if opts.command == "insert" then
if on_log then on_log("path: " .. vim.inspect(opts.path)) end
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
if opts.insert_line == nil then return false, "insert_line not provided" end if opts.insert_line == nil then return false, "insert_line not provided" end
@@ -220,6 +114,7 @@ function M.str_replace_editor(opts, on_log, on_complete, session_ctx)
return return
end end
if opts.command == "undo_edit" then if opts.command == "undo_edit" then
if on_log then on_log("path: " .. vim.inspect(opts.path)) end
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
local bufnr = get_bufnr() local bufnr = get_bufnr()
@@ -900,6 +795,7 @@ M._tools = {
}, },
}, },
}, },
require("avante.llm_tools.str_replace"),
require("avante.llm_tools.view"), require("avante.llm_tools.view"),
{ {
name = "read_global_file", name = "read_global_file",

View File

@@ -0,0 +1,176 @@
local Path = require("plenary.path")
local Utils = require("avante.utils")
local Base = require("avante.llm_tools.base")
local Helpers = require("avante.llm_tools.helpers")
local Diff = require("avante.diff")
---@class AvanteLLMTool
local M = setmetatable({}, Base)
M.name = "str_replace"
M.description =
"The str_replace command allows you to replace a specific string in a file with a new string. This is used for making precise edits."
---@type AvanteLLMToolParam
M.param = {
type = "table",
fields = {
{
name = "path",
description = "The path to the file in the current project scope",
type = "string",
},
{
name = "old_str",
description = "The text to replace (must match exactly, including whitespace and indentation)",
type = "string",
},
{
name = "new_str",
description = "The new text to insert in place of the old text",
type = "string",
},
},
}
---@type AvanteLLMToolReturn[]
M.returns = {
{
name = "success",
description = "True if the replacement was successful, false otherwise",
type = "boolean",
},
{
name = "error",
description = "Error message if the replacement failed",
type = "string",
optional = true,
},
}
---@type AvanteLLMToolFunc<{ path: string, old_str: string, new_str: string }>
function M.func(opts, on_log, on_complete)
if on_log then on_log("path: " .. opts.path) end
local abs_path = Helpers.get_abs_path(opts.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
local sidebar = require("avante").get()
if not sidebar then return false, "Avante sidebar not found" end
local get_bufnr = function()
local current_winid = vim.api.nvim_get_current_win()
vim.api.nvim_set_current_win(sidebar.code.winid)
local bufnr = Utils.get_or_create_buffer_with_filepath(abs_path)
vim.api.nvim_set_current_win(current_winid)
return bufnr
end
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
local file = io.open(abs_path, "r")
if not file then return false, "file not found: " .. abs_path end
if opts.old_str == nil then return false, "old_str not provided" end
if opts.new_str == nil then return false, "new_str not provided" end
Utils.debug("old_str", opts.old_str)
Utils.debug("new_str", opts.new_str)
local bufnr = get_bufnr()
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
local lines_content = table.concat(lines, "\n")
local old_lines = vim.split(opts.old_str, "\n")
local new_lines = vim.split(opts.new_str, "\n")
local start_line, end_line
for i = 1, #lines - #old_lines + 1 do
local match = true
for j = 1, #old_lines do
if lines[i + j - 1] ~= old_lines[j] then
match = false
break
end
end
if match then
start_line = i
end_line = i + #old_lines - 1
break
end
end
if start_line == nil or end_line == nil then
on_complete(false, "Failed to find the old string: " .. opts.old_str)
return
end
---@diagnostic disable-next-line: assign-type-mismatch, missing-fields
local patch = vim.diff(opts.old_str, opts.new_str, { ---@type integer[][]
algorithm = "histogram",
result_type = "indices",
ctxlen = vim.o.scrolloff,
})
local patch_start_line_content = "<<<<<<< HEAD"
local patch_end_line_content = ">>>>>>> new "
--- add random characters to the end of the line to avoid conflicts
patch_end_line_content = patch_end_line_content .. Utils.random_string(10)
local current_start_a = 1
local patched_new_lines = {}
for _, hunk in ipairs(patch) do
local start_a, count_a, start_b, count_b = unpack(hunk)
if current_start_a <= start_a then
if count_a > 0 then
vim.list_extend(patched_new_lines, vim.list_slice(old_lines, current_start_a, start_a - 1))
else
vim.list_extend(patched_new_lines, vim.list_slice(old_lines, current_start_a, start_a))
end
end
table.insert(patched_new_lines, patch_start_line_content)
if count_a > 0 then
vim.list_extend(patched_new_lines, vim.list_slice(old_lines, start_a, start_a + count_a - 1))
end
table.insert(patched_new_lines, "=======")
vim.list_extend(patched_new_lines, vim.list_slice(new_lines, start_b, start_b + count_b - 1))
table.insert(patched_new_lines, patch_end_line_content)
current_start_a = start_a + math.max(count_a, 1)
end
if current_start_a <= #old_lines then
vim.list_extend(patched_new_lines, vim.list_slice(old_lines, current_start_a, #old_lines))
end
vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, patched_new_lines)
local current_winid = vim.api.nvim_get_current_win()
vim.api.nvim_set_current_win(sidebar.code.winid)
Diff.add_visited_buffer(bufnr)
Diff.process(bufnr)
if #patch > 0 then
vim.api.nvim_win_set_cursor(sidebar.code.winid, { math.max(patch[1][1] + start_line - 1, 1), 0 })
end
vim.cmd("normal! zz")
vim.api.nvim_set_current_win(current_winid)
local augroup = vim.api.nvim_create_augroup("avante_str_replace_editor", { clear = true })
local confirm = Helpers.confirm("Are you sure you want to apply this modification?", function(ok)
pcall(vim.api.nvim_del_augroup_by_id, augroup)
vim.api.nvim_set_current_win(sidebar.code.winid)
vim.api.nvim_feedkeys(vim.api.nvim_replace_termcodes("<Esc>", true, false, true), "n", true)
vim.cmd("undo")
if not ok then
vim.api.nvim_set_current_win(current_winid)
on_complete(false, "User canceled")
return
end
vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, new_lines)
vim.api.nvim_set_current_win(current_winid)
on_complete(true, nil)
end, { focus = false })
vim.api.nvim_set_current_win(sidebar.code.winid)
vim.api.nvim_create_autocmd({ "TextChangedI", "TextChanged" }, {
group = augroup,
buffer = bufnr,
callback = function()
local current_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
local current_lines_content = table.concat(current_lines, "\n")
if current_lines_content:find(patch_end_line_content) then return end
pcall(vim.api.nvim_del_augroup_by_id, augroup)
if confirm then confirm:close() end
if vim.api.nvim_win_is_valid(current_winid) then vim.api.nvim_set_current_win(current_winid) end
if lines_content == current_lines_content then
on_complete(false, "User canceled")
return
end
on_complete(true, nil)
end,
})
end
return M

View File

@@ -75,6 +75,7 @@ M.returns = {
---@type AvanteLLMToolFunc<{ path: string, view_range?: { start_line: integer, end_line: integer } }> ---@type AvanteLLMToolFunc<{ path: string, view_range?: { start_line: integer, end_line: integer } }>
function M.func(opts, on_log, on_complete, session_ctx) function M.func(opts, on_log, on_complete, session_ctx)
if not on_complete then return false, "on_complete not provided" end if not on_complete then return false, "on_complete not provided" end
if on_log then on_log("path: " .. opts.path) end
if Helpers.already_in_context(opts.path) then if Helpers.already_in_context(opts.path) then
on_complete(nil, "Ooooops! This file is already in the context! Why you are trying to read it again?") on_complete(nil, "Ooooops! This file is already in the context! Why you are trying to read it again?")
return return
@@ -91,7 +92,6 @@ function M.func(opts, on_log, on_complete, session_ctx)
end end
local abs_path = Helpers.get_abs_path(opts.path) local abs_path = Helpers.get_abs_path(opts.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if on_log then on_log("path: " .. abs_path) end
if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end
if Path:new(abs_path):is_dir() then if Path:new(abs_path):is_dir() then
local files = vim.fn.glob(abs_path .. "/*", false, true) local files = vim.fn.glob(abs_path .. "/*", false, true)

View File

@@ -335,6 +335,7 @@ function M:parse_curl_args(prompt_opts)
if Config.behaviour.enable_claude_text_editor_tool_mode then if Config.behaviour.enable_claude_text_editor_tool_mode then
if tool.name == "create_file" then goto continue end if tool.name == "create_file" then goto continue end
if tool.name == "view" then goto continue end if tool.name == "view" then goto continue end
if tool.name == "str_replace" then goto continue end
end end
table.insert(tools, self:transform_tool(tool)) table.insert(tools, self:transform_tool(tool))
::continue:: ::continue::

View File

@@ -2551,13 +2551,7 @@ function Sidebar:create_input_container(opts)
local mode = "planning" local mode = "planning"
if Config.behaviour.enable_cursor_planning_mode then mode = "cursor-planning" end if Config.behaviour.enable_cursor_planning_mode then mode = "cursor-planning" end
local provider_config = Config.get_provider_config(Config.provider) if Config.behaviour.enable_claude_text_editor_tool_mode then mode = "claude-text-editor-tool" end
local is_claude_model = provider_config
and provider_config.model
and (provider_config.model:lower():match("claude") or Config.provider:lower():match("claude"))
if Config.behaviour.enable_claude_text_editor_tool_mode and is_claude_model then
mode = "claude-text-editor-tool"
end
---@type AvanteGeneratePromptsOptions ---@type AvanteGeneratePromptsOptions
local prompts_opts = { local prompts_opts = {

View File

@@ -1,4 +1,6 @@
{% extends "base.avanterules" %} {% extends "base.avanterules" %}
{% block extra_prompt %} {% block extra_prompt %}
Always reply to the user in the same language they are using. Always reply to the user in the same language they are using.
Don't just provide code suggestions, use the `str_replace` tool to help users fulfill their needs.
{% endblock %} {% endblock %}