feat: add vertex claude provider (#1549)

* feat: Add vertex claude provider

* remove debug logging

---------

Co-authored-by: Miguelo Sana <miguelo@incubeta.com>
This commit is contained in:
miguelosana
2025-03-10 15:43:10 +01:00
committed by GitHub
parent 7d7f93d093
commit 3eaaaa8f5f
3 changed files with 68 additions and 0 deletions

View File

@@ -264,6 +264,14 @@ M._defaults = {
num_ctx = 4096,
},
},
---@type AvanteSupportedProvider
vertex_claude = {
endpoint = "https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/publishers/antrhopic/models",
model = "claude-3-5-sonnet-v2@20241022",
timeout = 30000, -- Timeout in milliseconds
temperature = 0,
max_tokens = 4096,
},
---To add support for custom provider, follow the format below
---See https://github.com/yetone/avante.nvim/wiki#custom-providers for more details
---@type {[string]: AvanteProvider}

View File

@@ -19,6 +19,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
---@field cohere AvanteProviderFunctor
---@field bedrock AvanteBedrockProviderFunctor
---@field ollama AvanteProviderFunctor
---@field vertex_claude AvanteProviderFunctor
local M = {}
---@class EnvironmentHandler

View File

@@ -0,0 +1,59 @@
local P = require("avante.providers")
local Vertex = require("avante.providers.vertex")
---@class AvanteProviderFunctor
local M = {}
M.role_map = {
user = "user",
assistant = "assistant",
}
M.is_disable_stream = P.claude.is_disable_stream
M.parse_messages = P.claude.parse_messages
M.parse_response = P.claude.parse_response
M.parse_api_key = Vertex.parse_api_key
M.on_error = Vertex.on_error
Vertex.api_key_name = "cmd:gcloud auth print-access-token"
---@param prompt_opts AvantePromptOptions
function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = P.parse_config(self)
local location = vim.fn.getenv("LOCATION")
local project_id = vim.fn.getenv("PROJECT_ID")
local model_id = provider_conf.model or "default-model-id"
if location == nil or location == vim.NIL then location = "default-location" end
if project_id == nil or project_id == vim.NIL then project_id = "default-project-id" end
local url = provider_conf.endpoint:gsub("LOCATION", location):gsub("PROJECT_ID", project_id)
url = string.format("%s/%s:streamRawPredict", url, model_id)
local system_prompt = prompt_opts.system_prompt or ""
local messages = self:parse_messages(prompt_opts)
request_body = vim.tbl_deep_extend("force", request_body, {
anthropic_version = "vertex-2023-10-16",
temperature = 0,
max_tokens = 4096,
stream = true,
messages = messages,
system = {
{
type = "text",
text = system_prompt,
cache_control = { type = "ephemeral" },
},
},
})
return {
url = url,
headers = {
["Authorization"] = "Bearer " .. Vertex.parse_api_key(),
["Content-Type"] = "application/json; charset=utf-8",
},
body = vim.tbl_deep_extend("force", {}, request_body),
}
end
return M