feat(providers): fail gracefully when a provider is misconfigured (#2768)
This commit is contained in:
@@ -548,8 +548,11 @@ function M.curl(opts)
|
||||
if orig_on_stop then return orig_on_stop(stop_opts) end
|
||||
end
|
||||
|
||||
---@type AvanteCurlOutput
|
||||
local spec = provider:parse_curl_args(prompt_opts)
|
||||
if not spec then
|
||||
handler_opts.on_stop({ reason = "error", error = "Provider configuration error" })
|
||||
return
|
||||
end
|
||||
|
||||
---@type string
|
||||
local current_event_state = nil
|
||||
|
||||
@@ -20,6 +20,8 @@ M.api_key_name = "AZURE_OPENAI_API_KEY"
|
||||
-- Inherit from OpenAI class
|
||||
setmetatable(M, { __index = O })
|
||||
|
||||
---@param prompt_opts AvantePromptOptions
|
||||
---@return AvanteCurlOutput|nil
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(self)
|
||||
local disable_tools = provider_conf.disable_tools or false
|
||||
@@ -29,10 +31,15 @@ function M:parse_curl_args(prompt_opts)
|
||||
}
|
||||
|
||||
if P.env.require_api_key(provider_conf) then
|
||||
local api_key = self.parse_api_key()
|
||||
if not api_key then
|
||||
Utils.error("Azure: API key is not set. Please set " .. M.api_key_name)
|
||||
return nil
|
||||
end
|
||||
if provider_conf.entra then
|
||||
headers["Authorization"] = "Bearer " .. self.parse_api_key()
|
||||
headers["Authorization"] = "Bearer " .. api_key
|
||||
else
|
||||
headers["api-key"] = self.parse_api_key()
|
||||
headers["api-key"] = api_key
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
@@ -125,7 +125,7 @@ function M:parse_response_without_stream(data, event_state, opts)
|
||||
end
|
||||
|
||||
---@param prompt_opts AvantePromptOptions
|
||||
---@return table
|
||||
---@return AvanteCurlOutput|nil
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(self)
|
||||
|
||||
@@ -136,9 +136,12 @@ function M:parse_curl_args(prompt_opts)
|
||||
if profile ~= nil and profile ~= "" then
|
||||
---@diagnostic disable-next-line: undefined-field
|
||||
region = provider_conf.aws_region
|
||||
if not region or region == "" then
|
||||
Utils.error("Bedrock: no aws_region specified in bedrock config")
|
||||
return nil
|
||||
end
|
||||
|
||||
local awsCreds = M:get_aws_credentials(region, profile)
|
||||
if not region or region == "" then error("No aws_region specified in bedrock config") end
|
||||
|
||||
access_key_id = awsCreds.access_key_id
|
||||
secret_access_key = awsCreds.secret_access_key
|
||||
@@ -153,7 +156,8 @@ function M:parse_curl_args(prompt_opts)
|
||||
region = parts[3]
|
||||
session_token = parts[4]
|
||||
else
|
||||
error("API key not set correctly")
|
||||
Utils.error("Bedrock: API key not set correctly")
|
||||
return nil
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
@@ -315,7 +315,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
end
|
||||
|
||||
---@param prompt_opts AvantePromptOptions
|
||||
---@return table
|
||||
---@return AvanteCurlOutput|nil
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(self)
|
||||
local disable_tools = provider_conf.disable_tools or false
|
||||
@@ -326,7 +326,14 @@ function M:parse_curl_args(prompt_opts)
|
||||
["anthropic-beta"] = "prompt-caching-2024-07-31",
|
||||
}
|
||||
|
||||
if P.env.require_api_key(provider_conf) then headers["x-api-key"] = self.parse_api_key() end
|
||||
if P.env.require_api_key(provider_conf) then
|
||||
local api_key = self.parse_api_key()
|
||||
if not api_key then
|
||||
Utils.error("Claude: API key is not set. Please set " .. M.api_key_name)
|
||||
return nil
|
||||
end
|
||||
headers["x-api-key"] = api_key
|
||||
end
|
||||
|
||||
local messages = self:parse_messages(prompt_opts)
|
||||
|
||||
|
||||
@@ -76,6 +76,8 @@ function M:parse_stream_data(ctx, data, opts)
|
||||
end
|
||||
end
|
||||
|
||||
---@param prompt_opts AvantePromptOptions
|
||||
---@return AvanteCurlOutput|nil
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(self)
|
||||
|
||||
@@ -89,7 +91,14 @@ function M:parse_curl_args(prompt_opts)
|
||||
.. "."
|
||||
.. vim.version().patch,
|
||||
}
|
||||
if P.env.require_api_key(provider_conf) then headers["Authorization"] = "Bearer " .. self.parse_api_key() end
|
||||
if P.env.require_api_key(provider_conf) then
|
||||
local api_key = self.parse_api_key()
|
||||
if not api_key then
|
||||
Utils.error("Cohere: API key is not set. Please set " .. M.api_key_name)
|
||||
return nil
|
||||
end
|
||||
headers["Authorization"] = "Bearer " .. api_key
|
||||
end
|
||||
|
||||
return {
|
||||
url = Utils.url_join(provider_conf.endpoint, "/chat"),
|
||||
|
||||
@@ -312,11 +312,16 @@ function M:parse_response(ctx, data_stream, _, opts)
|
||||
end
|
||||
end
|
||||
|
||||
---@param prompt_opts AvantePromptOptions
|
||||
---@return AvanteCurlOutput|nil
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local provider_conf, request_body = Providers.parse_config(self)
|
||||
|
||||
local api_key = self:parse_api_key()
|
||||
if api_key == nil then error("Cannot get the gemini api key!") end
|
||||
if api_key == nil then
|
||||
Utils.error("Gemini: API key is not set. Please set " .. M.api_key_name)
|
||||
return nil
|
||||
end
|
||||
|
||||
return {
|
||||
url = Utils.url_join(
|
||||
|
||||
@@ -190,12 +190,21 @@ function M:parse_stream_data(ctx, data, opts)
|
||||
end
|
||||
|
||||
---@param prompt_opts AvantePromptOptions
|
||||
---@return AvanteCurlOutput|nil
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local provider_conf, request_body = Providers.parse_config(self)
|
||||
local keep_alive = provider_conf.keep_alive or "5m"
|
||||
|
||||
if not provider_conf.model or provider_conf.model == "" then error("Ollama model must be specified in config") end
|
||||
if not provider_conf.endpoint then error("Ollama requires endpoint configuration") end
|
||||
if not provider_conf.model or provider_conf.model == "" then
|
||||
Utils.error("Ollama: model must be specified in config")
|
||||
return nil
|
||||
end
|
||||
|
||||
if not provider_conf.endpoint then
|
||||
Utils.error("Ollama: endpoint must be specified in config")
|
||||
return nil
|
||||
end
|
||||
|
||||
local headers = {
|
||||
["Content-Type"] = "application/json",
|
||||
["Accept"] = "application/json",
|
||||
|
||||
@@ -504,6 +504,8 @@ function M:parse_response_without_stream(data, _, opts)
|
||||
end
|
||||
end
|
||||
|
||||
---@param prompt_opts AvantePromptOptions
|
||||
---@return AvanteCurlOutput|nil
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local provider_conf, request_body = Providers.parse_config(self)
|
||||
local disable_tools = provider_conf.disable_tools or false
|
||||
@@ -515,7 +517,8 @@ function M:parse_curl_args(prompt_opts)
|
||||
if Providers.env.require_api_key(provider_conf) then
|
||||
local api_key = self.parse_api_key()
|
||||
if api_key == nil then
|
||||
error(Config.provider .. " API key is not set, please set it in your environment variable or config file")
|
||||
Utils.error(Config.provider .. ": API key is not set, please set it in your environment variable or config file")
|
||||
return nil
|
||||
end
|
||||
headers["Authorization"] = "Bearer " .. api_key
|
||||
end
|
||||
|
||||
@@ -230,7 +230,7 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@alias AvanteMessagesParser fun(self: AvanteProviderFunctor, opts: AvantePromptOptions): AvanteChatMessage[]
|
||||
---
|
||||
---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table<string, any> | string, headers: table<string, string>, rawArgs: string[] | nil}
|
||||
---@alias AvanteCurlArgsParser fun(self: AvanteProviderFunctor, prompt_opts: AvantePromptOptions): AvanteCurlOutput
|
||||
---@alias AvanteCurlArgsParser fun(self: AvanteProviderFunctor, prompt_opts: AvantePromptOptions): (AvanteCurlOutput | nil)
|
||||
---
|
||||
---@alias AvanteResponseParser fun(self: AvanteProviderFunctor, ctx: any, data_stream: string, event_state: string, opts: AvanteHandlerOptions): nil
|
||||
---
|
||||
|
||||
Reference in New Issue
Block a user