feat: add generic type for llm func (#1373)

This commit is contained in:
yetone
2025-02-24 21:54:36 +08:00
committed by GitHub
parent fe496a9573
commit b627b335dd
3 changed files with 21 additions and 41 deletions

View File

@@ -33,8 +33,7 @@ local function has_permission_to_access(abs_path)
return not Utils.is_ignored(abs_path, gitignore_patterns, gitignore_negate_patterns)
end
---@param opts { rel_path: string, max_depth?: integer }
---@type AvanteLLMToolFunc
---@type AvanteLLMToolFunc<{ rel_path: string, max_depth?: integer }>
function M.list_files(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
@@ -53,8 +52,7 @@ function M.list_files(opts, on_log)
return vim.json.encode(filepaths), nil
end
---@param opts { rel_path: string, keyword: string }
---@type AvanteLLMToolFunc
---@type AvanteLLMToolFunc<{ rel_path: string, keyword: string }>
function M.search_files(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
@@ -70,8 +68,7 @@ function M.search_files(opts, on_log)
return vim.json.encode(filepaths), nil
end
---@param opts { rel_path: string, keyword: string }
---@type AvanteLLMToolFunc
---@type AvanteLLMToolFunc<{ rel_path: string, keyword: string }>
function M.search(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
@@ -107,8 +104,7 @@ function M.search(opts, on_log)
return vim.json.encode(filepaths), nil
end
---@param opts { rel_path: string }
---@type AvanteLLMToolFunc
---@type AvanteLLMToolFunc<{ rel_path: string }>
function M.read_file_toplevel_symbols(opts, on_log)
local RepoMap = require("avante.repo_map")
local abs_path = get_abs_path(opts.rel_path)
@@ -124,8 +120,7 @@ function M.read_file_toplevel_symbols(opts, on_log)
return definitions, nil
end
---@param opts { rel_path: string }
---@type AvanteLLMToolFunc
---@type AvanteLLMToolFunc<{ rel_path: string }>
function M.read_file(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
@@ -137,8 +132,7 @@ function M.read_file(opts, on_log)
return content, nil
end
---@param opts { rel_path: string }
---@type AvanteLLMToolFunc
---@type AvanteLLMToolFunc<{ rel_path: string }>
function M.create_file(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
@@ -156,8 +150,7 @@ function M.create_file(opts, on_log)
return true, nil
end
---@param opts { rel_path: string, new_rel_path: string }
---@type AvanteLLMToolFunc
---@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }>
function M.rename_file(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
@@ -174,8 +167,7 @@ function M.rename_file(opts, on_log)
return true, nil
end
---@param opts { rel_path: string, new_rel_path: string }
---@type AvanteLLMToolFunc
---@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }>
function M.copy_file(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
@@ -189,8 +181,7 @@ function M.copy_file(opts, on_log)
return true, nil
end
---@param opts { rel_path: string }
---@type AvanteLLMToolFunc
---@type AvanteLLMToolFunc<{ rel_path: string }>
function M.delete_file(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
@@ -202,8 +193,7 @@ function M.delete_file(opts, on_log)
return true, nil
end
---@param opts { rel_path: string }
---@type AvanteLLMToolFunc
---@type AvanteLLMToolFunc<{ rel_path: string }>
function M.create_dir(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
@@ -216,8 +206,7 @@ function M.create_dir(opts, on_log)
return true, nil
end
---@param opts { rel_path: string, new_rel_path: string }
---@type AvanteLLMToolFunc
---@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }>
function M.rename_dir(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
@@ -234,8 +223,7 @@ function M.rename_dir(opts, on_log)
return true, nil
end
---@param opts { rel_path: string }
---@type AvanteLLMToolFunc
---@type AvanteLLMToolFunc<{ rel_path: string }>
function M.delete_dir(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
@@ -249,8 +237,7 @@ function M.delete_dir(opts, on_log)
return true, nil
end
---@param opts { rel_path: string, command: string }
---@type AvanteLLMToolFunc
---@type AvanteLLMToolFunc<{ rel_path: string, command: string }>
function M.run_command(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
@@ -273,8 +260,7 @@ function M.run_command(opts, on_log)
return res.stdout, nil
end
---@param opts { query: string }
---@type AvanteLLMToolFunc
---@type AvanteLLMToolFunc<{ query: string }>
function M.web_search(opts, on_log)
local provider_type = Config.web_search_engine.provider
if provider_type == nil then return nil, "Search engine provider is not set" end
@@ -376,8 +362,7 @@ function M.web_search(opts, on_log)
end
end
---@param opts { url: string }
---@type AvanteLLMToolFunc
---@type AvanteLLMToolFunc<{ url: string }>
function M.fetch(opts, on_log)
if on_log then on_log("url: " .. opts.url) end
local Html2Md = require("avante.html2md")
@@ -386,8 +371,7 @@ function M.fetch(opts, on_log)
return res, nil
end
---@param opts { scope?: string }
---@type AvanteLLMToolFunc
---@type AvanteLLMToolFunc<{ scope?: string }>
function M.git_diff(opts, on_log)
local git_cmd = vim.fn.exepath("git")
if git_cmd == "" then return nil, "Git command not found" end
@@ -416,8 +400,7 @@ function M.git_diff(opts, on_log)
return diff, nil
end
---@param opts { message: string, scope?: string }
---@type AvanteLLMToolFunc
---@type AvanteLLMToolFunc<{ message: string, scope?: string }>
function M.git_commit(opts, on_log)
local git_cmd = vim.fn.exepath("git")
if git_cmd == "" then return false, "Git command not found" end
@@ -504,8 +487,7 @@ function M.git_commit(opts, on_log)
return true, nil
end
---@param opts { query: string }
---@type AvanteLLMToolFunc
---@type AvanteLLMToolFunc<{ query: string }>
function M.rag_search(opts, on_log)
if not Config.rag_service.enabled then return nil, "Rag service is not enabled" end
if not opts.query then return nil, "No query provided" end
@@ -518,8 +500,7 @@ function M.rag_search(opts, on_log)
return vim.json.encode(resp), nil
end
---@param opts { code: string, rel_path: string, container_image?: string }
---@type AvanteLLMToolFunc
---@type AvanteLLMToolFunc<{ code: string, rel_path: string, container_image?: string }>
function M.python(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return nil, "No permission to access path: " .. abs_path end

View File

@@ -2291,8 +2291,7 @@ function Sidebar:create_input_container(opts)
table.insert(tools, {
name = "add_file_to_context",
description = "Add a file to the context",
---@param input { rel_path: string }
---@type AvanteLLMToolFunc
---@type AvanteLLMToolFunc<{ rel_path: string }>
func = function(input)
self.file_selector:add_selected_file(input.rel_path)
return "Added file to context", nil

View File

@@ -323,7 +323,7 @@ vim.g.avante_login = vim.g.avante_login
---@field on_stop AvanteLLMStopCallback
---@field on_tool_log? function(tool_name: string, log: string): nil
---
---@alias AvanteLLMToolFunc fun(input: any, on_log?: (fun(log: string): nil) | nil): (boolean | string | nil, string | nil)
---@alias AvanteLLMToolFunc<T> fun(input: T, on_log?: (fun(log: string): nil) | nil): (boolean | string | nil, string | nil)
---
---@class AvanteLLMTool
---@field name string