feat(providers): fail gracefully when a provider is misconfigured (#2768)

This commit is contained in:
Dmitry Torokhov
2025-10-15 03:43:55 -07:00
committed by GitHub
parent 250b7a26b4
commit 0716819a0e
9 changed files with 61 additions and 14 deletions

View File

@@ -548,8 +548,11 @@ function M.curl(opts)
if orig_on_stop then return orig_on_stop(stop_opts) end if orig_on_stop then return orig_on_stop(stop_opts) end
end end
---@type AvanteCurlOutput
local spec = provider:parse_curl_args(prompt_opts) local spec = provider:parse_curl_args(prompt_opts)
if not spec then
handler_opts.on_stop({ reason = "error", error = "Provider configuration error" })
return
end
---@type string ---@type string
local current_event_state = nil local current_event_state = nil

View File

@@ -20,6 +20,8 @@ M.api_key_name = "AZURE_OPENAI_API_KEY"
-- Inherit from OpenAI class -- Inherit from OpenAI class
setmetatable(M, { __index = O }) setmetatable(M, { __index = O })
---@param prompt_opts AvantePromptOptions
---@return AvanteCurlOutput|nil
function M:parse_curl_args(prompt_opts) function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = P.parse_config(self) local provider_conf, request_body = P.parse_config(self)
local disable_tools = provider_conf.disable_tools or false local disable_tools = provider_conf.disable_tools or false
@@ -29,10 +31,15 @@ function M:parse_curl_args(prompt_opts)
} }
if P.env.require_api_key(provider_conf) then if P.env.require_api_key(provider_conf) then
local api_key = self.parse_api_key()
if not api_key then
Utils.error("Azure: API key is not set. Please set " .. M.api_key_name)
return nil
end
if provider_conf.entra then if provider_conf.entra then
headers["Authorization"] = "Bearer " .. self.parse_api_key() headers["Authorization"] = "Bearer " .. api_key
else else
headers["api-key"] = self.parse_api_key() headers["api-key"] = api_key
end end
end end

View File

@@ -125,7 +125,7 @@ function M:parse_response_without_stream(data, event_state, opts)
end end
---@param prompt_opts AvantePromptOptions ---@param prompt_opts AvantePromptOptions
---@return table ---@return AvanteCurlOutput|nil
function M:parse_curl_args(prompt_opts) function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = P.parse_config(self) local provider_conf, request_body = P.parse_config(self)
@@ -136,9 +136,12 @@ function M:parse_curl_args(prompt_opts)
if profile ~= nil and profile ~= "" then if profile ~= nil and profile ~= "" then
---@diagnostic disable-next-line: undefined-field ---@diagnostic disable-next-line: undefined-field
region = provider_conf.aws_region region = provider_conf.aws_region
if not region or region == "" then
Utils.error("Bedrock: no aws_region specified in bedrock config")
return nil
end
local awsCreds = M:get_aws_credentials(region, profile) local awsCreds = M:get_aws_credentials(region, profile)
if not region or region == "" then error("No aws_region specified in bedrock config") end
access_key_id = awsCreds.access_key_id access_key_id = awsCreds.access_key_id
secret_access_key = awsCreds.secret_access_key secret_access_key = awsCreds.secret_access_key
@@ -153,7 +156,8 @@ function M:parse_curl_args(prompt_opts)
region = parts[3] region = parts[3]
session_token = parts[4] session_token = parts[4]
else else
error("API key not set correctly") Utils.error("Bedrock: API key not set correctly")
return nil
end end
end end

View File

@@ -315,7 +315,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
end end
---@param prompt_opts AvantePromptOptions ---@param prompt_opts AvantePromptOptions
---@return table ---@return AvanteCurlOutput|nil
function M:parse_curl_args(prompt_opts) function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = P.parse_config(self) local provider_conf, request_body = P.parse_config(self)
local disable_tools = provider_conf.disable_tools or false local disable_tools = provider_conf.disable_tools or false
@@ -326,7 +326,14 @@ function M:parse_curl_args(prompt_opts)
["anthropic-beta"] = "prompt-caching-2024-07-31", ["anthropic-beta"] = "prompt-caching-2024-07-31",
} }
if P.env.require_api_key(provider_conf) then headers["x-api-key"] = self.parse_api_key() end if P.env.require_api_key(provider_conf) then
local api_key = self.parse_api_key()
if not api_key then
Utils.error("Claude: API key is not set. Please set " .. M.api_key_name)
return nil
end
headers["x-api-key"] = api_key
end
local messages = self:parse_messages(prompt_opts) local messages = self:parse_messages(prompt_opts)

View File

@@ -76,6 +76,8 @@ function M:parse_stream_data(ctx, data, opts)
end end
end end
---@param prompt_opts AvantePromptOptions
---@return AvanteCurlOutput|nil
function M:parse_curl_args(prompt_opts) function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = P.parse_config(self) local provider_conf, request_body = P.parse_config(self)
@@ -89,7 +91,14 @@ function M:parse_curl_args(prompt_opts)
.. "." .. "."
.. vim.version().patch, .. vim.version().patch,
} }
if P.env.require_api_key(provider_conf) then headers["Authorization"] = "Bearer " .. self.parse_api_key() end if P.env.require_api_key(provider_conf) then
local api_key = self.parse_api_key()
if not api_key then
Utils.error("Cohere: API key is not set. Please set " .. M.api_key_name)
return nil
end
headers["Authorization"] = "Bearer " .. api_key
end
return { return {
url = Utils.url_join(provider_conf.endpoint, "/chat"), url = Utils.url_join(provider_conf.endpoint, "/chat"),

View File

@@ -312,11 +312,16 @@ function M:parse_response(ctx, data_stream, _, opts)
end end
end end
---@param prompt_opts AvantePromptOptions
---@return AvanteCurlOutput|nil
function M:parse_curl_args(prompt_opts) function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = Providers.parse_config(self) local provider_conf, request_body = Providers.parse_config(self)
local api_key = self:parse_api_key() local api_key = self:parse_api_key()
if api_key == nil then error("Cannot get the gemini api key!") end if api_key == nil then
Utils.error("Gemini: API key is not set. Please set " .. M.api_key_name)
return nil
end
return { return {
url = Utils.url_join( url = Utils.url_join(

View File

@@ -190,12 +190,21 @@ function M:parse_stream_data(ctx, data, opts)
end end
---@param prompt_opts AvantePromptOptions ---@param prompt_opts AvantePromptOptions
---@return AvanteCurlOutput|nil
function M:parse_curl_args(prompt_opts) function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = Providers.parse_config(self) local provider_conf, request_body = Providers.parse_config(self)
local keep_alive = provider_conf.keep_alive or "5m" local keep_alive = provider_conf.keep_alive or "5m"
if not provider_conf.model or provider_conf.model == "" then error("Ollama model must be specified in config") end if not provider_conf.model or provider_conf.model == "" then
if not provider_conf.endpoint then error("Ollama requires endpoint configuration") end Utils.error("Ollama: model must be specified in config")
return nil
end
if not provider_conf.endpoint then
Utils.error("Ollama: endpoint must be specified in config")
return nil
end
local headers = { local headers = {
["Content-Type"] = "application/json", ["Content-Type"] = "application/json",
["Accept"] = "application/json", ["Accept"] = "application/json",

View File

@@ -504,6 +504,8 @@ function M:parse_response_without_stream(data, _, opts)
end end
end end
---@param prompt_opts AvantePromptOptions
---@return AvanteCurlOutput|nil
function M:parse_curl_args(prompt_opts) function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = Providers.parse_config(self) local provider_conf, request_body = Providers.parse_config(self)
local disable_tools = provider_conf.disable_tools or false local disable_tools = provider_conf.disable_tools or false
@@ -515,7 +517,8 @@ function M:parse_curl_args(prompt_opts)
if Providers.env.require_api_key(provider_conf) then if Providers.env.require_api_key(provider_conf) then
local api_key = self.parse_api_key() local api_key = self.parse_api_key()
if api_key == nil then if api_key == nil then
error(Config.provider .. " API key is not set, please set it in your environment variable or config file") Utils.error(Config.provider .. ": API key is not set, please set it in your environment variable or config file")
return nil
end end
headers["Authorization"] = "Bearer " .. api_key headers["Authorization"] = "Bearer " .. api_key
end end

View File

@@ -230,7 +230,7 @@ vim.g.avante_login = vim.g.avante_login
---@alias AvanteMessagesParser fun(self: AvanteProviderFunctor, opts: AvantePromptOptions): AvanteChatMessage[] ---@alias AvanteMessagesParser fun(self: AvanteProviderFunctor, opts: AvantePromptOptions): AvanteChatMessage[]
--- ---
---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table<string, any> | string, headers: table<string, string>, rawArgs: string[] | nil} ---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table<string, any> | string, headers: table<string, string>, rawArgs: string[] | nil}
---@alias AvanteCurlArgsParser fun(self: AvanteProviderFunctor, prompt_opts: AvantePromptOptions): AvanteCurlOutput ---@alias AvanteCurlArgsParser fun(self: AvanteProviderFunctor, prompt_opts: AvantePromptOptions): (AvanteCurlOutput | nil)
--- ---
---@alias AvanteResponseParser fun(self: AvanteProviderFunctor, ctx: any, data_stream: string, event_state: string, opts: AvanteHandlerOptions): nil ---@alias AvanteResponseParser fun(self: AvanteProviderFunctor, ctx: any, data_stream: string, event_state: string, opts: AvanteHandlerOptions): nil
--- ---