From b627b335dd5fff573c15e694230e7767d4e2b284 Mon Sep 17 00:00:00 2001 From: yetone Date: Mon, 24 Feb 2025 21:54:36 +0800 Subject: [PATCH] feat: add generic type for llm func (#1373) --- lua/avante/llm_tools.lua | 57 ++++++++++++++-------------------------- lua/avante/sidebar.lua | 3 +-- lua/avante/types.lua | 2 +- 3 files changed, 21 insertions(+), 41 deletions(-) diff --git a/lua/avante/llm_tools.lua b/lua/avante/llm_tools.lua index d4988a9..2e837f2 100644 --- a/lua/avante/llm_tools.lua +++ b/lua/avante/llm_tools.lua @@ -33,8 +33,7 @@ local function has_permission_to_access(abs_path) return not Utils.is_ignored(abs_path, gitignore_patterns, gitignore_negate_patterns) end ----@param opts { rel_path: string, max_depth?: integer } ----@type AvanteLLMToolFunc +---@type AvanteLLMToolFunc<{ rel_path: string, max_depth?: integer }> function M.list_files(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end @@ -53,8 +52,7 @@ function M.list_files(opts, on_log) return vim.json.encode(filepaths), nil end ----@param opts { rel_path: string, keyword: string } ----@type AvanteLLMToolFunc +---@type AvanteLLMToolFunc<{ rel_path: string, keyword: string }> function M.search_files(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end @@ -70,8 +68,7 @@ function M.search_files(opts, on_log) return vim.json.encode(filepaths), nil end ----@param opts { rel_path: string, keyword: string } ----@type AvanteLLMToolFunc +---@type AvanteLLMToolFunc<{ rel_path: string, keyword: string }> function M.search(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end @@ -107,8 +104,7 @@ function M.search(opts, on_log) return vim.json.encode(filepaths), nil end ----@param opts { rel_path: string } ----@type AvanteLLMToolFunc +---@type AvanteLLMToolFunc<{ rel_path: string }> function M.read_file_toplevel_symbols(opts, on_log) local RepoMap = require("avante.repo_map") local abs_path = get_abs_path(opts.rel_path) @@ -124,8 +120,7 @@ function M.read_file_toplevel_symbols(opts, on_log) return definitions, nil end ----@param opts { rel_path: string } ----@type AvanteLLMToolFunc +---@type AvanteLLMToolFunc<{ rel_path: string }> function M.read_file(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end @@ -137,8 +132,7 @@ function M.read_file(opts, on_log) return content, nil end ----@param opts { rel_path: string } ----@type AvanteLLMToolFunc +---@type AvanteLLMToolFunc<{ rel_path: string }> function M.create_file(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end @@ -156,8 +150,7 @@ function M.create_file(opts, on_log) return true, nil end ----@param opts { rel_path: string, new_rel_path: string } ----@type AvanteLLMToolFunc +---@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }> function M.rename_file(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end @@ -174,8 +167,7 @@ function M.rename_file(opts, on_log) return true, nil end ----@param opts { rel_path: string, new_rel_path: string } ----@type AvanteLLMToolFunc +---@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }> function M.copy_file(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end @@ -189,8 +181,7 @@ function M.copy_file(opts, on_log) return true, nil end ----@param opts { rel_path: string } ----@type AvanteLLMToolFunc +---@type AvanteLLMToolFunc<{ rel_path: string }> function M.delete_file(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end @@ -202,8 +193,7 @@ function M.delete_file(opts, on_log) return true, nil end ----@param opts { rel_path: string } ----@type AvanteLLMToolFunc +---@type AvanteLLMToolFunc<{ rel_path: string }> function M.create_dir(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end @@ -216,8 +206,7 @@ function M.create_dir(opts, on_log) return true, nil end ----@param opts { rel_path: string, new_rel_path: string } ----@type AvanteLLMToolFunc +---@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }> function M.rename_dir(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end @@ -234,8 +223,7 @@ function M.rename_dir(opts, on_log) return true, nil end ----@param opts { rel_path: string } ----@type AvanteLLMToolFunc +---@type AvanteLLMToolFunc<{ rel_path: string }> function M.delete_dir(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end @@ -249,8 +237,7 @@ function M.delete_dir(opts, on_log) return true, nil end ----@param opts { rel_path: string, command: string } ----@type AvanteLLMToolFunc +---@type AvanteLLMToolFunc<{ rel_path: string, command: string }> function M.run_command(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end @@ -273,8 +260,7 @@ function M.run_command(opts, on_log) return res.stdout, nil end ----@param opts { query: string } ----@type AvanteLLMToolFunc +---@type AvanteLLMToolFunc<{ query: string }> function M.web_search(opts, on_log) local provider_type = Config.web_search_engine.provider if provider_type == nil then return nil, "Search engine provider is not set" end @@ -376,8 +362,7 @@ function M.web_search(opts, on_log) end end ----@param opts { url: string } ----@type AvanteLLMToolFunc +---@type AvanteLLMToolFunc<{ url: string }> function M.fetch(opts, on_log) if on_log then on_log("url: " .. opts.url) end local Html2Md = require("avante.html2md") @@ -386,8 +371,7 @@ function M.fetch(opts, on_log) return res, nil end ----@param opts { scope?: string } ----@type AvanteLLMToolFunc +---@type AvanteLLMToolFunc<{ scope?: string }> function M.git_diff(opts, on_log) local git_cmd = vim.fn.exepath("git") if git_cmd == "" then return nil, "Git command not found" end @@ -416,8 +400,7 @@ function M.git_diff(opts, on_log) return diff, nil end ----@param opts { message: string, scope?: string } ----@type AvanteLLMToolFunc +---@type AvanteLLMToolFunc<{ message: string, scope?: string }> function M.git_commit(opts, on_log) local git_cmd = vim.fn.exepath("git") if git_cmd == "" then return false, "Git command not found" end @@ -504,8 +487,7 @@ function M.git_commit(opts, on_log) return true, nil end ----@param opts { query: string } ----@type AvanteLLMToolFunc +---@type AvanteLLMToolFunc<{ query: string }> function M.rag_search(opts, on_log) if not Config.rag_service.enabled then return nil, "Rag service is not enabled" end if not opts.query then return nil, "No query provided" end @@ -518,8 +500,7 @@ function M.rag_search(opts, on_log) return vim.json.encode(resp), nil end ----@param opts { code: string, rel_path: string, container_image?: string } ----@type AvanteLLMToolFunc +---@type AvanteLLMToolFunc<{ code: string, rel_path: string, container_image?: string }> function M.python(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return nil, "No permission to access path: " .. abs_path end diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 2001483..7049ee0 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -2291,8 +2291,7 @@ function Sidebar:create_input_container(opts) table.insert(tools, { name = "add_file_to_context", description = "Add a file to the context", - ---@param input { rel_path: string } - ---@type AvanteLLMToolFunc + ---@type AvanteLLMToolFunc<{ rel_path: string }> func = function(input) self.file_selector:add_selected_file(input.rel_path) return "Added file to context", nil diff --git a/lua/avante/types.lua b/lua/avante/types.lua index 79ae650..04ed5dc 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -323,7 +323,7 @@ vim.g.avante_login = vim.g.avante_login ---@field on_stop AvanteLLMStopCallback ---@field on_tool_log? function(tool_name: string, log: string): nil --- ----@alias AvanteLLMToolFunc fun(input: any, on_log?: (fun(log: string): nil) | nil): (boolean | string | nil, string | nil) +---@alias AvanteLLMToolFunc fun(input: T, on_log?: (fun(log: string): nil) | nil): (boolean | string | nil, string | nil) --- ---@class AvanteLLMTool ---@field name string