feat: tools (#1180)
* feat: tools * feat: claude use tools * feat: openai use tools
This commit is contained in:
@@ -8,6 +8,7 @@ local Utils = require("avante.utils")
|
||||
local Config = require("avante.config")
|
||||
local Path = require("avante.path")
|
||||
local P = require("avante.providers")
|
||||
local LLMTools = require("avante.llm_tools")
|
||||
|
||||
---@class avante.LLM
|
||||
local M = {}
|
||||
@@ -45,6 +46,8 @@ M.generate_prompts = function(opts)
|
||||
local project_root = Utils.root.get()
|
||||
Path.prompts.initialize(Path.prompts.get(project_root))
|
||||
|
||||
local system_info = Utils.get_system_info()
|
||||
|
||||
local template_opts = {
|
||||
use_xml_format = Provider.use_xml_format,
|
||||
ask = opts.ask, -- TODO: add mode without ask instruction
|
||||
@@ -53,6 +56,7 @@ M.generate_prompts = function(opts)
|
||||
selected_code = opts.selected_code,
|
||||
project_context = opts.project_context,
|
||||
diagnostics = opts.diagnostics,
|
||||
system_info = system_info,
|
||||
}
|
||||
|
||||
local system_prompt = Path.prompts.render_mode(mode, template_opts)
|
||||
@@ -111,6 +115,10 @@ M.generate_prompts = function(opts)
|
||||
system_prompt = system_prompt,
|
||||
messages = messages,
|
||||
image_paths = image_paths,
|
||||
tools = opts.tools,
|
||||
tool_use = opts.tool_use,
|
||||
tool_result = opts.tool_result,
|
||||
response_content = opts.response_content,
|
||||
}
|
||||
end
|
||||
|
||||
@@ -135,7 +143,28 @@ M._stream = function(opts)
|
||||
local current_event_state = nil
|
||||
|
||||
---@type AvanteHandlerOptions
|
||||
local handler_opts = { on_chunk = opts.on_chunk, on_complete = opts.on_complete }
|
||||
local handler_opts = {
|
||||
on_start = opts.on_start,
|
||||
on_chunk = opts.on_chunk,
|
||||
on_stop = function(stop_opts)
|
||||
if stop_opts.reason == "tool_use" and stop_opts.tool_use then
|
||||
local result, error = LLMTools.process_tool_use(stop_opts.tool_use)
|
||||
local tool_result = {
|
||||
tool_use_id = stop_opts.tool_use.id,
|
||||
content = error ~= nil and error or result,
|
||||
is_error = error ~= nil,
|
||||
}
|
||||
local new_opts = vim.tbl_deep_extend(
|
||||
"force",
|
||||
opts,
|
||||
{ tool_result = tool_result, tool_use = stop_opts.tool_use, response_content = stop_opts.response_content }
|
||||
)
|
||||
return M._stream(new_opts)
|
||||
end
|
||||
return opts.on_stop(stop_opts)
|
||||
end,
|
||||
}
|
||||
|
||||
---@type AvanteCurlOutput
|
||||
local spec = Provider.parse_curl_args(Provider, code_opts)
|
||||
|
||||
@@ -180,7 +209,7 @@ M._stream = function(opts)
|
||||
stream = function(err, data, _)
|
||||
if err then
|
||||
completed = true
|
||||
opts.on_complete(err)
|
||||
handler_opts.on_stop({ reason = "error", error = err })
|
||||
return
|
||||
end
|
||||
if not data then return end
|
||||
@@ -224,7 +253,7 @@ M._stream = function(opts)
|
||||
active_job = nil
|
||||
completed = true
|
||||
cleanup()
|
||||
opts.on_complete(err)
|
||||
handler_opts.on_stop({ reason = "error", error = err })
|
||||
end,
|
||||
callback = function(result)
|
||||
active_job = nil
|
||||
@@ -238,9 +267,10 @@ M._stream = function(opts)
|
||||
vim.schedule(function()
|
||||
if not completed then
|
||||
completed = true
|
||||
opts.on_complete(
|
||||
"API request failed with status " .. result.status .. ". Body: " .. vim.inspect(result.body)
|
||||
)
|
||||
handler_opts.on_stop({
|
||||
reason = "error",
|
||||
error = "API request failed with status " .. result.status .. ". Body: " .. vim.inspect(result.body),
|
||||
})
|
||||
end
|
||||
end)
|
||||
end
|
||||
@@ -335,9 +365,9 @@ M._dual_boost_stream = function(opts, Provider1, Provider2)
|
||||
on_chunk = function(chunk)
|
||||
if chunk then response = response .. chunk end
|
||||
end,
|
||||
on_complete = function(err)
|
||||
if err then
|
||||
Utils.error(string.format("Stream %d failed: %s", index, err))
|
||||
on_stop = function(stop_opts)
|
||||
if stop_opts.error then
|
||||
Utils.error(string.format("Stream %d failed: %s", index, stop_opts.error))
|
||||
return
|
||||
end
|
||||
Utils.debug(string.format("Response %d completed", index))
|
||||
@@ -381,10 +411,15 @@ end
|
||||
---@field instructions string
|
||||
---@field mode LlmMode
|
||||
---@field provider AvanteProviderFunctor | AvanteBedrockProviderFunctor | nil
|
||||
---@field tools? AvanteLLMTool[]
|
||||
---@field tool_result? AvanteLLMToolResult
|
||||
---@field tool_use? AvanteLLMToolUse
|
||||
---@field response_content? string
|
||||
---
|
||||
---@class StreamOptions: GeneratePromptsOptions
|
||||
---@field on_chunk AvanteChunkParser
|
||||
---@field on_complete AvanteCompleteParser
|
||||
---@field on_start AvanteLLMStartCallback
|
||||
---@field on_chunk AvanteLLMChunkCallback
|
||||
---@field on_stop AvanteLLMStopCallback
|
||||
|
||||
---@param opts StreamOptions
|
||||
M.stream = function(opts)
|
||||
@@ -396,12 +431,12 @@ M.stream = function(opts)
|
||||
return original_on_chunk(chunk)
|
||||
end)
|
||||
end
|
||||
if opts.on_complete ~= nil then
|
||||
local original_on_complete = opts.on_complete
|
||||
opts.on_complete = vim.schedule_wrap(function(err)
|
||||
if opts.on_stop ~= nil then
|
||||
local original_on_stop = opts.on_stop
|
||||
opts.on_stop = vim.schedule_wrap(function(stop_opts)
|
||||
if is_completed then return end
|
||||
is_completed = true
|
||||
return original_on_complete(err)
|
||||
if stop_opts.reason == "complete" or stop_opts.reason == "error" then is_completed = true end
|
||||
return original_on_stop(stop_opts)
|
||||
end)
|
||||
end
|
||||
if Config.dual_boost.enabled then
|
||||
|
||||
Reference in New Issue
Block a user