diff --git a/lua/avante/llm_tools.lua b/lua/avante/llm_tools.lua index 0ae1323..d4988a9 100644 --- a/lua/avante/llm_tools.lua +++ b/lua/avante/llm_tools.lua @@ -34,9 +34,7 @@ local function has_permission_to_access(abs_path) end ---@param opts { rel_path: string, max_depth?: integer } ----@param on_log? fun(log: string): nil ----@return string files ----@return string|nil error +---@type AvanteLLMToolFunc function M.list_files(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end @@ -56,9 +54,7 @@ function M.list_files(opts, on_log) end ---@param opts { rel_path: string, keyword: string } ----@param on_log? fun(log: string): nil ----@return string files ----@return string|nil error +---@type AvanteLLMToolFunc function M.search_files(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end @@ -75,9 +71,7 @@ function M.search_files(opts, on_log) end ---@param opts { rel_path: string, keyword: string } ----@param on_log? fun(log: string): nil ----@return string result ----@return string|nil error +---@type AvanteLLMToolFunc function M.search(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end @@ -114,9 +108,7 @@ function M.search(opts, on_log) end ---@param opts { rel_path: string } ----@param on_log? fun(log: string): nil ----@return string definitions ----@return string|nil error +---@type AvanteLLMToolFunc function M.read_file_toplevel_symbols(opts, on_log) local RepoMap = require("avante.repo_map") local abs_path = get_abs_path(opts.rel_path) @@ -133,9 +125,7 @@ function M.read_file_toplevel_symbols(opts, on_log) end ---@param opts { rel_path: string } ----@param on_log? fun(log: string): nil ----@return string content ----@return string|nil error +---@type AvanteLLMToolFunc function M.read_file(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end @@ -148,9 +138,7 @@ function M.read_file(opts, on_log) end ---@param opts { rel_path: string } ----@param on_log? fun(log: string): nil ----@return boolean success ----@return string|nil error +---@type AvanteLLMToolFunc function M.create_file(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end @@ -169,9 +157,7 @@ function M.create_file(opts, on_log) end ---@param opts { rel_path: string, new_rel_path: string } ----@param on_log? fun(log: string): nil ----@return boolean success ----@return string|nil error +---@type AvanteLLMToolFunc function M.rename_file(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end @@ -189,9 +175,7 @@ function M.rename_file(opts, on_log) end ---@param opts { rel_path: string, new_rel_path: string } ----@param on_log? fun(log: string): nil ----@return boolean success ----@return string|nil error +---@type AvanteLLMToolFunc function M.copy_file(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end @@ -206,9 +190,7 @@ function M.copy_file(opts, on_log) end ---@param opts { rel_path: string } ----@param on_log? fun(log: string): nil ----@return boolean success ----@return string|nil error +---@type AvanteLLMToolFunc function M.delete_file(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end @@ -221,9 +203,7 @@ function M.delete_file(opts, on_log) end ---@param opts { rel_path: string } ----@param on_log? fun(log: string): nil ----@return boolean success ----@return string|nil error +---@type AvanteLLMToolFunc function M.create_dir(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end @@ -237,9 +217,7 @@ function M.create_dir(opts, on_log) end ---@param opts { rel_path: string, new_rel_path: string } ----@param on_log? fun(log: string): nil ----@return boolean success ----@return string|nil error +---@type AvanteLLMToolFunc function M.rename_dir(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end @@ -257,9 +235,7 @@ function M.rename_dir(opts, on_log) end ---@param opts { rel_path: string } ----@param on_log? fun(log: string): nil ----@return boolean success ----@return string|nil error +---@type AvanteLLMToolFunc function M.delete_dir(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end @@ -274,9 +250,7 @@ function M.delete_dir(opts, on_log) end ---@param opts { rel_path: string, command: string } ----@param on_log? fun(log: string): nil ----@return string|boolean result ----@return string|nil error +---@type AvanteLLMToolFunc function M.run_command(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end @@ -300,9 +274,7 @@ function M.run_command(opts, on_log) end ---@param opts { query: string } ----@param on_log? fun(log: string): nil ----@return string|nil result ----@return string|nil error +---@type AvanteLLMToolFunc function M.web_search(opts, on_log) local provider_type = Config.web_search_engine.provider if provider_type == nil then return nil, "Search engine provider is not set" end @@ -405,9 +377,7 @@ function M.web_search(opts, on_log) end ---@param opts { url: string } ----@param on_log? fun(log: string): nil ----@return string|nil result ----@return string|nil error +---@type AvanteLLMToolFunc function M.fetch(opts, on_log) if on_log then on_log("url: " .. opts.url) end local Html2Md = require("avante.html2md") @@ -417,9 +387,7 @@ function M.fetch(opts, on_log) end ---@param opts { scope?: string } ----@param on_log? fun(log: string): nil ----@return string|nil result ----@return string|nil error +---@type AvanteLLMToolFunc function M.git_diff(opts, on_log) local git_cmd = vim.fn.exepath("git") if git_cmd == "" then return nil, "Git command not found" end @@ -449,9 +417,7 @@ function M.git_diff(opts, on_log) end ---@param opts { message: string, scope?: string } ----@param on_log? fun(log: string): nil ----@return boolean success ----@return string|nil error +---@type AvanteLLMToolFunc function M.git_commit(opts, on_log) local git_cmd = vim.fn.exepath("git") if git_cmd == "" then return false, "Git command not found" end @@ -539,9 +505,7 @@ function M.git_commit(opts, on_log) end ---@param opts { query: string } ----@param on_log? fun(log: string): nil ----@return string|nil result ----@return string|nil error +---@type AvanteLLMToolFunc function M.rag_search(opts, on_log) if not Config.rag_service.enabled then return nil, "Rag service is not enabled" end if not opts.query then return nil, "No query provided" end @@ -555,9 +519,7 @@ function M.rag_search(opts, on_log) end ---@param opts { code: string, rel_path: string, container_image?: string } ----@param on_log? fun(log: string): nil ----@return string|nil result ----@return string|nil error +---@type AvanteLLMToolFunc function M.python(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return nil, "No permission to access path: " .. abs_path end @@ -1147,22 +1109,28 @@ M._tools = { ---@return string | nil error function M.process_tool_use(tools, tool_use, on_log) Utils.debug("use tool", tool_use.name, tool_use.input_json) - local tool = vim.iter(tools):find(function(tool) return tool.name == tool_use.name end) + ---@type AvanteLLMTool? + local tool = vim.iter(tools):find(function(tool) return tool.name == tool_use.name end) ---@param tool AvanteLLMTool if tool == nil then return end local input_json = vim.json.decode(tool_use.input_json) local func = tool.func or M[tool.name] - if on_log then on_log(tool_use.name, "running tool") end + if on_log then on_log(tool.name, "running tool") end local result, error = func(input_json, function(log) - if on_log then on_log(tool_use.name, log) end + if on_log then on_log(tool.name, log) end end) - if on_log then on_log(tool_use.name, "tool finished") end + if on_log then on_log(tool.name, "tool finished") end -- Utils.debug("result", result) -- Utils.debug("error", error) if error ~= nil then - if on_log then on_log(tool_use.name, "Error: " .. error) end + if on_log then on_log(tool.name, "Error: " .. error) end end - if result ~= nil and type(result) ~= "string" then result = vim.json.encode(result) end - return result, error + local result_str ---@type string? + if type(result) == "string" then + result_str = result + elseif result ~= nil then + result_str = vim.json.encode(result) + end + return result_str, error end ---@param tool_use AvanteLLMToolUse diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 9696ac5..2001483 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -2292,8 +2292,7 @@ function Sidebar:create_input_container(opts) name = "add_file_to_context", description = "Add a file to the context", ---@param input { rel_path: string } - ---@return string | nil result - ---@return string | nil error + ---@type AvanteLLMToolFunc func = function(input) self.file_selector:add_selected_file(input.rel_path) return "Added file to context", nil @@ -2309,8 +2308,7 @@ function Sidebar:create_input_container(opts) name = "remove_file_from_context", description = "Remove a file from the context", ---@param input { rel_path: string } - ---@return string | nil result - ---@return string | nil error + ---@type AvanteLLMToolFunc func = function(input) self.file_selector:remove_selected_file(input.rel_path) return "Removed file from context", nil diff --git a/lua/avante/types.lua b/lua/avante/types.lua index 9034e63..79ae650 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -323,10 +323,12 @@ vim.g.avante_login = vim.g.avante_login ---@field on_stop AvanteLLMStopCallback ---@field on_tool_log? function(tool_name: string, log: string): nil --- +---@alias AvanteLLMToolFunc fun(input: any, on_log?: (fun(log: string): nil) | nil): (boolean | string | nil, string | nil) +--- ---@class AvanteLLMTool ---@field name string ---@field description string ----@field func? fun(input: any): (string | nil, string | nil) +---@field func? AvanteLLMToolFunc ---@field param AvanteLLMToolParam ---@field returns AvanteLLMToolReturn[] ---@field enabled? fun(): boolean