chore: make tool func type more accurate (#1369)
This commit is contained in:
@@ -34,9 +34,7 @@ local function has_permission_to_access(abs_path)
|
||||
end
|
||||
|
||||
---@param opts { rel_path: string, max_depth?: integer }
|
||||
---@param on_log? fun(log: string): nil
|
||||
---@return string files
|
||||
---@return string|nil error
|
||||
---@type AvanteLLMToolFunc
|
||||
function M.list_files(opts, on_log)
|
||||
local abs_path = get_abs_path(opts.rel_path)
|
||||
if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
|
||||
@@ -56,9 +54,7 @@ function M.list_files(opts, on_log)
|
||||
end
|
||||
|
||||
---@param opts { rel_path: string, keyword: string }
|
||||
---@param on_log? fun(log: string): nil
|
||||
---@return string files
|
||||
---@return string|nil error
|
||||
---@type AvanteLLMToolFunc
|
||||
function M.search_files(opts, on_log)
|
||||
local abs_path = get_abs_path(opts.rel_path)
|
||||
if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
|
||||
@@ -75,9 +71,7 @@ function M.search_files(opts, on_log)
|
||||
end
|
||||
|
||||
---@param opts { rel_path: string, keyword: string }
|
||||
---@param on_log? fun(log: string): nil
|
||||
---@return string result
|
||||
---@return string|nil error
|
||||
---@type AvanteLLMToolFunc
|
||||
function M.search(opts, on_log)
|
||||
local abs_path = get_abs_path(opts.rel_path)
|
||||
if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
|
||||
@@ -114,9 +108,7 @@ function M.search(opts, on_log)
|
||||
end
|
||||
|
||||
---@param opts { rel_path: string }
|
||||
---@param on_log? fun(log: string): nil
|
||||
---@return string definitions
|
||||
---@return string|nil error
|
||||
---@type AvanteLLMToolFunc
|
||||
function M.read_file_toplevel_symbols(opts, on_log)
|
||||
local RepoMap = require("avante.repo_map")
|
||||
local abs_path = get_abs_path(opts.rel_path)
|
||||
@@ -133,9 +125,7 @@ function M.read_file_toplevel_symbols(opts, on_log)
|
||||
end
|
||||
|
||||
---@param opts { rel_path: string }
|
||||
---@param on_log? fun(log: string): nil
|
||||
---@return string content
|
||||
---@return string|nil error
|
||||
---@type AvanteLLMToolFunc
|
||||
function M.read_file(opts, on_log)
|
||||
local abs_path = get_abs_path(opts.rel_path)
|
||||
if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
|
||||
@@ -148,9 +138,7 @@ function M.read_file(opts, on_log)
|
||||
end
|
||||
|
||||
---@param opts { rel_path: string }
|
||||
---@param on_log? fun(log: string): nil
|
||||
---@return boolean success
|
||||
---@return string|nil error
|
||||
---@type AvanteLLMToolFunc
|
||||
function M.create_file(opts, on_log)
|
||||
local abs_path = get_abs_path(opts.rel_path)
|
||||
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
@@ -169,9 +157,7 @@ function M.create_file(opts, on_log)
|
||||
end
|
||||
|
||||
---@param opts { rel_path: string, new_rel_path: string }
|
||||
---@param on_log? fun(log: string): nil
|
||||
---@return boolean success
|
||||
---@return string|nil error
|
||||
---@type AvanteLLMToolFunc
|
||||
function M.rename_file(opts, on_log)
|
||||
local abs_path = get_abs_path(opts.rel_path)
|
||||
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
@@ -189,9 +175,7 @@ function M.rename_file(opts, on_log)
|
||||
end
|
||||
|
||||
---@param opts { rel_path: string, new_rel_path: string }
|
||||
---@param on_log? fun(log: string): nil
|
||||
---@return boolean success
|
||||
---@return string|nil error
|
||||
---@type AvanteLLMToolFunc
|
||||
function M.copy_file(opts, on_log)
|
||||
local abs_path = get_abs_path(opts.rel_path)
|
||||
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
@@ -206,9 +190,7 @@ function M.copy_file(opts, on_log)
|
||||
end
|
||||
|
||||
---@param opts { rel_path: string }
|
||||
---@param on_log? fun(log: string): nil
|
||||
---@return boolean success
|
||||
---@return string|nil error
|
||||
---@type AvanteLLMToolFunc
|
||||
function M.delete_file(opts, on_log)
|
||||
local abs_path = get_abs_path(opts.rel_path)
|
||||
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
@@ -221,9 +203,7 @@ function M.delete_file(opts, on_log)
|
||||
end
|
||||
|
||||
---@param opts { rel_path: string }
|
||||
---@param on_log? fun(log: string): nil
|
||||
---@return boolean success
|
||||
---@return string|nil error
|
||||
---@type AvanteLLMToolFunc
|
||||
function M.create_dir(opts, on_log)
|
||||
local abs_path = get_abs_path(opts.rel_path)
|
||||
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
@@ -237,9 +217,7 @@ function M.create_dir(opts, on_log)
|
||||
end
|
||||
|
||||
---@param opts { rel_path: string, new_rel_path: string }
|
||||
---@param on_log? fun(log: string): nil
|
||||
---@return boolean success
|
||||
---@return string|nil error
|
||||
---@type AvanteLLMToolFunc
|
||||
function M.rename_dir(opts, on_log)
|
||||
local abs_path = get_abs_path(opts.rel_path)
|
||||
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
@@ -257,9 +235,7 @@ function M.rename_dir(opts, on_log)
|
||||
end
|
||||
|
||||
---@param opts { rel_path: string }
|
||||
---@param on_log? fun(log: string): nil
|
||||
---@return boolean success
|
||||
---@return string|nil error
|
||||
---@type AvanteLLMToolFunc
|
||||
function M.delete_dir(opts, on_log)
|
||||
local abs_path = get_abs_path(opts.rel_path)
|
||||
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
@@ -274,9 +250,7 @@ function M.delete_dir(opts, on_log)
|
||||
end
|
||||
|
||||
---@param opts { rel_path: string, command: string }
|
||||
---@param on_log? fun(log: string): nil
|
||||
---@return string|boolean result
|
||||
---@return string|nil error
|
||||
---@type AvanteLLMToolFunc
|
||||
function M.run_command(opts, on_log)
|
||||
local abs_path = get_abs_path(opts.rel_path)
|
||||
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
@@ -300,9 +274,7 @@ function M.run_command(opts, on_log)
|
||||
end
|
||||
|
||||
---@param opts { query: string }
|
||||
---@param on_log? fun(log: string): nil
|
||||
---@return string|nil result
|
||||
---@return string|nil error
|
||||
---@type AvanteLLMToolFunc
|
||||
function M.web_search(opts, on_log)
|
||||
local provider_type = Config.web_search_engine.provider
|
||||
if provider_type == nil then return nil, "Search engine provider is not set" end
|
||||
@@ -405,9 +377,7 @@ function M.web_search(opts, on_log)
|
||||
end
|
||||
|
||||
---@param opts { url: string }
|
||||
---@param on_log? fun(log: string): nil
|
||||
---@return string|nil result
|
||||
---@return string|nil error
|
||||
---@type AvanteLLMToolFunc
|
||||
function M.fetch(opts, on_log)
|
||||
if on_log then on_log("url: " .. opts.url) end
|
||||
local Html2Md = require("avante.html2md")
|
||||
@@ -417,9 +387,7 @@ function M.fetch(opts, on_log)
|
||||
end
|
||||
|
||||
---@param opts { scope?: string }
|
||||
---@param on_log? fun(log: string): nil
|
||||
---@return string|nil result
|
||||
---@return string|nil error
|
||||
---@type AvanteLLMToolFunc
|
||||
function M.git_diff(opts, on_log)
|
||||
local git_cmd = vim.fn.exepath("git")
|
||||
if git_cmd == "" then return nil, "Git command not found" end
|
||||
@@ -449,9 +417,7 @@ function M.git_diff(opts, on_log)
|
||||
end
|
||||
|
||||
---@param opts { message: string, scope?: string }
|
||||
---@param on_log? fun(log: string): nil
|
||||
---@return boolean success
|
||||
---@return string|nil error
|
||||
---@type AvanteLLMToolFunc
|
||||
function M.git_commit(opts, on_log)
|
||||
local git_cmd = vim.fn.exepath("git")
|
||||
if git_cmd == "" then return false, "Git command not found" end
|
||||
@@ -539,9 +505,7 @@ function M.git_commit(opts, on_log)
|
||||
end
|
||||
|
||||
---@param opts { query: string }
|
||||
---@param on_log? fun(log: string): nil
|
||||
---@return string|nil result
|
||||
---@return string|nil error
|
||||
---@type AvanteLLMToolFunc
|
||||
function M.rag_search(opts, on_log)
|
||||
if not Config.rag_service.enabled then return nil, "Rag service is not enabled" end
|
||||
if not opts.query then return nil, "No query provided" end
|
||||
@@ -555,9 +519,7 @@ function M.rag_search(opts, on_log)
|
||||
end
|
||||
|
||||
---@param opts { code: string, rel_path: string, container_image?: string }
|
||||
---@param on_log? fun(log: string): nil
|
||||
---@return string|nil result
|
||||
---@return string|nil error
|
||||
---@type AvanteLLMToolFunc
|
||||
function M.python(opts, on_log)
|
||||
local abs_path = get_abs_path(opts.rel_path)
|
||||
if not has_permission_to_access(abs_path) then return nil, "No permission to access path: " .. abs_path end
|
||||
@@ -1147,22 +1109,28 @@ M._tools = {
|
||||
---@return string | nil error
|
||||
function M.process_tool_use(tools, tool_use, on_log)
|
||||
Utils.debug("use tool", tool_use.name, tool_use.input_json)
|
||||
local tool = vim.iter(tools):find(function(tool) return tool.name == tool_use.name end)
|
||||
---@type AvanteLLMTool?
|
||||
local tool = vim.iter(tools):find(function(tool) return tool.name == tool_use.name end) ---@param tool AvanteLLMTool
|
||||
if tool == nil then return end
|
||||
local input_json = vim.json.decode(tool_use.input_json)
|
||||
local func = tool.func or M[tool.name]
|
||||
if on_log then on_log(tool_use.name, "running tool") end
|
||||
if on_log then on_log(tool.name, "running tool") end
|
||||
local result, error = func(input_json, function(log)
|
||||
if on_log then on_log(tool_use.name, log) end
|
||||
if on_log then on_log(tool.name, log) end
|
||||
end)
|
||||
if on_log then on_log(tool_use.name, "tool finished") end
|
||||
if on_log then on_log(tool.name, "tool finished") end
|
||||
-- Utils.debug("result", result)
|
||||
-- Utils.debug("error", error)
|
||||
if error ~= nil then
|
||||
if on_log then on_log(tool_use.name, "Error: " .. error) end
|
||||
if on_log then on_log(tool.name, "Error: " .. error) end
|
||||
end
|
||||
if result ~= nil and type(result) ~= "string" then result = vim.json.encode(result) end
|
||||
return result, error
|
||||
local result_str ---@type string?
|
||||
if type(result) == "string" then
|
||||
result_str = result
|
||||
elseif result ~= nil then
|
||||
result_str = vim.json.encode(result)
|
||||
end
|
||||
return result_str, error
|
||||
end
|
||||
|
||||
---@param tool_use AvanteLLMToolUse
|
||||
|
||||
@@ -2292,8 +2292,7 @@ function Sidebar:create_input_container(opts)
|
||||
name = "add_file_to_context",
|
||||
description = "Add a file to the context",
|
||||
---@param input { rel_path: string }
|
||||
---@return string | nil result
|
||||
---@return string | nil error
|
||||
---@type AvanteLLMToolFunc
|
||||
func = function(input)
|
||||
self.file_selector:add_selected_file(input.rel_path)
|
||||
return "Added file to context", nil
|
||||
@@ -2309,8 +2308,7 @@ function Sidebar:create_input_container(opts)
|
||||
name = "remove_file_from_context",
|
||||
description = "Remove a file from the context",
|
||||
---@param input { rel_path: string }
|
||||
---@return string | nil result
|
||||
---@return string | nil error
|
||||
---@type AvanteLLMToolFunc
|
||||
func = function(input)
|
||||
self.file_selector:remove_selected_file(input.rel_path)
|
||||
return "Removed file from context", nil
|
||||
|
||||
@@ -323,10 +323,12 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field on_stop AvanteLLMStopCallback
|
||||
---@field on_tool_log? function(tool_name: string, log: string): nil
|
||||
---
|
||||
---@alias AvanteLLMToolFunc fun(input: any, on_log?: (fun(log: string): nil) | nil): (boolean | string | nil, string | nil)
|
||||
---
|
||||
---@class AvanteLLMTool
|
||||
---@field name string
|
||||
---@field description string
|
||||
---@field func? fun(input: any): (string | nil, string | nil)
|
||||
---@field func? AvanteLLMToolFunc
|
||||
---@field param AvanteLLMToolParam
|
||||
---@field returns AvanteLLMToolReturn[]
|
||||
---@field enabled? fun(): boolean
|
||||
|
||||
Reference in New Issue
Block a user