feat(providers): fail gracefully when a provider is misconfigured (#2768)
This commit is contained in:
@@ -548,8 +548,11 @@ function M.curl(opts)
|
|||||||
if orig_on_stop then return orig_on_stop(stop_opts) end
|
if orig_on_stop then return orig_on_stop(stop_opts) end
|
||||||
end
|
end
|
||||||
|
|
||||||
---@type AvanteCurlOutput
|
|
||||||
local spec = provider:parse_curl_args(prompt_opts)
|
local spec = provider:parse_curl_args(prompt_opts)
|
||||||
|
if not spec then
|
||||||
|
handler_opts.on_stop({ reason = "error", error = "Provider configuration error" })
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
---@type string
|
---@type string
|
||||||
local current_event_state = nil
|
local current_event_state = nil
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ M.api_key_name = "AZURE_OPENAI_API_KEY"
|
|||||||
-- Inherit from OpenAI class
|
-- Inherit from OpenAI class
|
||||||
setmetatable(M, { __index = O })
|
setmetatable(M, { __index = O })
|
||||||
|
|
||||||
|
---@param prompt_opts AvantePromptOptions
|
||||||
|
---@return AvanteCurlOutput|nil
|
||||||
function M:parse_curl_args(prompt_opts)
|
function M:parse_curl_args(prompt_opts)
|
||||||
local provider_conf, request_body = P.parse_config(self)
|
local provider_conf, request_body = P.parse_config(self)
|
||||||
local disable_tools = provider_conf.disable_tools or false
|
local disable_tools = provider_conf.disable_tools or false
|
||||||
@@ -29,10 +31,15 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if P.env.require_api_key(provider_conf) then
|
if P.env.require_api_key(provider_conf) then
|
||||||
|
local api_key = self.parse_api_key()
|
||||||
|
if not api_key then
|
||||||
|
Utils.error("Azure: API key is not set. Please set " .. M.api_key_name)
|
||||||
|
return nil
|
||||||
|
end
|
||||||
if provider_conf.entra then
|
if provider_conf.entra then
|
||||||
headers["Authorization"] = "Bearer " .. self.parse_api_key()
|
headers["Authorization"] = "Bearer " .. api_key
|
||||||
else
|
else
|
||||||
headers["api-key"] = self.parse_api_key()
|
headers["api-key"] = api_key
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
@@ -125,7 +125,7 @@ function M:parse_response_without_stream(data, event_state, opts)
|
|||||||
end
|
end
|
||||||
|
|
||||||
---@param prompt_opts AvantePromptOptions
|
---@param prompt_opts AvantePromptOptions
|
||||||
---@return table
|
---@return AvanteCurlOutput|nil
|
||||||
function M:parse_curl_args(prompt_opts)
|
function M:parse_curl_args(prompt_opts)
|
||||||
local provider_conf, request_body = P.parse_config(self)
|
local provider_conf, request_body = P.parse_config(self)
|
||||||
|
|
||||||
@@ -136,9 +136,12 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
if profile ~= nil and profile ~= "" then
|
if profile ~= nil and profile ~= "" then
|
||||||
---@diagnostic disable-next-line: undefined-field
|
---@diagnostic disable-next-line: undefined-field
|
||||||
region = provider_conf.aws_region
|
region = provider_conf.aws_region
|
||||||
|
if not region or region == "" then
|
||||||
|
Utils.error("Bedrock: no aws_region specified in bedrock config")
|
||||||
|
return nil
|
||||||
|
end
|
||||||
|
|
||||||
local awsCreds = M:get_aws_credentials(region, profile)
|
local awsCreds = M:get_aws_credentials(region, profile)
|
||||||
if not region or region == "" then error("No aws_region specified in bedrock config") end
|
|
||||||
|
|
||||||
access_key_id = awsCreds.access_key_id
|
access_key_id = awsCreds.access_key_id
|
||||||
secret_access_key = awsCreds.secret_access_key
|
secret_access_key = awsCreds.secret_access_key
|
||||||
@@ -153,7 +156,8 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
region = parts[3]
|
region = parts[3]
|
||||||
session_token = parts[4]
|
session_token = parts[4]
|
||||||
else
|
else
|
||||||
error("API key not set correctly")
|
Utils.error("Bedrock: API key not set correctly")
|
||||||
|
return nil
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
@@ -315,7 +315,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
|||||||
end
|
end
|
||||||
|
|
||||||
---@param prompt_opts AvantePromptOptions
|
---@param prompt_opts AvantePromptOptions
|
||||||
---@return table
|
---@return AvanteCurlOutput|nil
|
||||||
function M:parse_curl_args(prompt_opts)
|
function M:parse_curl_args(prompt_opts)
|
||||||
local provider_conf, request_body = P.parse_config(self)
|
local provider_conf, request_body = P.parse_config(self)
|
||||||
local disable_tools = provider_conf.disable_tools or false
|
local disable_tools = provider_conf.disable_tools or false
|
||||||
@@ -326,7 +326,14 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
["anthropic-beta"] = "prompt-caching-2024-07-31",
|
["anthropic-beta"] = "prompt-caching-2024-07-31",
|
||||||
}
|
}
|
||||||
|
|
||||||
if P.env.require_api_key(provider_conf) then headers["x-api-key"] = self.parse_api_key() end
|
if P.env.require_api_key(provider_conf) then
|
||||||
|
local api_key = self.parse_api_key()
|
||||||
|
if not api_key then
|
||||||
|
Utils.error("Claude: API key is not set. Please set " .. M.api_key_name)
|
||||||
|
return nil
|
||||||
|
end
|
||||||
|
headers["x-api-key"] = api_key
|
||||||
|
end
|
||||||
|
|
||||||
local messages = self:parse_messages(prompt_opts)
|
local messages = self:parse_messages(prompt_opts)
|
||||||
|
|
||||||
|
|||||||
@@ -76,6 +76,8 @@ function M:parse_stream_data(ctx, data, opts)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
---@param prompt_opts AvantePromptOptions
|
||||||
|
---@return AvanteCurlOutput|nil
|
||||||
function M:parse_curl_args(prompt_opts)
|
function M:parse_curl_args(prompt_opts)
|
||||||
local provider_conf, request_body = P.parse_config(self)
|
local provider_conf, request_body = P.parse_config(self)
|
||||||
|
|
||||||
@@ -89,7 +91,14 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
.. "."
|
.. "."
|
||||||
.. vim.version().patch,
|
.. vim.version().patch,
|
||||||
}
|
}
|
||||||
if P.env.require_api_key(provider_conf) then headers["Authorization"] = "Bearer " .. self.parse_api_key() end
|
if P.env.require_api_key(provider_conf) then
|
||||||
|
local api_key = self.parse_api_key()
|
||||||
|
if not api_key then
|
||||||
|
Utils.error("Cohere: API key is not set. Please set " .. M.api_key_name)
|
||||||
|
return nil
|
||||||
|
end
|
||||||
|
headers["Authorization"] = "Bearer " .. api_key
|
||||||
|
end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
url = Utils.url_join(provider_conf.endpoint, "/chat"),
|
url = Utils.url_join(provider_conf.endpoint, "/chat"),
|
||||||
|
|||||||
@@ -312,11 +312,16 @@ function M:parse_response(ctx, data_stream, _, opts)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
---@param prompt_opts AvantePromptOptions
|
||||||
|
---@return AvanteCurlOutput|nil
|
||||||
function M:parse_curl_args(prompt_opts)
|
function M:parse_curl_args(prompt_opts)
|
||||||
local provider_conf, request_body = Providers.parse_config(self)
|
local provider_conf, request_body = Providers.parse_config(self)
|
||||||
|
|
||||||
local api_key = self:parse_api_key()
|
local api_key = self:parse_api_key()
|
||||||
if api_key == nil then error("Cannot get the gemini api key!") end
|
if api_key == nil then
|
||||||
|
Utils.error("Gemini: API key is not set. Please set " .. M.api_key_name)
|
||||||
|
return nil
|
||||||
|
end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
url = Utils.url_join(
|
url = Utils.url_join(
|
||||||
|
|||||||
@@ -190,12 +190,21 @@ function M:parse_stream_data(ctx, data, opts)
|
|||||||
end
|
end
|
||||||
|
|
||||||
---@param prompt_opts AvantePromptOptions
|
---@param prompt_opts AvantePromptOptions
|
||||||
|
---@return AvanteCurlOutput|nil
|
||||||
function M:parse_curl_args(prompt_opts)
|
function M:parse_curl_args(prompt_opts)
|
||||||
local provider_conf, request_body = Providers.parse_config(self)
|
local provider_conf, request_body = Providers.parse_config(self)
|
||||||
local keep_alive = provider_conf.keep_alive or "5m"
|
local keep_alive = provider_conf.keep_alive or "5m"
|
||||||
|
|
||||||
if not provider_conf.model or provider_conf.model == "" then error("Ollama model must be specified in config") end
|
if not provider_conf.model or provider_conf.model == "" then
|
||||||
if not provider_conf.endpoint then error("Ollama requires endpoint configuration") end
|
Utils.error("Ollama: model must be specified in config")
|
||||||
|
return nil
|
||||||
|
end
|
||||||
|
|
||||||
|
if not provider_conf.endpoint then
|
||||||
|
Utils.error("Ollama: endpoint must be specified in config")
|
||||||
|
return nil
|
||||||
|
end
|
||||||
|
|
||||||
local headers = {
|
local headers = {
|
||||||
["Content-Type"] = "application/json",
|
["Content-Type"] = "application/json",
|
||||||
["Accept"] = "application/json",
|
["Accept"] = "application/json",
|
||||||
|
|||||||
@@ -504,6 +504,8 @@ function M:parse_response_without_stream(data, _, opts)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
---@param prompt_opts AvantePromptOptions
|
||||||
|
---@return AvanteCurlOutput|nil
|
||||||
function M:parse_curl_args(prompt_opts)
|
function M:parse_curl_args(prompt_opts)
|
||||||
local provider_conf, request_body = Providers.parse_config(self)
|
local provider_conf, request_body = Providers.parse_config(self)
|
||||||
local disable_tools = provider_conf.disable_tools or false
|
local disable_tools = provider_conf.disable_tools or false
|
||||||
@@ -515,7 +517,8 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
if Providers.env.require_api_key(provider_conf) then
|
if Providers.env.require_api_key(provider_conf) then
|
||||||
local api_key = self.parse_api_key()
|
local api_key = self.parse_api_key()
|
||||||
if api_key == nil then
|
if api_key == nil then
|
||||||
error(Config.provider .. " API key is not set, please set it in your environment variable or config file")
|
Utils.error(Config.provider .. ": API key is not set, please set it in your environment variable or config file")
|
||||||
|
return nil
|
||||||
end
|
end
|
||||||
headers["Authorization"] = "Bearer " .. api_key
|
headers["Authorization"] = "Bearer " .. api_key
|
||||||
end
|
end
|
||||||
|
|||||||
@@ -230,7 +230,7 @@ vim.g.avante_login = vim.g.avante_login
|
|||||||
---@alias AvanteMessagesParser fun(self: AvanteProviderFunctor, opts: AvantePromptOptions): AvanteChatMessage[]
|
---@alias AvanteMessagesParser fun(self: AvanteProviderFunctor, opts: AvantePromptOptions): AvanteChatMessage[]
|
||||||
---
|
---
|
||||||
---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table<string, any> | string, headers: table<string, string>, rawArgs: string[] | nil}
|
---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table<string, any> | string, headers: table<string, string>, rawArgs: string[] | nil}
|
||||||
---@alias AvanteCurlArgsParser fun(self: AvanteProviderFunctor, prompt_opts: AvantePromptOptions): AvanteCurlOutput
|
---@alias AvanteCurlArgsParser fun(self: AvanteProviderFunctor, prompt_opts: AvantePromptOptions): (AvanteCurlOutput | nil)
|
||||||
---
|
---
|
||||||
---@alias AvanteResponseParser fun(self: AvanteProviderFunctor, ctx: any, data_stream: string, event_state: string, opts: AvanteHandlerOptions): nil
|
---@alias AvanteResponseParser fun(self: AvanteProviderFunctor, ctx: any, data_stream: string, event_state: string, opts: AvanteHandlerOptions): nil
|
||||||
---
|
---
|
||||||
|
|||||||
Reference in New Issue
Block a user