diff --git a/lua/avante/providers/bedrock.lua b/lua/avante/providers/bedrock.lua index 355882f..31e4f4c 100644 --- a/lua/avante/providers/bedrock.lua +++ b/lua/avante/providers/bedrock.lua @@ -34,6 +34,11 @@ function M:parse_response(ctx, data_stream, event_state, opts) return model_handler.parse_response(self, ctx, data_stream, event_state, opts) end +function M:transform_tool(tool) + local model_handler = M.load_model_handler() + return model_handler.transform_tool(self, tool) +end + function M:build_bedrock_payload(prompt_opts, request_body) local model_handler = M.load_model_handler() return model_handler.build_bedrock_payload(self, prompt_opts, request_body) diff --git a/lua/avante/providers/bedrock/claude.lua b/lua/avante/providers/bedrock/claude.lua index e5a16a5..6130d22 100644 --- a/lua/avante/providers/bedrock/claude.lua +++ b/lua/avante/providers/bedrock/claude.lua @@ -6,6 +6,7 @@ ---@field role "user" | "assistant" ---@field content [AvanteBedrockClaudeTextMessage][] +local P = require("avante.providers") local Claude = require("avante.providers.claude") ---@class AvanteBedrockModelHandler @@ -20,6 +21,7 @@ M.role_map = { M.is_disable_stream = Claude.is_disable_stream M.parse_messages = Claude.parse_messages M.parse_response = Claude.parse_response +M.transform_tool = Claude.transform_tool ---@param provider AvanteProviderFunctor ---@param prompt_opts AvantePromptOptions @@ -29,10 +31,21 @@ function M.build_bedrock_payload(provider, prompt_opts, request_body) local system_prompt = prompt_opts.system_prompt or "" local messages = provider:parse_messages(prompt_opts) local max_tokens = request_body.max_tokens or 2000 + + local provider_conf, _ = P.parse_config(provider) + local disable_tools = provider_conf.disable_tools or false + local tools = {} + if not disable_tools and prompt_opts.tools then + for _, tool in ipairs(prompt_opts.tools) do + table.insert(tools, provider:transform_tool(tool)) + end + end + local payload = { anthropic_version = "bedrock-2023-05-31", max_tokens = max_tokens, messages = messages, + tools = tools, system = system_prompt, } return vim.tbl_deep_extend("force", payload, request_body or {}) diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index 69ed37b..f760943 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -32,7 +32,7 @@ end ---@param tool AvanteLLMTool ---@return AvanteClaudeTool -function M.transform_tool(tool) +function M:transform_tool(tool) local input_schema_properties = {} local required = {} for _, field in ipairs(tool.param.fields) do @@ -339,7 +339,7 @@ function M:parse_curl_args(prompt_opts) local tools = {} if not disable_tools and prompt_opts.tools then for _, tool in ipairs(prompt_opts.tools) do - table.insert(tools, self.transform_tool(tool)) + table.insert(tools, self:transform_tool(tool)) end end diff --git a/lua/avante/providers/copilot.lua b/lua/avante/providers/copilot.lua index 7a48472..ece8394 100644 --- a/lua/avante/providers/copilot.lua +++ b/lua/avante/providers/copilot.lua @@ -228,7 +228,7 @@ function M:parse_curl_args(prompt_opts) local tools = {} if not disable_tools and prompt_opts.tools then for _, tool in ipairs(prompt_opts.tools) do - table.insert(tools, OpenAI.transform_tool(tool)) + table.insert(tools, OpenAI:transform_tool(tool)) end end diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index 1015306..39624a0 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -17,7 +17,7 @@ function M:is_disable_stream() return false end ---@param tool AvanteLLMTool ---@return AvanteOpenAITool -function M.transform_tool(tool) +function M:transform_tool(tool) local input_schema_properties = {} local required = {} for _, field in ipairs(tool.param.fields) do @@ -303,7 +303,7 @@ function M:parse_curl_args(prompt_opts) if not disable_tools and prompt_opts.tools then tools = {} for _, tool in ipairs(prompt_opts.tools) do - table.insert(tools, self.transform_tool(tool)) + table.insert(tools, self:transform_tool(tool)) end end diff --git a/lua/avante/providers/vertex_claude.lua b/lua/avante/providers/vertex_claude.lua index d6115b2..cb16bce 100644 --- a/lua/avante/providers/vertex_claude.lua +++ b/lua/avante/providers/vertex_claude.lua @@ -36,7 +36,7 @@ function M:parse_curl_args(prompt_opts) local tools = {} if not disable_tools and prompt_opts.tools then for _, tool in ipairs(prompt_opts.tools) do - table.insert(tools, P.claude.transform_tool(tool)) + table.insert(tools, P.claude:transform_tool(tool)) end end diff --git a/lua/avante/types.lua b/lua/avante/types.lua index f1c862b..5fc7054 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -277,7 +277,7 @@ vim.g.avante_login = vim.g.avante_login ---@field parse_api_key fun(): string | nil ---@field parse_stream_data? AvanteStreamParser ---@field on_error? fun(result: table): nil ----@field transform_tool? fun(tool: AvanteLLMTool): AvanteOpenAITool | AvanteClaudeTool +---@field transform_tool? fun(self: AvanteProviderFunctor, tool: AvanteLLMTool): AvanteOpenAITool | AvanteClaudeTool ---@field get_rate_limit_sleep_time? fun(self: AvanteProviderFunctor, headers: table): integer | nil --- ---@alias AvanteBedrockPayloadBuilder fun(self: AvanteBedrockModelHandler | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions, request_body: table): table