chore: make tool func type more accurate (#1369)

This commit is contained in:
Peter Cardenas
2025-02-24 05:49:06 -08:00
committed by GitHub
parent c2188e1afd
commit fe496a9573
3 changed files with 37 additions and 69 deletions

View File

@@ -34,9 +34,7 @@ local function has_permission_to_access(abs_path)
end
---@param opts { rel_path: string, max_depth?: integer }
---@param on_log? fun(log: string): nil
---@return string files
---@return string|nil error
---@type AvanteLLMToolFunc
function M.list_files(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
@@ -56,9 +54,7 @@ function M.list_files(opts, on_log)
end
---@param opts { rel_path: string, keyword: string }
---@param on_log? fun(log: string): nil
---@return string files
---@return string|nil error
---@type AvanteLLMToolFunc
function M.search_files(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
@@ -75,9 +71,7 @@ function M.search_files(opts, on_log)
end
---@param opts { rel_path: string, keyword: string }
---@param on_log? fun(log: string): nil
---@return string result
---@return string|nil error
---@type AvanteLLMToolFunc
function M.search(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
@@ -114,9 +108,7 @@ function M.search(opts, on_log)
end
---@param opts { rel_path: string }
---@param on_log? fun(log: string): nil
---@return string definitions
---@return string|nil error
---@type AvanteLLMToolFunc
function M.read_file_toplevel_symbols(opts, on_log)
local RepoMap = require("avante.repo_map")
local abs_path = get_abs_path(opts.rel_path)
@@ -133,9 +125,7 @@ function M.read_file_toplevel_symbols(opts, on_log)
end
---@param opts { rel_path: string }
---@param on_log? fun(log: string): nil
---@return string content
---@return string|nil error
---@type AvanteLLMToolFunc
function M.read_file(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
@@ -148,9 +138,7 @@ function M.read_file(opts, on_log)
end
---@param opts { rel_path: string }
---@param on_log? fun(log: string): nil
---@return boolean success
---@return string|nil error
---@type AvanteLLMToolFunc
function M.create_file(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
@@ -169,9 +157,7 @@ function M.create_file(opts, on_log)
end
---@param opts { rel_path: string, new_rel_path: string }
---@param on_log? fun(log: string): nil
---@return boolean success
---@return string|nil error
---@type AvanteLLMToolFunc
function M.rename_file(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
@@ -189,9 +175,7 @@ function M.rename_file(opts, on_log)
end
---@param opts { rel_path: string, new_rel_path: string }
---@param on_log? fun(log: string): nil
---@return boolean success
---@return string|nil error
---@type AvanteLLMToolFunc
function M.copy_file(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
@@ -206,9 +190,7 @@ function M.copy_file(opts, on_log)
end
---@param opts { rel_path: string }
---@param on_log? fun(log: string): nil
---@return boolean success
---@return string|nil error
---@type AvanteLLMToolFunc
function M.delete_file(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
@@ -221,9 +203,7 @@ function M.delete_file(opts, on_log)
end
---@param opts { rel_path: string }
---@param on_log? fun(log: string): nil
---@return boolean success
---@return string|nil error
---@type AvanteLLMToolFunc
function M.create_dir(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
@@ -237,9 +217,7 @@ function M.create_dir(opts, on_log)
end
---@param opts { rel_path: string, new_rel_path: string }
---@param on_log? fun(log: string): nil
---@return boolean success
---@return string|nil error
---@type AvanteLLMToolFunc
function M.rename_dir(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
@@ -257,9 +235,7 @@ function M.rename_dir(opts, on_log)
end
---@param opts { rel_path: string }
---@param on_log? fun(log: string): nil
---@return boolean success
---@return string|nil error
---@type AvanteLLMToolFunc
function M.delete_dir(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
@@ -274,9 +250,7 @@ function M.delete_dir(opts, on_log)
end
---@param opts { rel_path: string, command: string }
---@param on_log? fun(log: string): nil
---@return string|boolean result
---@return string|nil error
---@type AvanteLLMToolFunc
function M.run_command(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
@@ -300,9 +274,7 @@ function M.run_command(opts, on_log)
end
---@param opts { query: string }
---@param on_log? fun(log: string): nil
---@return string|nil result
---@return string|nil error
---@type AvanteLLMToolFunc
function M.web_search(opts, on_log)
local provider_type = Config.web_search_engine.provider
if provider_type == nil then return nil, "Search engine provider is not set" end
@@ -405,9 +377,7 @@ function M.web_search(opts, on_log)
end
---@param opts { url: string }
---@param on_log? fun(log: string): nil
---@return string|nil result
---@return string|nil error
---@type AvanteLLMToolFunc
function M.fetch(opts, on_log)
if on_log then on_log("url: " .. opts.url) end
local Html2Md = require("avante.html2md")
@@ -417,9 +387,7 @@ function M.fetch(opts, on_log)
end
---@param opts { scope?: string }
---@param on_log? fun(log: string): nil
---@return string|nil result
---@return string|nil error
---@type AvanteLLMToolFunc
function M.git_diff(opts, on_log)
local git_cmd = vim.fn.exepath("git")
if git_cmd == "" then return nil, "Git command not found" end
@@ -449,9 +417,7 @@ function M.git_diff(opts, on_log)
end
---@param opts { message: string, scope?: string }
---@param on_log? fun(log: string): nil
---@return boolean success
---@return string|nil error
---@type AvanteLLMToolFunc
function M.git_commit(opts, on_log)
local git_cmd = vim.fn.exepath("git")
if git_cmd == "" then return false, "Git command not found" end
@@ -539,9 +505,7 @@ function M.git_commit(opts, on_log)
end
---@param opts { query: string }
---@param on_log? fun(log: string): nil
---@return string|nil result
---@return string|nil error
---@type AvanteLLMToolFunc
function M.rag_search(opts, on_log)
if not Config.rag_service.enabled then return nil, "Rag service is not enabled" end
if not opts.query then return nil, "No query provided" end
@@ -555,9 +519,7 @@ function M.rag_search(opts, on_log)
end
---@param opts { code: string, rel_path: string, container_image?: string }
---@param on_log? fun(log: string): nil
---@return string|nil result
---@return string|nil error
---@type AvanteLLMToolFunc
function M.python(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return nil, "No permission to access path: " .. abs_path end
@@ -1147,22 +1109,28 @@ M._tools = {
---@return string | nil error
function M.process_tool_use(tools, tool_use, on_log)
Utils.debug("use tool", tool_use.name, tool_use.input_json)
local tool = vim.iter(tools):find(function(tool) return tool.name == tool_use.name end)
---@type AvanteLLMTool?
local tool = vim.iter(tools):find(function(tool) return tool.name == tool_use.name end) ---@param tool AvanteLLMTool
if tool == nil then return end
local input_json = vim.json.decode(tool_use.input_json)
local func = tool.func or M[tool.name]
if on_log then on_log(tool_use.name, "running tool") end
if on_log then on_log(tool.name, "running tool") end
local result, error = func(input_json, function(log)
if on_log then on_log(tool_use.name, log) end
if on_log then on_log(tool.name, log) end
end)
if on_log then on_log(tool_use.name, "tool finished") end
if on_log then on_log(tool.name, "tool finished") end
-- Utils.debug("result", result)
-- Utils.debug("error", error)
if error ~= nil then
if on_log then on_log(tool_use.name, "Error: " .. error) end
if on_log then on_log(tool.name, "Error: " .. error) end
end
if result ~= nil and type(result) ~= "string" then result = vim.json.encode(result) end
return result, error
local result_str ---@type string?
if type(result) == "string" then
result_str = result
elseif result ~= nil then
result_str = vim.json.encode(result)
end
return result_str, error
end
---@param tool_use AvanteLLMToolUse

View File

@@ -2292,8 +2292,7 @@ function Sidebar:create_input_container(opts)
name = "add_file_to_context",
description = "Add a file to the context",
---@param input { rel_path: string }
---@return string | nil result
---@return string | nil error
---@type AvanteLLMToolFunc
func = function(input)
self.file_selector:add_selected_file(input.rel_path)
return "Added file to context", nil
@@ -2309,8 +2308,7 @@ function Sidebar:create_input_container(opts)
name = "remove_file_from_context",
description = "Remove a file from the context",
---@param input { rel_path: string }
---@return string | nil result
---@return string | nil error
---@type AvanteLLMToolFunc
func = function(input)
self.file_selector:remove_selected_file(input.rel_path)
return "Removed file from context", nil

View File

@@ -323,10 +323,12 @@ vim.g.avante_login = vim.g.avante_login
---@field on_stop AvanteLLMStopCallback
---@field on_tool_log? function(tool_name: string, log: string): nil
---
---@alias AvanteLLMToolFunc fun(input: any, on_log?: (fun(log: string): nil) | nil): (boolean | string | nil, string | nil)
---
---@class AvanteLLMTool
---@field name string
---@field description string
---@field func? fun(input: any): (string | nil, string | nil)
---@field func? AvanteLLMToolFunc
---@field param AvanteLLMToolParam
---@field returns AvanteLLMToolReturn[]
---@field enabled? fun(): boolean