feat: add vertex claude provider (#1549)
* feat: Add vertex claude provider * remove debug logging --------- Co-authored-by: Miguelo Sana <miguelo@incubeta.com>
This commit is contained in:
@@ -264,6 +264,14 @@ M._defaults = {
|
||||
num_ctx = 4096,
|
||||
},
|
||||
},
|
||||
---@type AvanteSupportedProvider
|
||||
vertex_claude = {
|
||||
endpoint = "https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/publishers/antrhopic/models",
|
||||
model = "claude-3-5-sonnet-v2@20241022",
|
||||
timeout = 30000, -- Timeout in milliseconds
|
||||
temperature = 0,
|
||||
max_tokens = 4096,
|
||||
},
|
||||
---To add support for custom provider, follow the format below
|
||||
---See https://github.com/yetone/avante.nvim/wiki#custom-providers for more details
|
||||
---@type {[string]: AvanteProvider}
|
||||
|
||||
@@ -19,6 +19,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
|
||||
---@field cohere AvanteProviderFunctor
|
||||
---@field bedrock AvanteBedrockProviderFunctor
|
||||
---@field ollama AvanteProviderFunctor
|
||||
---@field vertex_claude AvanteProviderFunctor
|
||||
local M = {}
|
||||
|
||||
---@class EnvironmentHandler
|
||||
|
||||
59
lua/avante/providers/vertex_claude.lua
Normal file
59
lua/avante/providers/vertex_claude.lua
Normal file
@@ -0,0 +1,59 @@
|
||||
local P = require("avante.providers")
|
||||
local Vertex = require("avante.providers.vertex")
|
||||
|
||||
---@class AvanteProviderFunctor
|
||||
local M = {}
|
||||
|
||||
M.role_map = {
|
||||
user = "user",
|
||||
assistant = "assistant",
|
||||
}
|
||||
|
||||
M.is_disable_stream = P.claude.is_disable_stream
|
||||
M.parse_messages = P.claude.parse_messages
|
||||
M.parse_response = P.claude.parse_response
|
||||
M.parse_api_key = Vertex.parse_api_key
|
||||
M.on_error = Vertex.on_error
|
||||
|
||||
Vertex.api_key_name = "cmd:gcloud auth print-access-token"
|
||||
|
||||
---@param prompt_opts AvantePromptOptions
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(self)
|
||||
local location = vim.fn.getenv("LOCATION")
|
||||
local project_id = vim.fn.getenv("PROJECT_ID")
|
||||
local model_id = provider_conf.model or "default-model-id"
|
||||
if location == nil or location == vim.NIL then location = "default-location" end
|
||||
if project_id == nil or project_id == vim.NIL then project_id = "default-project-id" end
|
||||
local url = provider_conf.endpoint:gsub("LOCATION", location):gsub("PROJECT_ID", project_id)
|
||||
|
||||
url = string.format("%s/%s:streamRawPredict", url, model_id)
|
||||
|
||||
local system_prompt = prompt_opts.system_prompt or ""
|
||||
local messages = self:parse_messages(prompt_opts)
|
||||
request_body = vim.tbl_deep_extend("force", request_body, {
|
||||
anthropic_version = "vertex-2023-10-16",
|
||||
temperature = 0,
|
||||
max_tokens = 4096,
|
||||
stream = true,
|
||||
messages = messages,
|
||||
system = {
|
||||
{
|
||||
type = "text",
|
||||
text = system_prompt,
|
||||
cache_control = { type = "ephemeral" },
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
return {
|
||||
url = url,
|
||||
headers = {
|
||||
["Authorization"] = "Bearer " .. Vertex.parse_api_key(),
|
||||
["Content-Type"] = "application/json; charset=utf-8",
|
||||
},
|
||||
body = vim.tbl_deep_extend("force", {}, request_body),
|
||||
}
|
||||
end
|
||||
|
||||
return M
|
||||
Reference in New Issue
Block a user