refactor: llm tool parameters (#2449)

This commit is contained in:
yetone
2025-07-15 16:40:25 +08:00
committed by GitHub
parent 0c6a8f5688
commit b8bb0fd969
25 changed files with 627 additions and 381 deletions

View File

@@ -8,9 +8,10 @@ local Helpers = require("avante.llm_tools.helpers")
local M = {}
---@type AvanteLLMToolFunc<{ path: string }>
function M.read_file_toplevel_symbols(opts, on_log, on_complete, session_ctx)
function M.read_file_toplevel_symbols(input, opts)
local on_log = opts.on_log
local RepoMap = require("avante.repo_map")
local abs_path = Helpers.get_abs_path(opts.path)
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
if on_log then on_log("path: " .. abs_path) end
if not Path:new(abs_path):exists() then return "", "File does not exists: " .. abs_path end
@@ -24,50 +25,47 @@ function M.read_file_toplevel_symbols(opts, on_log, on_complete, session_ctx)
end
---@type AvanteLLMToolFunc<{ command: "view" | "str_replace" | "create" | "insert" | "undo_edit", path: string, old_str?: string, new_str?: string, file_text?: string, insert_line?: integer, new_str?: string, view_range?: integer[] }>
function M.str_replace_editor(opts, on_log, on_complete, session_ctx)
if opts.command == "undo_edit" then
return require("avante.llm_tools.undo_edit").func(opts, on_log, on_complete, session_ctx)
end
---@cast opts any
return M.str_replace_based_edit_tool(opts, on_log, on_complete, session_ctx)
function M.str_replace_editor(input, opts)
if input.command == "undo_edit" then return require("avante.llm_tools.undo_edit").func(input, opts) end
---@cast input any
return M.str_replace_based_edit_tool(input, opts)
end
---@type AvanteLLMToolFunc<{ command: "view" | "str_replace" | "create" | "insert", path: string, old_str?: string, new_str?: string, file_text?: string, insert_line?: integer, new_str?: string, view_range?: integer[], streaming?: boolean }>
function M.str_replace_based_edit_tool(opts, on_log, on_complete, session_ctx)
if not opts.command then return false, "command not provided" end
if on_log then on_log("command: " .. opts.command) end
function M.str_replace_based_edit_tool(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
if not input.command then return false, "command not provided" end
if on_log then on_log("command: " .. input.command) end
if not on_complete then return false, "on_complete not provided" end
local abs_path = Helpers.get_abs_path(opts.path)
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if opts.command == "view" then
if input.command == "view" then
local view = require("avante.llm_tools.view")
local opts_ = { path = opts.path }
if opts.view_range then
local start_line, end_line = unpack(opts.view_range)
opts_.start_line = start_line
opts_.end_line = end_line
local input_ = { path = input.path }
if input.view_range then
local start_line, end_line = unpack(input.view_range)
input_.start_line = start_line
input_.end_line = end_line
end
return view(opts_, on_log, on_complete, session_ctx)
return view(input_, opts)
end
if opts.command == "str_replace" then
if opts.new_str == nil and opts.file_text ~= nil then
opts.new_str = opts.file_text
opts.file_text = nil
if input.command == "str_replace" then
if input.new_str == nil and input.file_text ~= nil then
input.new_str = input.file_text
input.file_text = nil
end
return require("avante.llm_tools.str_replace").func(opts, on_log, on_complete, session_ctx)
return require("avante.llm_tools.str_replace").func(input, opts)
end
if opts.command == "create" then
return require("avante.llm_tools.create").func(opts, on_log, on_complete, session_ctx)
end
if opts.command == "insert" then
return require("avante.llm_tools.insert").func(opts, on_log, on_complete, session_ctx)
end
return false, "Unknown command: " .. opts.command
if input.command == "create" then return require("avante.llm_tools.create").func(input, opts) end
if input.command == "insert" then return require("avante.llm_tools.insert").func(input, opts) end
return false, "Unknown command: " .. input.command
end
---@type AvanteLLMToolFunc<{ abs_path: string }>
function M.read_global_file(opts, on_log, on_complete, session_ctx)
local abs_path = Helpers.get_abs_path(opts.abs_path)
function M.read_global_file(input, opts)
local on_log = opts.on_log
local abs_path = Helpers.get_abs_path(input.abs_path)
if Helpers.is_ignored(abs_path) then return "", "This file is ignored: " .. abs_path end
if on_log then on_log("path: " .. abs_path) end
local file = io.open(abs_path, "r")
@@ -78,11 +76,13 @@ function M.read_global_file(opts, on_log, on_complete, session_ctx)
end
---@type AvanteLLMToolFunc<{ abs_path: string, content: string }>
function M.write_global_file(opts, on_log, on_complete, session_ctx)
local abs_path = Helpers.get_abs_path(opts.abs_path)
function M.write_global_file(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local abs_path = Helpers.get_abs_path(input.abs_path)
if Helpers.is_ignored(abs_path) then return false, "This file is ignored: " .. abs_path end
if on_log then on_log("path: " .. abs_path) end
if on_log then on_log("content: " .. opts.content) end
if on_log then on_log("content: " .. input.content) end
if not on_complete then return false, "on_complete not provided" end
Helpers.confirm("Are you sure you want to write to the file: " .. abs_path, function(ok)
if not ok then
@@ -94,18 +94,20 @@ function M.write_global_file(opts, on_log, on_complete, session_ctx)
on_complete(false, "file not found: " .. abs_path)
return
end
file:write(opts.content)
file:write(input.content)
file:close()
on_complete(true, nil)
end, nil, session_ctx, "write_global_file")
end, nil, opts.session_ctx, "write_global_file")
end
---@type AvanteLLMToolFunc<{ source_path: string, destination_path: string }>
function M.move_path(opts, on_log, on_complete, session_ctx)
local abs_path = Helpers.get_abs_path(opts.source_path)
function M.move_path(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local abs_path = Helpers.get_abs_path(input.source_path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "The source path not found: " .. abs_path end
local new_abs_path = Helpers.get_abs_path(opts.destination_path)
local new_abs_path = Helpers.get_abs_path(input.destination_path)
if on_log then on_log(abs_path .. " -> " .. new_abs_path) end
if not Helpers.has_permission_to_access(new_abs_path) then
return false, "No permission to access path: " .. new_abs_path
@@ -123,17 +125,19 @@ function M.move_path(opts, on_log, on_complete, session_ctx)
on_complete(true, nil)
end,
nil,
session_ctx,
opts.session_ctx,
"move_path"
)
end
---@type AvanteLLMToolFunc<{ source_path: string, destination_path: string }>
function M.copy_path(opts, on_log, on_complete, session_ctx)
local abs_path = Helpers.get_abs_path(opts.source_path)
function M.copy_path(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local abs_path = Helpers.get_abs_path(input.source_path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "The source path not found: " .. abs_path end
local new_abs_path = Helpers.get_abs_path(opts.destination_path)
local new_abs_path = Helpers.get_abs_path(input.destination_path)
if not Helpers.has_permission_to_access(new_abs_path) then
return false, "No permission to access path: " .. new_abs_path
end
@@ -169,14 +173,16 @@ function M.copy_path(opts, on_log, on_complete, session_ctx)
on_complete(true, nil)
end,
nil,
session_ctx,
opts.session_ctx,
"copy_path"
)
end
---@type AvanteLLMToolFunc<{ path: string }>
function M.delete_path(opts, on_log, on_complete, session_ctx)
local abs_path = Helpers.get_abs_path(opts.path)
function M.delete_path(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end
if not on_complete then return false, "on_complete not provided" end
@@ -188,12 +194,14 @@ function M.delete_path(opts, on_log, on_complete, session_ctx)
if on_log then on_log("Deleting path: " .. abs_path) end
os.remove(abs_path)
on_complete(true, nil)
end, nil, session_ctx, "delete_path")
end, nil, opts.session_ctx, "delete_path")
end
---@type AvanteLLMToolFunc<{ path: string }>
function M.create_dir(opts, on_log, on_complete, session_ctx)
local abs_path = Helpers.get_abs_path(opts.path)
function M.create_dir(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if Path:new(abs_path):exists() then return false, "Directory already exists: " .. abs_path end
if not on_complete then return false, "on_complete not provided" end
@@ -205,16 +213,17 @@ function M.create_dir(opts, on_log, on_complete, session_ctx)
if on_log then on_log("Creating directory: " .. abs_path) end
Path:new(abs_path):mkdir({ parents = true })
on_complete(true, nil)
end, nil, session_ctx, "create_dir")
end, nil, opts.session_ctx, "create_dir")
end
---@type AvanteLLMToolFunc<{ query: string }>
function M.web_search(opts, on_log, on_complete, session_ctx)
function M.web_search(input, opts)
local on_log = opts.on_log
local provider_type = Config.web_search_engine.provider
local proxy = Config.web_search_engine.proxy
if provider_type == nil then return nil, "Search engine provider is not set" end
if on_log then on_log("provider: " .. provider_type) end
if on_log then on_log("query: " .. opts.query) end
if on_log then on_log("query: " .. input.query) end
local search_engine = Config.web_search_engine.providers[provider_type]
if search_engine == nil then return nil, "No search engine found: " .. provider_type end
if provider_type ~= "searxng" and search_engine.api_key_name == "" then return nil, "No API key provided" end
@@ -229,7 +238,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
["Authorization"] = "Bearer " .. api_key,
},
body = vim.json.encode(vim.tbl_deep_extend("force", {
query = opts.query,
query = input.query,
}, search_engine.extra_request_body)),
}
if proxy then curl_opts.proxy = proxy end
@@ -240,7 +249,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
elseif provider_type == "serpapi" then
local query_params = vim.tbl_deep_extend("force", {
api_key = api_key,
q = opts.query,
q = input.query,
}, search_engine.extra_request_body)
local query_string = ""
for key, value in pairs(query_params) do
@@ -259,7 +268,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
elseif provider_type == "searchapi" then
local query_params = vim.tbl_deep_extend("force", {
api_key = api_key,
q = opts.query,
q = input.query,
}, search_engine.extra_request_body)
local query_string = ""
for key, value in pairs(query_params) do
@@ -283,7 +292,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
local query_params = vim.tbl_deep_extend("force", {
key = api_key,
cx = engine_id,
q = opts.query,
q = input.query,
}, search_engine.extra_request_body)
local query_string = ""
for key, value in pairs(query_params) do
@@ -301,7 +310,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
return search_engine.format_response_body(jsn)
elseif provider_type == "kagi" then
local query_params = vim.tbl_deep_extend("force", {
q = opts.query,
q = input.query,
}, search_engine.extra_request_body)
local query_string = ""
for key, value in pairs(query_params) do
@@ -320,7 +329,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
return search_engine.format_response_body(jsn)
elseif provider_type == "brave" then
local query_params = vim.tbl_deep_extend("force", {
q = opts.query,
q = input.query,
}, search_engine.extra_request_body)
local query_string = ""
for key, value in pairs(query_params) do
@@ -343,7 +352,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
return nil, "Environment variable " .. search_engine.api_url_name .. " is not set"
end
local query_params = vim.tbl_deep_extend("force", {
q = opts.query,
q = input.query,
}, search_engine.extra_request_body)
local query_string = ""
for key, value in pairs(query_params) do
@@ -362,16 +371,18 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
end
---@type AvanteLLMToolFunc<{ url: string }>
function M.fetch(opts, on_log, on_complete, session_ctx)
if on_log then on_log("url: " .. opts.url) end
function M.fetch(input, opts)
local on_log = opts.on_log
if on_log then on_log("url: " .. input.url) end
local Html2Md = require("avante.html2md")
local res, err = Html2Md.fetch_md(opts.url)
local res, err = Html2Md.fetch_md(input.url)
if err then return nil, err end
return res, nil
end
---@type AvanteLLMToolFunc<{ scope?: string }>
function M.git_diff(opts, on_log, on_complete, session_ctx)
function M.git_diff(input, opts)
local on_log = opts.on_log
local git_cmd = vim.fn.exepath("git")
if git_cmd == "" then return nil, "Git command not found" end
local project_root = Utils.get_project_root()
@@ -382,7 +393,7 @@ function M.git_diff(opts, on_log, on_complete, session_ctx)
if git_dir == "" then return nil, "Not a git repository" end
-- Get the diff
local scope = opts.scope or ""
local scope = input.scope or ""
local cmd = string.format("git diff --cached %s", scope)
if on_log then on_log("Running command: " .. cmd) end
local diff = vim.fn.system(cmd)
@@ -400,7 +411,10 @@ function M.git_diff(opts, on_log, on_complete, session_ctx)
end
---@type AvanteLLMToolFunc<{ message: string, scope?: string }>
function M.git_commit(opts, on_log, on_complete, session_ctx)
function M.git_commit(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local git_cmd = vim.fn.exepath("git")
if git_cmd == "" then return false, "Git command not found" end
local project_root = Utils.get_project_root()
@@ -444,7 +458,7 @@ function M.git_commit(opts, on_log, on_complete, session_ctx)
-- Prepare commit message
local commit_msg_lines = {}
for line in opts.message:gmatch("[^\r\n]+") do
for line in input.message:gmatch("[^\r\n]+") do
commit_msg_lines[#commit_msg_lines + 1] = line:gsub('"', '\\"')
end
commit_msg_lines[#commit_msg_lines + 1] = ""
@@ -464,8 +478,8 @@ function M.git_commit(opts, on_log, on_complete, session_ctx)
return
end
-- Stage changes if scope is provided
if opts.scope then
local stage_cmd = string.format("git add %s", opts.scope)
if input.scope then
local stage_cmd = string.format("git add %s", input.scope)
if on_log then on_log("Staging files: " .. stage_cmd) end
local stage_result = vim.fn.system(stage_cmd)
if vim.v.shell_error ~= 0 then
@@ -494,20 +508,23 @@ function M.git_commit(opts, on_log, on_complete, session_ctx)
end
on_complete(true, nil)
end, nil, session_ctx, "git_commit")
end, nil, opts.session_ctx, "git_commit")
end
---@type AvanteLLMToolFunc<{ query: string }>
function M.rag_search(opts, on_log, on_complete, session_ctx)
function M.rag_search(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
if not on_complete then return nil, "on_complete not provided" end
if not Config.rag_service.enabled then return nil, "Rag service is not enabled" end
if not opts.query then return nil, "No query provided" end
if on_log then on_log("query: " .. opts.query) end
if not input.query then return nil, "No query provided" end
if on_log then on_log("query: " .. input.query) end
local root = Utils.get_project_root()
local uri = "file://" .. root
if uri:sub(-1) ~= "/" then uri = uri .. "/" end
RagService.retrieve(
uri,
opts.query,
input.query,
vim.schedule_wrap(function(resp, err)
if err then
on_complete(nil, err)
@@ -519,13 +536,15 @@ function M.rag_search(opts, on_log, on_complete, session_ctx)
end
---@type AvanteLLMToolFunc<{ code: string, path: string, container_image?: string }>
function M.python(opts, on_log, on_complete, session_ctx)
local abs_path = Helpers.get_abs_path(opts.path)
function M.python(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return nil, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return nil, "Path not found: " .. abs_path end
if on_log then on_log("cwd: " .. abs_path) end
if on_log then on_log("code:\n" .. opts.code) end
local container_image = opts.container_image or "python:3.11-slim-bookworm"
if on_log then on_log("code:\n" .. input.code) end
local container_image = input.container_image or "python:3.11-slim-bookworm"
if not on_complete then return nil, "on_complete not provided" end
Helpers.confirm(
"Are you sure you want to run the following python code in the `"
@@ -533,7 +552,7 @@ function M.python(opts, on_log, on_complete, session_ctx)
.. "` container, in the directory: `"
.. abs_path
.. "`?\n"
.. opts.code,
.. input.code,
function(ok, reason)
if not ok then
on_complete(nil, "User declined, reason: " .. (reason or "unknown"))
@@ -562,7 +581,7 @@ function M.python(opts, on_log, on_complete, session_ctx)
container_image,
"python",
"-c",
opts.code,
input.code,
},
{
text = true,
@@ -576,7 +595,7 @@ function M.python(opts, on_log, on_complete, session_ctx)
)
end,
nil,
session_ctx,
opts.session_ctx,
"python"
)
end
@@ -1189,14 +1208,19 @@ You can delete the first file by providing a path of "directory1/a/something.txt
--- compatibility alias for old calls & tests
M.run_python = M.python
---@class avante.ProcessToolUseOpts
---@field session_ctx table
---@field on_log? fun(tool_id: string, tool_name: string, log: string, state: AvanteLLMToolUseState): nil
---@field set_tool_use_store? fun(tool_id: string, key: string, value: any): nil
---@field on_complete? fun(result: string | nil, error: string | nil): nil
---@param tools AvanteLLMTool[]
---@param tool_use AvanteLLMToolUse
---@param on_log? fun(tool_id: string, tool_name: string, log: string, state: AvanteLLMToolUseState): nil
---@param on_complete? fun(result: string | nil, error: string | nil): nil
---@param session_ctx? table
---@return string | nil result
---@return string | nil error
function M.process_tool_use(tools, tool_use, on_log, on_complete, session_ctx)
function M.process_tool_use(tools, tool_use, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
-- Check if execution is already cancelled
if Helpers.is_cancelled then
Utils.debug("Tool execution cancelled before starting: " .. tool_use.name)
@@ -1274,25 +1298,32 @@ function M.process_tool_use(tools, tool_use, on_log, on_complete, session_ctx)
return result_str, err
end
local result, err = func(input_json, function(log)
-- Check for cancellation during logging
if Helpers.is_cancelled then return end
if on_log then on_log(tool_use.id, tool_use.name, log, "running") end
end, function(result, err)
-- Check for cancellation before completing
if Helpers.is_cancelled then
Helpers.is_cancelled = false
if on_complete then on_complete(nil, Helpers.CANCEL_TOKEN) end
return
end
local result, err = func(input_json, {
session_ctx = opts.session_ctx or {},
on_log = function(log)
-- Check for cancellation during logging
if Helpers.is_cancelled then return end
if on_log then on_log(tool_use.id, tool_use.name, log, "running") end
end,
set_store = function(key, value)
if opts.set_tool_use_store then opts.set_tool_use_store(tool_use.id, key, value) end
end,
on_complete = function(result, err)
-- Check for cancellation before completing
if Helpers.is_cancelled then
Helpers.is_cancelled = false
if on_complete then on_complete(nil, Helpers.CANCEL_TOKEN) end
return
end
result, err = handle_result(result, err)
if on_complete == nil then
Utils.error("asynchronous tool " .. tool_use.name .. " result not handled")
return
end
on_complete(result, err)
end, session_ctx)
result, err = handle_result(result, err)
if on_complete == nil then
Utils.error("asynchronous tool " .. tool_use.name .. " result not handled")
return
end
on_complete(result, err)
end,
})
-- Result and error being nil means that the tool was executed asynchronously
if result == nil and err == nil and on_complete then return end