From 3eaaaa8f5fcb9ee617ac425a87a156eeb3b6642c Mon Sep 17 00:00:00 2001 From: miguelosana Date: Mon, 10 Mar 2025 15:43:10 +0100 Subject: [PATCH] feat: add vertex claude provider (#1549) * feat: Add vertex claude provider * remove debug logging --------- Co-authored-by: Miguelo Sana --- lua/avante/config.lua | 8 ++++ lua/avante/providers/init.lua | 1 + lua/avante/providers/vertex_claude.lua | 59 ++++++++++++++++++++++++++ 3 files changed, 68 insertions(+) create mode 100644 lua/avante/providers/vertex_claude.lua diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 9cb20a2..7a92f5c 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -264,6 +264,14 @@ M._defaults = { num_ctx = 4096, }, }, + ---@type AvanteSupportedProvider + vertex_claude = { + endpoint = "https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/publishers/antrhopic/models", + model = "claude-3-5-sonnet-v2@20241022", + timeout = 30000, -- Timeout in milliseconds + temperature = 0, + max_tokens = 4096, + }, ---To add support for custom provider, follow the format below ---See https://github.com/yetone/avante.nvim/wiki#custom-providers for more details ---@type {[string]: AvanteProvider} diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index bfc8d0e..dd6fe98 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -19,6 +19,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil } ---@field cohere AvanteProviderFunctor ---@field bedrock AvanteBedrockProviderFunctor ---@field ollama AvanteProviderFunctor +---@field vertex_claude AvanteProviderFunctor local M = {} ---@class EnvironmentHandler diff --git a/lua/avante/providers/vertex_claude.lua b/lua/avante/providers/vertex_claude.lua new file mode 100644 index 0000000..ee9bc81 --- /dev/null +++ b/lua/avante/providers/vertex_claude.lua @@ -0,0 +1,59 @@ +local P = require("avante.providers") +local Vertex = require("avante.providers.vertex") + +---@class AvanteProviderFunctor +local M = {} + +M.role_map = { + user = "user", + assistant = "assistant", +} + +M.is_disable_stream = P.claude.is_disable_stream +M.parse_messages = P.claude.parse_messages +M.parse_response = P.claude.parse_response +M.parse_api_key = Vertex.parse_api_key +M.on_error = Vertex.on_error + +Vertex.api_key_name = "cmd:gcloud auth print-access-token" + +---@param prompt_opts AvantePromptOptions +function M:parse_curl_args(prompt_opts) + local provider_conf, request_body = P.parse_config(self) + local location = vim.fn.getenv("LOCATION") + local project_id = vim.fn.getenv("PROJECT_ID") + local model_id = provider_conf.model or "default-model-id" + if location == nil or location == vim.NIL then location = "default-location" end + if project_id == nil or project_id == vim.NIL then project_id = "default-project-id" end + local url = provider_conf.endpoint:gsub("LOCATION", location):gsub("PROJECT_ID", project_id) + + url = string.format("%s/%s:streamRawPredict", url, model_id) + + local system_prompt = prompt_opts.system_prompt or "" + local messages = self:parse_messages(prompt_opts) + request_body = vim.tbl_deep_extend("force", request_body, { + anthropic_version = "vertex-2023-10-16", + temperature = 0, + max_tokens = 4096, + stream = true, + messages = messages, + system = { + { + type = "text", + text = system_prompt, + cache_control = { type = "ephemeral" }, + }, + }, + }) + + return { + url = url, + headers = { + ["Authorization"] = "Bearer " .. Vertex.parse_api_key(), + ["Content-Type"] = "application/json; charset=utf-8", + }, + body = vim.tbl_deep_extend("force", {}, request_body), + } +end + +return M