From fee1aacdfcb24a52b0061ad352a31e2239e07100 Mon Sep 17 00:00:00 2001 From: miguelosana Date: Wed, 12 Mar 2025 11:43:44 +0100 Subject: [PATCH] feat: Add claude tools to vertex claude provider (#1559) * feat: Add claude tools to vertex claude provider * fix: export transform_tool from claude.lua * Include type for transfor_tool --------- Co-authored-by: Miguelo Sana --- lua/avante/providers/claude.lua | 26 +++++++++++++------------- lua/avante/providers/vertex_claude.lua | 16 ++++++++++++++++ lua/avante/types.lua | 1 + 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index 3d3d908..aacf0b8 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -2,9 +2,20 @@ local Utils = require("avante.utils") local Clipboard = require("avante.clipboard") local P = require("avante.providers") +---@class AvanteProviderFunctor +local M = {} + +M.api_key_name = "ANTHROPIC_API_KEY" +M.support_prompt_caching = true + +M.role_map = { + user = "user", + assistant = "assistant", +} + ---@param tool AvanteLLMTool ---@return AvanteClaudeTool -local function transform_tool(tool) +function M.transform_tool(tool) local input_schema_properties = {} local required = {} for _, field in ipairs(tool.param.fields) do @@ -25,17 +36,6 @@ local function transform_tool(tool) } end ----@class AvanteProviderFunctor -local M = {} - -M.api_key_name = "ANTHROPIC_API_KEY" -M.support_prompt_caching = true - -M.role_map = { - user = "user", - assistant = "assistant", -} - function M:parse_messages(opts) ---@type AvanteClaudeMessage[] local messages = {} @@ -278,7 +278,7 @@ function M:parse_curl_args(prompt_opts) local tools = {} if not disable_tools and prompt_opts.tools then for _, tool in ipairs(prompt_opts.tools) do - table.insert(tools, transform_tool(tool)) + table.insert(tools, self.transform_tool(tool)) end end diff --git a/lua/avante/providers/vertex_claude.lua b/lua/avante/providers/vertex_claude.lua index 44cf736..ab2ca0a 100644 --- a/lua/avante/providers/vertex_claude.lua +++ b/lua/avante/providers/vertex_claude.lua @@ -19,6 +19,7 @@ Vertex.api_key_name = "cmd:gcloud auth print-access-token" ---@param prompt_opts AvantePromptOptions function M:parse_curl_args(prompt_opts) local provider_conf, request_body = P.parse_config(self) + local disable_tools = provider_conf.disable_tools or false local location = vim.fn.getenv("LOCATION") local project_id = vim.fn.getenv("PROJECT_ID") local model_id = provider_conf.model or "default-model-id" @@ -30,6 +31,20 @@ function M:parse_curl_args(prompt_opts) local system_prompt = prompt_opts.system_prompt or "" local messages = self:parse_messages(prompt_opts) + + local tools = {} + if not disable_tools and prompt_opts.tools then + for _, tool in ipairs(prompt_opts.tools) do + table.insert(tools, P.claude.transform_tool(tool)) + end + end + + if self.support_prompt_caching and #tools > 0 then + local last_tool = vim.deepcopy(tools[#tools]) + last_tool.cache_control = { type = "ephemeral" } + tools[#tools] = last_tool + end + request_body = vim.tbl_deep_extend("force", request_body, { anthropic_version = "vertex-2023-10-16", temperature = 0, @@ -43,6 +58,7 @@ function M:parse_curl_args(prompt_opts) cache_control = { type = "ephemeral" }, }, }, + tools = tools, }) return { diff --git a/lua/avante/types.lua b/lua/avante/types.lua index 895fb2c..94288eb 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -269,6 +269,7 @@ vim.g.avante_login = vim.g.avante_login ---@field parse_api_key fun(): string | nil ---@field parse_stream_data? AvanteStreamParser ---@field on_error? fun(result: table): nil +---@field transform_tool? fun(tool: AvanteLLMTool): AvanteOpenAITool | AvanteClaudeTool --- ---@alias AvanteBedrockPayloadBuilder fun(self: AvanteBedrockModelHandler | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions, request_body: table): table ---