feat: support tools in bedrock (#1598)
This commit is contained in:
@@ -34,6 +34,11 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
|||||||
return model_handler.parse_response(self, ctx, data_stream, event_state, opts)
|
return model_handler.parse_response(self, ctx, data_stream, event_state, opts)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function M:transform_tool(tool)
|
||||||
|
local model_handler = M.load_model_handler()
|
||||||
|
return model_handler.transform_tool(self, tool)
|
||||||
|
end
|
||||||
|
|
||||||
function M:build_bedrock_payload(prompt_opts, request_body)
|
function M:build_bedrock_payload(prompt_opts, request_body)
|
||||||
local model_handler = M.load_model_handler()
|
local model_handler = M.load_model_handler()
|
||||||
return model_handler.build_bedrock_payload(self, prompt_opts, request_body)
|
return model_handler.build_bedrock_payload(self, prompt_opts, request_body)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
---@field role "user" | "assistant"
|
---@field role "user" | "assistant"
|
||||||
---@field content [AvanteBedrockClaudeTextMessage][]
|
---@field content [AvanteBedrockClaudeTextMessage][]
|
||||||
|
|
||||||
|
local P = require("avante.providers")
|
||||||
local Claude = require("avante.providers.claude")
|
local Claude = require("avante.providers.claude")
|
||||||
|
|
||||||
---@class AvanteBedrockModelHandler
|
---@class AvanteBedrockModelHandler
|
||||||
@@ -20,6 +21,7 @@ M.role_map = {
|
|||||||
M.is_disable_stream = Claude.is_disable_stream
|
M.is_disable_stream = Claude.is_disable_stream
|
||||||
M.parse_messages = Claude.parse_messages
|
M.parse_messages = Claude.parse_messages
|
||||||
M.parse_response = Claude.parse_response
|
M.parse_response = Claude.parse_response
|
||||||
|
M.transform_tool = Claude.transform_tool
|
||||||
|
|
||||||
---@param provider AvanteProviderFunctor
|
---@param provider AvanteProviderFunctor
|
||||||
---@param prompt_opts AvantePromptOptions
|
---@param prompt_opts AvantePromptOptions
|
||||||
@@ -29,10 +31,21 @@ function M.build_bedrock_payload(provider, prompt_opts, request_body)
|
|||||||
local system_prompt = prompt_opts.system_prompt or ""
|
local system_prompt = prompt_opts.system_prompt or ""
|
||||||
local messages = provider:parse_messages(prompt_opts)
|
local messages = provider:parse_messages(prompt_opts)
|
||||||
local max_tokens = request_body.max_tokens or 2000
|
local max_tokens = request_body.max_tokens or 2000
|
||||||
|
|
||||||
|
local provider_conf, _ = P.parse_config(provider)
|
||||||
|
local disable_tools = provider_conf.disable_tools or false
|
||||||
|
local tools = {}
|
||||||
|
if not disable_tools and prompt_opts.tools then
|
||||||
|
for _, tool in ipairs(prompt_opts.tools) do
|
||||||
|
table.insert(tools, provider:transform_tool(tool))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
local payload = {
|
local payload = {
|
||||||
anthropic_version = "bedrock-2023-05-31",
|
anthropic_version = "bedrock-2023-05-31",
|
||||||
max_tokens = max_tokens,
|
max_tokens = max_tokens,
|
||||||
messages = messages,
|
messages = messages,
|
||||||
|
tools = tools,
|
||||||
system = system_prompt,
|
system = system_prompt,
|
||||||
}
|
}
|
||||||
return vim.tbl_deep_extend("force", payload, request_body or {})
|
return vim.tbl_deep_extend("force", payload, request_body or {})
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ end
|
|||||||
|
|
||||||
---@param tool AvanteLLMTool
|
---@param tool AvanteLLMTool
|
||||||
---@return AvanteClaudeTool
|
---@return AvanteClaudeTool
|
||||||
function M.transform_tool(tool)
|
function M:transform_tool(tool)
|
||||||
local input_schema_properties = {}
|
local input_schema_properties = {}
|
||||||
local required = {}
|
local required = {}
|
||||||
for _, field in ipairs(tool.param.fields) do
|
for _, field in ipairs(tool.param.fields) do
|
||||||
@@ -339,7 +339,7 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
local tools = {}
|
local tools = {}
|
||||||
if not disable_tools and prompt_opts.tools then
|
if not disable_tools and prompt_opts.tools then
|
||||||
for _, tool in ipairs(prompt_opts.tools) do
|
for _, tool in ipairs(prompt_opts.tools) do
|
||||||
table.insert(tools, self.transform_tool(tool))
|
table.insert(tools, self:transform_tool(tool))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
@@ -228,7 +228,7 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
local tools = {}
|
local tools = {}
|
||||||
if not disable_tools and prompt_opts.tools then
|
if not disable_tools and prompt_opts.tools then
|
||||||
for _, tool in ipairs(prompt_opts.tools) do
|
for _, tool in ipairs(prompt_opts.tools) do
|
||||||
table.insert(tools, OpenAI.transform_tool(tool))
|
table.insert(tools, OpenAI:transform_tool(tool))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ function M:is_disable_stream() return false end
|
|||||||
|
|
||||||
---@param tool AvanteLLMTool
|
---@param tool AvanteLLMTool
|
||||||
---@return AvanteOpenAITool
|
---@return AvanteOpenAITool
|
||||||
function M.transform_tool(tool)
|
function M:transform_tool(tool)
|
||||||
local input_schema_properties = {}
|
local input_schema_properties = {}
|
||||||
local required = {}
|
local required = {}
|
||||||
for _, field in ipairs(tool.param.fields) do
|
for _, field in ipairs(tool.param.fields) do
|
||||||
@@ -303,7 +303,7 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
if not disable_tools and prompt_opts.tools then
|
if not disable_tools and prompt_opts.tools then
|
||||||
tools = {}
|
tools = {}
|
||||||
for _, tool in ipairs(prompt_opts.tools) do
|
for _, tool in ipairs(prompt_opts.tools) do
|
||||||
table.insert(tools, self.transform_tool(tool))
|
table.insert(tools, self:transform_tool(tool))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
local tools = {}
|
local tools = {}
|
||||||
if not disable_tools and prompt_opts.tools then
|
if not disable_tools and prompt_opts.tools then
|
||||||
for _, tool in ipairs(prompt_opts.tools) do
|
for _, tool in ipairs(prompt_opts.tools) do
|
||||||
table.insert(tools, P.claude.transform_tool(tool))
|
table.insert(tools, P.claude:transform_tool(tool))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
@@ -277,7 +277,7 @@ vim.g.avante_login = vim.g.avante_login
|
|||||||
---@field parse_api_key fun(): string | nil
|
---@field parse_api_key fun(): string | nil
|
||||||
---@field parse_stream_data? AvanteStreamParser
|
---@field parse_stream_data? AvanteStreamParser
|
||||||
---@field on_error? fun(result: table<string, any>): nil
|
---@field on_error? fun(result: table<string, any>): nil
|
||||||
---@field transform_tool? fun(tool: AvanteLLMTool): AvanteOpenAITool | AvanteClaudeTool
|
---@field transform_tool? fun(self: AvanteProviderFunctor, tool: AvanteLLMTool): AvanteOpenAITool | AvanteClaudeTool
|
||||||
---@field get_rate_limit_sleep_time? fun(self: AvanteProviderFunctor, headers: table<string, string>): integer | nil
|
---@field get_rate_limit_sleep_time? fun(self: AvanteProviderFunctor, headers: table<string, string>): integer | nil
|
||||||
---
|
---
|
||||||
---@alias AvanteBedrockPayloadBuilder fun(self: AvanteBedrockModelHandler | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions, request_body: table<string, any>): table<string, any>
|
---@alias AvanteBedrockPayloadBuilder fun(self: AvanteBedrockModelHandler | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions, request_body: table<string, any>): table<string, any>
|
||||||
|
|||||||
Reference in New Issue
Block a user