refactor: llm tool parameters (#2449)
This commit is contained in:
@@ -8,9 +8,10 @@ local Helpers = require("avante.llm_tools.helpers")
|
||||
local M = {}
|
||||
|
||||
---@type AvanteLLMToolFunc<{ path: string }>
|
||||
function M.read_file_toplevel_symbols(opts, on_log, on_complete, session_ctx)
|
||||
function M.read_file_toplevel_symbols(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local RepoMap = require("avante.repo_map")
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
local abs_path = Helpers.get_abs_path(input.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
|
||||
if on_log then on_log("path: " .. abs_path) end
|
||||
if not Path:new(abs_path):exists() then return "", "File does not exists: " .. abs_path end
|
||||
@@ -24,50 +25,47 @@ function M.read_file_toplevel_symbols(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ command: "view" | "str_replace" | "create" | "insert" | "undo_edit", path: string, old_str?: string, new_str?: string, file_text?: string, insert_line?: integer, new_str?: string, view_range?: integer[] }>
|
||||
function M.str_replace_editor(opts, on_log, on_complete, session_ctx)
|
||||
if opts.command == "undo_edit" then
|
||||
return require("avante.llm_tools.undo_edit").func(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
---@cast opts any
|
||||
return M.str_replace_based_edit_tool(opts, on_log, on_complete, session_ctx)
|
||||
function M.str_replace_editor(input, opts)
|
||||
if input.command == "undo_edit" then return require("avante.llm_tools.undo_edit").func(input, opts) end
|
||||
---@cast input any
|
||||
return M.str_replace_based_edit_tool(input, opts)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ command: "view" | "str_replace" | "create" | "insert", path: string, old_str?: string, new_str?: string, file_text?: string, insert_line?: integer, new_str?: string, view_range?: integer[], streaming?: boolean }>
|
||||
function M.str_replace_based_edit_tool(opts, on_log, on_complete, session_ctx)
|
||||
if not opts.command then return false, "command not provided" end
|
||||
if on_log then on_log("command: " .. opts.command) end
|
||||
function M.str_replace_based_edit_tool(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
if not input.command then return false, "command not provided" end
|
||||
if on_log then on_log("command: " .. input.command) end
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
local abs_path = Helpers.get_abs_path(input.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if opts.command == "view" then
|
||||
if input.command == "view" then
|
||||
local view = require("avante.llm_tools.view")
|
||||
local opts_ = { path = opts.path }
|
||||
if opts.view_range then
|
||||
local start_line, end_line = unpack(opts.view_range)
|
||||
opts_.start_line = start_line
|
||||
opts_.end_line = end_line
|
||||
local input_ = { path = input.path }
|
||||
if input.view_range then
|
||||
local start_line, end_line = unpack(input.view_range)
|
||||
input_.start_line = start_line
|
||||
input_.end_line = end_line
|
||||
end
|
||||
return view(opts_, on_log, on_complete, session_ctx)
|
||||
return view(input_, opts)
|
||||
end
|
||||
if opts.command == "str_replace" then
|
||||
if opts.new_str == nil and opts.file_text ~= nil then
|
||||
opts.new_str = opts.file_text
|
||||
opts.file_text = nil
|
||||
if input.command == "str_replace" then
|
||||
if input.new_str == nil and input.file_text ~= nil then
|
||||
input.new_str = input.file_text
|
||||
input.file_text = nil
|
||||
end
|
||||
return require("avante.llm_tools.str_replace").func(opts, on_log, on_complete, session_ctx)
|
||||
return require("avante.llm_tools.str_replace").func(input, opts)
|
||||
end
|
||||
if opts.command == "create" then
|
||||
return require("avante.llm_tools.create").func(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
if opts.command == "insert" then
|
||||
return require("avante.llm_tools.insert").func(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
return false, "Unknown command: " .. opts.command
|
||||
if input.command == "create" then return require("avante.llm_tools.create").func(input, opts) end
|
||||
if input.command == "insert" then return require("avante.llm_tools.insert").func(input, opts) end
|
||||
return false, "Unknown command: " .. input.command
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ abs_path: string }>
|
||||
function M.read_global_file(opts, on_log, on_complete, session_ctx)
|
||||
local abs_path = Helpers.get_abs_path(opts.abs_path)
|
||||
function M.read_global_file(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local abs_path = Helpers.get_abs_path(input.abs_path)
|
||||
if Helpers.is_ignored(abs_path) then return "", "This file is ignored: " .. abs_path end
|
||||
if on_log then on_log("path: " .. abs_path) end
|
||||
local file = io.open(abs_path, "r")
|
||||
@@ -78,11 +76,13 @@ function M.read_global_file(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ abs_path: string, content: string }>
|
||||
function M.write_global_file(opts, on_log, on_complete, session_ctx)
|
||||
local abs_path = Helpers.get_abs_path(opts.abs_path)
|
||||
function M.write_global_file(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
local abs_path = Helpers.get_abs_path(input.abs_path)
|
||||
if Helpers.is_ignored(abs_path) then return false, "This file is ignored: " .. abs_path end
|
||||
if on_log then on_log("path: " .. abs_path) end
|
||||
if on_log then on_log("content: " .. opts.content) end
|
||||
if on_log then on_log("content: " .. input.content) end
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
Helpers.confirm("Are you sure you want to write to the file: " .. abs_path, function(ok)
|
||||
if not ok then
|
||||
@@ -94,18 +94,20 @@ function M.write_global_file(opts, on_log, on_complete, session_ctx)
|
||||
on_complete(false, "file not found: " .. abs_path)
|
||||
return
|
||||
end
|
||||
file:write(opts.content)
|
||||
file:write(input.content)
|
||||
file:close()
|
||||
on_complete(true, nil)
|
||||
end, nil, session_ctx, "write_global_file")
|
||||
end, nil, opts.session_ctx, "write_global_file")
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ source_path: string, destination_path: string }>
|
||||
function M.move_path(opts, on_log, on_complete, session_ctx)
|
||||
local abs_path = Helpers.get_abs_path(opts.source_path)
|
||||
function M.move_path(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
local abs_path = Helpers.get_abs_path(input.source_path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if not Path:new(abs_path):exists() then return false, "The source path not found: " .. abs_path end
|
||||
local new_abs_path = Helpers.get_abs_path(opts.destination_path)
|
||||
local new_abs_path = Helpers.get_abs_path(input.destination_path)
|
||||
if on_log then on_log(abs_path .. " -> " .. new_abs_path) end
|
||||
if not Helpers.has_permission_to_access(new_abs_path) then
|
||||
return false, "No permission to access path: " .. new_abs_path
|
||||
@@ -123,17 +125,19 @@ function M.move_path(opts, on_log, on_complete, session_ctx)
|
||||
on_complete(true, nil)
|
||||
end,
|
||||
nil,
|
||||
session_ctx,
|
||||
opts.session_ctx,
|
||||
"move_path"
|
||||
)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ source_path: string, destination_path: string }>
|
||||
function M.copy_path(opts, on_log, on_complete, session_ctx)
|
||||
local abs_path = Helpers.get_abs_path(opts.source_path)
|
||||
function M.copy_path(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
local abs_path = Helpers.get_abs_path(input.source_path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if not Path:new(abs_path):exists() then return false, "The source path not found: " .. abs_path end
|
||||
local new_abs_path = Helpers.get_abs_path(opts.destination_path)
|
||||
local new_abs_path = Helpers.get_abs_path(input.destination_path)
|
||||
if not Helpers.has_permission_to_access(new_abs_path) then
|
||||
return false, "No permission to access path: " .. new_abs_path
|
||||
end
|
||||
@@ -169,14 +173,16 @@ function M.copy_path(opts, on_log, on_complete, session_ctx)
|
||||
on_complete(true, nil)
|
||||
end,
|
||||
nil,
|
||||
session_ctx,
|
||||
opts.session_ctx,
|
||||
"copy_path"
|
||||
)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ path: string }>
|
||||
function M.delete_path(opts, on_log, on_complete, session_ctx)
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
function M.delete_path(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
local abs_path = Helpers.get_abs_path(input.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
@@ -188,12 +194,14 @@ function M.delete_path(opts, on_log, on_complete, session_ctx)
|
||||
if on_log then on_log("Deleting path: " .. abs_path) end
|
||||
os.remove(abs_path)
|
||||
on_complete(true, nil)
|
||||
end, nil, session_ctx, "delete_path")
|
||||
end, nil, opts.session_ctx, "delete_path")
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ path: string }>
|
||||
function M.create_dir(opts, on_log, on_complete, session_ctx)
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
function M.create_dir(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
local abs_path = Helpers.get_abs_path(input.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if Path:new(abs_path):exists() then return false, "Directory already exists: " .. abs_path end
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
@@ -205,16 +213,17 @@ function M.create_dir(opts, on_log, on_complete, session_ctx)
|
||||
if on_log then on_log("Creating directory: " .. abs_path) end
|
||||
Path:new(abs_path):mkdir({ parents = true })
|
||||
on_complete(true, nil)
|
||||
end, nil, session_ctx, "create_dir")
|
||||
end, nil, opts.session_ctx, "create_dir")
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ query: string }>
|
||||
function M.web_search(opts, on_log, on_complete, session_ctx)
|
||||
function M.web_search(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local provider_type = Config.web_search_engine.provider
|
||||
local proxy = Config.web_search_engine.proxy
|
||||
if provider_type == nil then return nil, "Search engine provider is not set" end
|
||||
if on_log then on_log("provider: " .. provider_type) end
|
||||
if on_log then on_log("query: " .. opts.query) end
|
||||
if on_log then on_log("query: " .. input.query) end
|
||||
local search_engine = Config.web_search_engine.providers[provider_type]
|
||||
if search_engine == nil then return nil, "No search engine found: " .. provider_type end
|
||||
if provider_type ~= "searxng" and search_engine.api_key_name == "" then return nil, "No API key provided" end
|
||||
@@ -229,7 +238,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
|
||||
["Authorization"] = "Bearer " .. api_key,
|
||||
},
|
||||
body = vim.json.encode(vim.tbl_deep_extend("force", {
|
||||
query = opts.query,
|
||||
query = input.query,
|
||||
}, search_engine.extra_request_body)),
|
||||
}
|
||||
if proxy then curl_opts.proxy = proxy end
|
||||
@@ -240,7 +249,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
|
||||
elseif provider_type == "serpapi" then
|
||||
local query_params = vim.tbl_deep_extend("force", {
|
||||
api_key = api_key,
|
||||
q = opts.query,
|
||||
q = input.query,
|
||||
}, search_engine.extra_request_body)
|
||||
local query_string = ""
|
||||
for key, value in pairs(query_params) do
|
||||
@@ -259,7 +268,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
|
||||
elseif provider_type == "searchapi" then
|
||||
local query_params = vim.tbl_deep_extend("force", {
|
||||
api_key = api_key,
|
||||
q = opts.query,
|
||||
q = input.query,
|
||||
}, search_engine.extra_request_body)
|
||||
local query_string = ""
|
||||
for key, value in pairs(query_params) do
|
||||
@@ -283,7 +292,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
|
||||
local query_params = vim.tbl_deep_extend("force", {
|
||||
key = api_key,
|
||||
cx = engine_id,
|
||||
q = opts.query,
|
||||
q = input.query,
|
||||
}, search_engine.extra_request_body)
|
||||
local query_string = ""
|
||||
for key, value in pairs(query_params) do
|
||||
@@ -301,7 +310,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
|
||||
return search_engine.format_response_body(jsn)
|
||||
elseif provider_type == "kagi" then
|
||||
local query_params = vim.tbl_deep_extend("force", {
|
||||
q = opts.query,
|
||||
q = input.query,
|
||||
}, search_engine.extra_request_body)
|
||||
local query_string = ""
|
||||
for key, value in pairs(query_params) do
|
||||
@@ -320,7 +329,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
|
||||
return search_engine.format_response_body(jsn)
|
||||
elseif provider_type == "brave" then
|
||||
local query_params = vim.tbl_deep_extend("force", {
|
||||
q = opts.query,
|
||||
q = input.query,
|
||||
}, search_engine.extra_request_body)
|
||||
local query_string = ""
|
||||
for key, value in pairs(query_params) do
|
||||
@@ -343,7 +352,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
|
||||
return nil, "Environment variable " .. search_engine.api_url_name .. " is not set"
|
||||
end
|
||||
local query_params = vim.tbl_deep_extend("force", {
|
||||
q = opts.query,
|
||||
q = input.query,
|
||||
}, search_engine.extra_request_body)
|
||||
local query_string = ""
|
||||
for key, value in pairs(query_params) do
|
||||
@@ -362,16 +371,18 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ url: string }>
|
||||
function M.fetch(opts, on_log, on_complete, session_ctx)
|
||||
if on_log then on_log("url: " .. opts.url) end
|
||||
function M.fetch(input, opts)
|
||||
local on_log = opts.on_log
|
||||
if on_log then on_log("url: " .. input.url) end
|
||||
local Html2Md = require("avante.html2md")
|
||||
local res, err = Html2Md.fetch_md(opts.url)
|
||||
local res, err = Html2Md.fetch_md(input.url)
|
||||
if err then return nil, err end
|
||||
return res, nil
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ scope?: string }>
|
||||
function M.git_diff(opts, on_log, on_complete, session_ctx)
|
||||
function M.git_diff(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local git_cmd = vim.fn.exepath("git")
|
||||
if git_cmd == "" then return nil, "Git command not found" end
|
||||
local project_root = Utils.get_project_root()
|
||||
@@ -382,7 +393,7 @@ function M.git_diff(opts, on_log, on_complete, session_ctx)
|
||||
if git_dir == "" then return nil, "Not a git repository" end
|
||||
|
||||
-- Get the diff
|
||||
local scope = opts.scope or ""
|
||||
local scope = input.scope or ""
|
||||
local cmd = string.format("git diff --cached %s", scope)
|
||||
if on_log then on_log("Running command: " .. cmd) end
|
||||
local diff = vim.fn.system(cmd)
|
||||
@@ -400,7 +411,10 @@ function M.git_diff(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ message: string, scope?: string }>
|
||||
function M.git_commit(opts, on_log, on_complete, session_ctx)
|
||||
function M.git_commit(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
|
||||
local git_cmd = vim.fn.exepath("git")
|
||||
if git_cmd == "" then return false, "Git command not found" end
|
||||
local project_root = Utils.get_project_root()
|
||||
@@ -444,7 +458,7 @@ function M.git_commit(opts, on_log, on_complete, session_ctx)
|
||||
|
||||
-- Prepare commit message
|
||||
local commit_msg_lines = {}
|
||||
for line in opts.message:gmatch("[^\r\n]+") do
|
||||
for line in input.message:gmatch("[^\r\n]+") do
|
||||
commit_msg_lines[#commit_msg_lines + 1] = line:gsub('"', '\\"')
|
||||
end
|
||||
commit_msg_lines[#commit_msg_lines + 1] = ""
|
||||
@@ -464,8 +478,8 @@ function M.git_commit(opts, on_log, on_complete, session_ctx)
|
||||
return
|
||||
end
|
||||
-- Stage changes if scope is provided
|
||||
if opts.scope then
|
||||
local stage_cmd = string.format("git add %s", opts.scope)
|
||||
if input.scope then
|
||||
local stage_cmd = string.format("git add %s", input.scope)
|
||||
if on_log then on_log("Staging files: " .. stage_cmd) end
|
||||
local stage_result = vim.fn.system(stage_cmd)
|
||||
if vim.v.shell_error ~= 0 then
|
||||
@@ -494,20 +508,23 @@ function M.git_commit(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
|
||||
on_complete(true, nil)
|
||||
end, nil, session_ctx, "git_commit")
|
||||
end, nil, opts.session_ctx, "git_commit")
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ query: string }>
|
||||
function M.rag_search(opts, on_log, on_complete, session_ctx)
|
||||
function M.rag_search(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
if not on_complete then return nil, "on_complete not provided" end
|
||||
if not Config.rag_service.enabled then return nil, "Rag service is not enabled" end
|
||||
if not opts.query then return nil, "No query provided" end
|
||||
if on_log then on_log("query: " .. opts.query) end
|
||||
if not input.query then return nil, "No query provided" end
|
||||
if on_log then on_log("query: " .. input.query) end
|
||||
local root = Utils.get_project_root()
|
||||
local uri = "file://" .. root
|
||||
if uri:sub(-1) ~= "/" then uri = uri .. "/" end
|
||||
RagService.retrieve(
|
||||
uri,
|
||||
opts.query,
|
||||
input.query,
|
||||
vim.schedule_wrap(function(resp, err)
|
||||
if err then
|
||||
on_complete(nil, err)
|
||||
@@ -519,13 +536,15 @@ function M.rag_search(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ code: string, path: string, container_image?: string }>
|
||||
function M.python(opts, on_log, on_complete, session_ctx)
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
function M.python(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
local abs_path = Helpers.get_abs_path(input.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return nil, "No permission to access path: " .. abs_path end
|
||||
if not Path:new(abs_path):exists() then return nil, "Path not found: " .. abs_path end
|
||||
if on_log then on_log("cwd: " .. abs_path) end
|
||||
if on_log then on_log("code:\n" .. opts.code) end
|
||||
local container_image = opts.container_image or "python:3.11-slim-bookworm"
|
||||
if on_log then on_log("code:\n" .. input.code) end
|
||||
local container_image = input.container_image or "python:3.11-slim-bookworm"
|
||||
if not on_complete then return nil, "on_complete not provided" end
|
||||
Helpers.confirm(
|
||||
"Are you sure you want to run the following python code in the `"
|
||||
@@ -533,7 +552,7 @@ function M.python(opts, on_log, on_complete, session_ctx)
|
||||
.. "` container, in the directory: `"
|
||||
.. abs_path
|
||||
.. "`?\n"
|
||||
.. opts.code,
|
||||
.. input.code,
|
||||
function(ok, reason)
|
||||
if not ok then
|
||||
on_complete(nil, "User declined, reason: " .. (reason or "unknown"))
|
||||
@@ -562,7 +581,7 @@ function M.python(opts, on_log, on_complete, session_ctx)
|
||||
container_image,
|
||||
"python",
|
||||
"-c",
|
||||
opts.code,
|
||||
input.code,
|
||||
},
|
||||
{
|
||||
text = true,
|
||||
@@ -576,7 +595,7 @@ function M.python(opts, on_log, on_complete, session_ctx)
|
||||
)
|
||||
end,
|
||||
nil,
|
||||
session_ctx,
|
||||
opts.session_ctx,
|
||||
"python"
|
||||
)
|
||||
end
|
||||
@@ -1189,14 +1208,19 @@ You can delete the first file by providing a path of "directory1/a/something.txt
|
||||
--- compatibility alias for old calls & tests
|
||||
M.run_python = M.python
|
||||
|
||||
---@class avante.ProcessToolUseOpts
|
||||
---@field session_ctx table
|
||||
---@field on_log? fun(tool_id: string, tool_name: string, log: string, state: AvanteLLMToolUseState): nil
|
||||
---@field set_tool_use_store? fun(tool_id: string, key: string, value: any): nil
|
||||
---@field on_complete? fun(result: string | nil, error: string | nil): nil
|
||||
|
||||
---@param tools AvanteLLMTool[]
|
||||
---@param tool_use AvanteLLMToolUse
|
||||
---@param on_log? fun(tool_id: string, tool_name: string, log: string, state: AvanteLLMToolUseState): nil
|
||||
---@param on_complete? fun(result: string | nil, error: string | nil): nil
|
||||
---@param session_ctx? table
|
||||
---@return string | nil result
|
||||
---@return string | nil error
|
||||
function M.process_tool_use(tools, tool_use, on_log, on_complete, session_ctx)
|
||||
function M.process_tool_use(tools, tool_use, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
-- Check if execution is already cancelled
|
||||
if Helpers.is_cancelled then
|
||||
Utils.debug("Tool execution cancelled before starting: " .. tool_use.name)
|
||||
@@ -1274,25 +1298,32 @@ function M.process_tool_use(tools, tool_use, on_log, on_complete, session_ctx)
|
||||
return result_str, err
|
||||
end
|
||||
|
||||
local result, err = func(input_json, function(log)
|
||||
-- Check for cancellation during logging
|
||||
if Helpers.is_cancelled then return end
|
||||
if on_log then on_log(tool_use.id, tool_use.name, log, "running") end
|
||||
end, function(result, err)
|
||||
-- Check for cancellation before completing
|
||||
if Helpers.is_cancelled then
|
||||
Helpers.is_cancelled = false
|
||||
if on_complete then on_complete(nil, Helpers.CANCEL_TOKEN) end
|
||||
return
|
||||
end
|
||||
local result, err = func(input_json, {
|
||||
session_ctx = opts.session_ctx or {},
|
||||
on_log = function(log)
|
||||
-- Check for cancellation during logging
|
||||
if Helpers.is_cancelled then return end
|
||||
if on_log then on_log(tool_use.id, tool_use.name, log, "running") end
|
||||
end,
|
||||
set_store = function(key, value)
|
||||
if opts.set_tool_use_store then opts.set_tool_use_store(tool_use.id, key, value) end
|
||||
end,
|
||||
on_complete = function(result, err)
|
||||
-- Check for cancellation before completing
|
||||
if Helpers.is_cancelled then
|
||||
Helpers.is_cancelled = false
|
||||
if on_complete then on_complete(nil, Helpers.CANCEL_TOKEN) end
|
||||
return
|
||||
end
|
||||
|
||||
result, err = handle_result(result, err)
|
||||
if on_complete == nil then
|
||||
Utils.error("asynchronous tool " .. tool_use.name .. " result not handled")
|
||||
return
|
||||
end
|
||||
on_complete(result, err)
|
||||
end, session_ctx)
|
||||
result, err = handle_result(result, err)
|
||||
if on_complete == nil then
|
||||
Utils.error("asynchronous tool " .. tool_use.name .. " result not handled")
|
||||
return
|
||||
end
|
||||
on_complete(result, err)
|
||||
end,
|
||||
})
|
||||
|
||||
-- Result and error being nil means that the tool was executed asynchronously
|
||||
if result == nil and err == nil and on_complete then return end
|
||||
|
||||
Reference in New Issue
Block a user