From b41556ee217fea5d473a3e7afc5bce9b6cfd6a85 Mon Sep 17 00:00:00 2001 From: yetone Date: Wed, 26 Mar 2025 17:59:59 +0800 Subject: [PATCH] feat: extract text editor tools (#1726) --- lua/avante/llm_tools/create.lua | 78 +++++++++++++++++++++++ lua/avante/llm_tools/helpers.lua | 13 ++++ lua/avante/llm_tools/init.lua | 95 ++-------------------------- lua/avante/llm_tools/insert.lua | 89 ++++++++++++++++++++++++++ lua/avante/llm_tools/str_replace.lua | 12 +--- lua/avante/llm_tools/undo_edit.lua | 66 +++++++++++++++++++ lua/avante/llm_tools/view.lua | 2 +- lua/avante/providers/claude.lua | 3 + 8 files changed, 259 insertions(+), 99 deletions(-) create mode 100644 lua/avante/llm_tools/create.lua create mode 100644 lua/avante/llm_tools/insert.lua create mode 100644 lua/avante/llm_tools/undo_edit.lua diff --git a/lua/avante/llm_tools/create.lua b/lua/avante/llm_tools/create.lua new file mode 100644 index 0000000..5e0b16b --- /dev/null +++ b/lua/avante/llm_tools/create.lua @@ -0,0 +1,78 @@ +local Path = require("plenary.path") +local Utils = require("avante.utils") +local Base = require("avante.llm_tools.base") +local Helpers = require("avante.llm_tools.helpers") + +---@class AvanteLLMTool +local M = setmetatable({}, Base) + +M.name = "create" + +M.description = "The create tool allows you to create a new file with specified content." + +---@type AvanteLLMToolParam +M.param = { + type = "table", + fields = { + { + name = "path", + description = "The path where the new file should be created", + type = "string", + }, + { + name = "file_text", + description = "The content to write to the new file", + type = "string", + }, + }, +} + +---@type AvanteLLMToolReturn[] +M.returns = { + { + name = "success", + description = "Whether the file was created successfully", + type = "boolean", + }, + { + name = "error", + description = "Error message if the file was not created successfully", + type = "string", + optional = true, + }, +} + +---@type AvanteLLMToolFunc<{ path: string, file_text: string }> +function M.func(opts, on_log, on_complete) + if not on_complete then return false, "on_complete not provided" end + if on_log then on_log("path: " .. opts.path) end + if Helpers.already_in_context(opts.path) then + on_complete(nil, "Ooooops! This file is already in the context! Why you are trying to create it again?") + return + end + local abs_path = Helpers.get_abs_path(opts.path) + if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end + if opts.file_text == nil then return false, "file_text not provided" end + if Path:new(abs_path):exists() then return false, "File already exists: " .. abs_path end + local lines = vim.split(opts.file_text, "\n") + local bufnr, err = Helpers.get_bufnr(abs_path) + if err then return false, err end + vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, lines) + Helpers.confirm("Are you sure you want to create this file?", function(ok) + if not ok then + -- close the buffer + vim.api.nvim_buf_delete(bufnr, { force = true }) + on_complete(false, "User canceled") + return + end + -- save the file + local current_winid = vim.api.nvim_get_current_win() + local winid = Utils.get_winid(bufnr) + vim.api.nvim_set_current_win(winid) + vim.cmd("write") + vim.api.nvim_set_current_win(current_winid) + on_complete(true, nil) + end) +end + +return M diff --git a/lua/avante/llm_tools/helpers.lua b/lua/avante/llm_tools/helpers.lua index 3bbdf4b..788e4eb 100644 --- a/lua/avante/llm_tools/helpers.lua +++ b/lua/avante/llm_tools/helpers.lua @@ -72,4 +72,17 @@ function M.already_in_context(path) return false end +---@param abs_path string +---@return integer bufnr +---@return string | nil error +function M.get_bufnr(abs_path) + local sidebar = require("avante").get() + if not sidebar then return 0, "Avante sidebar not found" end + local current_winid = vim.api.nvim_get_current_win() + vim.api.nvim_set_current_win(sidebar.code.winid) + local bufnr = Utils.get_or_create_buffer_with_filepath(abs_path) + vim.api.nvim_set_current_win(current_winid) + return bufnr, nil +end + return M diff --git a/lua/avante/llm_tools/init.lua b/lua/avante/llm_tools/init.lua index fe7b63c..927b57f 100644 --- a/lua/avante/llm_tools/init.lua +++ b/lua/avante/llm_tools/init.lua @@ -3,7 +3,6 @@ local Utils = require("avante.utils") local Path = require("plenary.path") local Config = require("avante.config") local RagService = require("avante.rag_service") -local Highlights = require("avante.highlights") local Helpers = require("avante.llm_tools.helpers") local M = {} @@ -30,15 +29,6 @@ function M.str_replace_editor(opts, on_log, on_complete, session_ctx) if not on_complete then return false, "on_complete not provided" end local abs_path = Helpers.get_abs_path(opts.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end - local sidebar = require("avante").get() - if not sidebar then return false, "Avante sidebar not found" end - local get_bufnr = function() - local current_winid = vim.api.nvim_get_current_win() - vim.api.nvim_set_current_win(sidebar.code.winid) - local bufnr = Utils.get_or_create_buffer_with_filepath(abs_path) - vim.api.nvim_set_current_win(current_winid) - return bufnr - end if opts.command == "view" then local view = require("avante.llm_tools.view") local opts_ = { path = opts.path } @@ -54,85 +44,9 @@ function M.str_replace_editor(opts, on_log, on_complete, session_ctx) if opts.command == "str_replace" then return require("avante.llm_tools.str_replace").func(opts, on_log, on_complete) end - if opts.command == "create" then - if on_log then on_log("path: " .. vim.inspect(opts.path)) end - if opts.file_text == nil then return false, "file_text not provided" end - if Path:new(abs_path):exists() then return false, "File already exists: " .. abs_path end - local lines = vim.split(opts.file_text, "\n") - local bufnr = get_bufnr() - vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, lines) - Helpers.confirm("Are you sure you want to create this file?", function(ok) - if not ok then - -- close the buffer - vim.api.nvim_buf_delete(bufnr, { force = true }) - on_complete(false, "User canceled") - return - end - -- save the file - local current_winid = vim.api.nvim_get_current_win() - local winid = Utils.get_winid(bufnr) - vim.api.nvim_set_current_win(winid) - vim.cmd("write") - vim.api.nvim_set_current_win(current_winid) - on_complete(true, nil) - end) - return - end - if opts.command == "insert" then - if on_log then on_log("path: " .. vim.inspect(opts.path)) end - if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end - if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end - if opts.insert_line == nil then return false, "insert_line not provided" end - if opts.new_str == nil then return false, "new_str not provided" end - local ns_id = vim.api.nvim_create_namespace("avante_insert_diff") - local bufnr = get_bufnr() - local function clear_highlights() vim.api.nvim_buf_clear_namespace(bufnr, ns_id, 0, -1) end - local new_lines = vim.split(opts.new_str, "\n") - local max_col = vim.o.columns - local virt_lines = vim - .iter(new_lines) - :map(function(line) - --- append spaces to the end of the line - local line_ = line .. string.rep(" ", max_col - #line) - return { { line_, Highlights.INCOMING } } - end) - :totable() - vim.api.nvim_buf_set_extmark(bufnr, ns_id, opts.insert_line - 1, 0, { - virt_lines = virt_lines, - hl_eol = true, - hl_mode = "combine", - }) - Helpers.confirm("Are you sure you want to insert these lines?", function(ok) - clear_highlights() - if not ok then - on_complete(false, "User canceled") - return - end - vim.api.nvim_buf_set_lines(bufnr, opts.insert_line - 1, opts.insert_line - 1, false, new_lines) - on_complete(true, nil) - end) - return - end - if opts.command == "undo_edit" then - if on_log then on_log("path: " .. vim.inspect(opts.path)) end - if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end - if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end - local bufnr = get_bufnr() - Helpers.confirm("Are you sure you want to undo edit this file?", function(ok) - if not ok then - on_complete(false, "User canceled") - return - end - local current_winid = vim.api.nvim_get_current_win() - local winid = Utils.get_winid(bufnr) - vim.api.nvim_set_current_win(winid) - -- run undo - vim.cmd("undo") - vim.api.nvim_set_current_win(current_winid) - on_complete(true, nil) - end) - return - end + if opts.command == "create" then return require("avante.llm_tools.create").func(opts, on_log, on_complete) end + if opts.command == "insert" then return require("avante.llm_tools.insert").func(opts, on_log, on_complete) end + if opts.command == "undo_edit" then return require("avante.llm_tools.undo_edit").func(opts, on_log, on_complete) end return false, "Unknown command: " .. opts.command end @@ -797,6 +711,9 @@ M._tools = { }, require("avante.llm_tools.str_replace"), require("avante.llm_tools.view"), + require("avante.llm_tools.create"), + require("avante.llm_tools.insert"), + require("avante.llm_tools.undo_edit"), { name = "read_global_file", description = "Read the contents of a file in the global scope. If the file content is already in the context, do not use this tool.", diff --git a/lua/avante/llm_tools/insert.lua b/lua/avante/llm_tools/insert.lua new file mode 100644 index 0000000..2080009 --- /dev/null +++ b/lua/avante/llm_tools/insert.lua @@ -0,0 +1,89 @@ +local Path = require("plenary.path") +local Base = require("avante.llm_tools.base") +local Helpers = require("avante.llm_tools.helpers") +local Highlights = require("avante.highlights") + +---@class AvanteLLMTool +local M = setmetatable({}, Base) + +M.name = "insert" + +M.description = "The insert tool allows you to insert text at a specific location in a file." + +---@type AvanteLLMToolParam +M.param = { + type = "table", + fields = { + { + name = "path", + description = "The path to the file to modify", + type = "string", + }, + { + name = "insert_line", + description = "The line number after which to insert the text (0 for beginning of file)", + type = "integer", + }, + { + name = "new_str", + description = "The text to insert", + type = "string", + }, + }, +} + +---@type AvanteLLMToolReturn[] +M.returns = { + { + name = "success", + description = "True if the text was inserted successfully, false otherwise", + type = "boolean", + }, + { + name = "error", + description = "Error message if the text was not inserted successfully", + type = "string", + optional = true, + }, +} + +---@type AvanteLLMToolFunc<{ path: string, insert_line: integer, new_str: string }> +function M.func(opts, on_log, on_complete) + if on_log then on_log("path: " .. opts.path) end + local abs_path = Helpers.get_abs_path(opts.path) + if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end + if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end + if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end + if opts.insert_line == nil then return false, "insert_line not provided" end + if opts.new_str == nil then return false, "new_str not provided" end + local ns_id = vim.api.nvim_create_namespace("avante_insert_diff") + local bufnr, err = Helpers.get_bufnr(abs_path) + if err then return false, err end + local function clear_highlights() vim.api.nvim_buf_clear_namespace(bufnr, ns_id, 0, -1) end + local new_lines = vim.split(opts.new_str, "\n") + local max_col = vim.o.columns + local virt_lines = vim + .iter(new_lines) + :map(function(line) + --- append spaces to the end of the line + local line_ = line .. string.rep(" ", max_col - #line) + return { { line_, Highlights.INCOMING } } + end) + :totable() + vim.api.nvim_buf_set_extmark(bufnr, ns_id, opts.insert_line, 0, { + virt_lines = virt_lines, + hl_eol = true, + hl_mode = "combine", + }) + Helpers.confirm("Are you sure you want to insert these lines?", function(ok) + clear_highlights() + if not ok then + on_complete(false, "User canceled") + return + end + vim.api.nvim_buf_set_lines(bufnr, opts.insert_line, opts.insert_line, false, new_lines) + on_complete(true, nil) + end) +end + +return M diff --git a/lua/avante/llm_tools/str_replace.lua b/lua/avante/llm_tools/str_replace.lua index a900aaa..9f47001 100644 --- a/lua/avante/llm_tools/str_replace.lua +++ b/lua/avante/llm_tools/str_replace.lua @@ -10,7 +10,7 @@ local M = setmetatable({}, Base) M.name = "str_replace" M.description = - "The str_replace command allows you to replace a specific string in a file with a new string. This is used for making precise edits." + "The str_replace tool allows you to replace a specific string in a file with a new string. This is used for making precise edits." ---@type AvanteLLMToolParam M.param = { @@ -56,13 +56,6 @@ function M.func(opts, on_log, on_complete) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end local sidebar = require("avante").get() if not sidebar then return false, "Avante sidebar not found" end - local get_bufnr = function() - local current_winid = vim.api.nvim_get_current_win() - vim.api.nvim_set_current_win(sidebar.code.winid) - local bufnr = Utils.get_or_create_buffer_with_filepath(abs_path) - vim.api.nvim_set_current_win(current_winid) - return bufnr - end if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end local file = io.open(abs_path, "r") @@ -71,7 +64,8 @@ function M.func(opts, on_log, on_complete) if opts.new_str == nil then return false, "new_str not provided" end Utils.debug("old_str", opts.old_str) Utils.debug("new_str", opts.new_str) - local bufnr = get_bufnr() + local bufnr, err = Helpers.get_bufnr(abs_path) + if err then return false, err end local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) local lines_content = table.concat(lines, "\n") local old_lines = vim.split(opts.old_str, "\n") diff --git a/lua/avante/llm_tools/undo_edit.lua b/lua/avante/llm_tools/undo_edit.lua new file mode 100644 index 0000000..cbcbb47 --- /dev/null +++ b/lua/avante/llm_tools/undo_edit.lua @@ -0,0 +1,66 @@ +local Path = require("plenary.path") +local Base = require("avante.llm_tools.base") +local Helpers = require("avante.llm_tools.helpers") +local Utils = require("avante.utils") + +---@class AvanteLLMTool +local M = setmetatable({}, Base) + +M.name = "undo_edit" + +M.description = "The undo_edit tool allows you to revert the last edit made to a file." + +---@type AvanteLLMToolParam +M.param = { + type = "table", + fields = { + { + name = "path", + description = "The path to the file whose last edit should be undone", + type = "string", + }, + }, +} + +---@type AvanteLLMToolReturn[] +M.returns = { + { + name = "success", + description = "True if the edit was undone successfully, false otherwise", + type = "boolean", + }, + { + name = "error", + description = "Error message if the edit was not undone successfully", + type = "string", + optional = true, + }, +} + +---@type AvanteLLMToolFunc<{ path: string }> +function M.func(opts, on_log, on_complete) + if on_log then on_log("path: " .. opts.path) end + local abs_path = Helpers.get_abs_path(opts.path) + if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end + if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end + if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end + local bufnr, err = Helpers.get_bufnr(abs_path) + if err then return false, err end + local current_winid = vim.api.nvim_get_current_win() + local winid = Utils.get_winid(bufnr) + vim.api.nvim_set_current_win(winid) + vim.api.nvim_set_current_win(current_winid) + Helpers.confirm("Are you sure you want to undo edit this file?", function(ok) + if not ok then + on_complete(false, "User canceled") + return + end + vim.api.nvim_set_current_win(winid) + -- run undo + vim.cmd("undo") + vim.api.nvim_set_current_win(current_winid) + on_complete(true, nil) + end) +end + +return M diff --git a/lua/avante/llm_tools/view.lua b/lua/avante/llm_tools/view.lua index 4b98dd7..a049c6c 100644 --- a/lua/avante/llm_tools/view.lua +++ b/lua/avante/llm_tools/view.lua @@ -9,7 +9,7 @@ local M = setmetatable({}, Base) M.name = "view" M.description = - "The view command allows you to examine the contents of a file or list the contents of a directory. It can read the entire file or a specific range of lines. If the file content is already in the context, do not use this tool." + "The view tool allows you to examine the contents of a file or list the contents of a directory. It can read the entire file or a specific range of lines. If the file content is already in the context, do not use this tool." M.enabled = function(opts) if opts.user_input:match("@read_global_file") then return false end diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index b0256d0..deda9be 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -336,6 +336,9 @@ function M:parse_curl_args(prompt_opts) if tool.name == "create_file" then goto continue end if tool.name == "view" then goto continue end if tool.name == "str_replace" then goto continue end + if tool.name == "create" then goto continue end + if tool.name == "insert" then goto continue end + if tool.name == "undo_edit" then goto continue end end table.insert(tools, self:transform_tool(tool)) ::continue::