273 lines
8.5 KiB
Lua
273 lines
8.5 KiB
Lua
local api, fn = vim.api, vim.fn
|
|
|
|
local Config = require("avante.config")
|
|
local Utils = require("avante.utils")
|
|
|
|
---@class avante.Providers
|
|
---@field openai AvanteProviderFunctor
|
|
---@field claude AvanteProviderFunctor
|
|
---@field copilot AvanteProviderFunctor
|
|
---@field azure AvanteProviderFunctor
|
|
---@field gemini AvanteProviderFunctor
|
|
---@field cohere AvanteProviderFunctor
|
|
---@field bedrock AvanteBedrockProviderFunctor
|
|
---@field ollama AvanteProviderFunctor
|
|
---@field vertex_claude AvanteProviderFunctor
|
|
---@field watsonx_code_assistant AvanteProviderFunctor
|
|
local M = {}
|
|
|
|
---@class EnvironmentHandler
|
|
local E = {}
|
|
|
|
---@private
|
|
---@type table<string, string>
|
|
E.cache = {}
|
|
|
|
---@param Opts AvanteSupportedProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor
|
|
---@return string | nil
|
|
function E.parse_envvar(Opts)
|
|
-- First try the scoped version (e.g., AVANTE_ANTHROPIC_API_KEY)
|
|
local scoped_key_name = nil
|
|
if Opts.api_key_name and type(Opts.api_key_name) == "string" and Opts.api_key_name ~= "" then
|
|
-- Only add AVANTE_ prefix if it's a regular environment variable (not a cmd: or already prefixed)
|
|
if not Opts.api_key_name:match("^cmd:") and not Opts.api_key_name:match("^AVANTE_") then
|
|
scoped_key_name = "AVANTE_" .. Opts.api_key_name
|
|
end
|
|
end
|
|
|
|
-- Try scoped key first if available
|
|
if scoped_key_name then
|
|
local scoped_value = Utils.environment.parse(scoped_key_name, Opts._shellenv)
|
|
if scoped_value ~= nil then
|
|
vim.g.avante_login = true
|
|
return scoped_value
|
|
end
|
|
end
|
|
|
|
-- Fall back to the original global key
|
|
local value = Utils.environment.parse(Opts.api_key_name, Opts._shellenv)
|
|
if value ~= nil then
|
|
vim.g.avante_login = true
|
|
return value
|
|
end
|
|
|
|
return nil
|
|
end
|
|
|
|
--- initialize the environment variable for current neovim session.
|
|
--- This will only run once and spawn a UI for users to input the envvar.
|
|
---@param opts {refresh: boolean, provider: AvanteProviderFunctor | AvanteBedrockProviderFunctor}
|
|
---@private
|
|
function E.setup(opts)
|
|
opts.provider.setup()
|
|
|
|
local var = opts.provider.api_key_name
|
|
|
|
if var == nil or var == "" then
|
|
vim.g.avante_login = true
|
|
return
|
|
end
|
|
|
|
if type(var) ~= "table" and vim.env[var] ~= nil then
|
|
vim.g.avante_login = true
|
|
return
|
|
end
|
|
|
|
-- check if var is a all caps string
|
|
if type(var) == "table" or var:match("^cmd:(.*)") then return end
|
|
|
|
local refresh = opts.refresh or false
|
|
|
|
---@param value string
|
|
---@return nil
|
|
local function on_confirm(value)
|
|
if value then
|
|
vim.fn.setenv(var, value)
|
|
vim.g.avante_login = true
|
|
else
|
|
if not opts.provider.is_env_set() then
|
|
Utils.warn("Failed to set " .. var .. ". Avante won't work as expected", { once = true })
|
|
end
|
|
end
|
|
end
|
|
|
|
local function mount_input_ui()
|
|
vim.defer_fn(function()
|
|
-- only mount if given buffer is not of buftype ministarter, dashboard, alpha, qf
|
|
local exclude_filetypes = {
|
|
"NvimTree",
|
|
"Outline",
|
|
"help",
|
|
"dashboard",
|
|
"alpha",
|
|
"qf",
|
|
"ministarter",
|
|
"TelescopePrompt",
|
|
"gitcommit",
|
|
"gitrebase",
|
|
"DressingInput",
|
|
"snacks_input",
|
|
"noice",
|
|
}
|
|
|
|
if not vim.tbl_contains(exclude_filetypes, vim.bo.filetype) and not opts.provider.is_env_set() then
|
|
local Input = require("avante.ui.input")
|
|
local input = Input:new({
|
|
provider = Config.input.provider,
|
|
title = "Enter " .. var .. ": ",
|
|
default = "",
|
|
conceal = true, -- Password input should be concealed
|
|
provider_opts = Config.input.provider_opts,
|
|
on_submit = on_confirm,
|
|
})
|
|
input:open()
|
|
end
|
|
end, 200)
|
|
end
|
|
|
|
if refresh then return mount_input_ui() end
|
|
|
|
api.nvim_create_autocmd("User", {
|
|
pattern = E.REQUEST_LOGIN_PATTERN,
|
|
callback = mount_input_ui,
|
|
})
|
|
end
|
|
|
|
E.REQUEST_LOGIN_PATTERN = "AvanteRequestLogin"
|
|
|
|
---@param provider AvanteDefaultBaseProvider
|
|
function E.require_api_key(provider) return provider.api_key_name ~= nil and provider.api_key_name ~= "" end
|
|
|
|
M.env = E
|
|
|
|
M = setmetatable(M, {
|
|
---@param t avante.Providers
|
|
---@param k avante.ProviderName
|
|
__index = function(t, k)
|
|
if Config.providers[k] == nil then error("Failed to find provider: " .. k, 2) end
|
|
|
|
local provider_config = M.get_config(k)
|
|
|
|
if provider_config.__inherited_from ~= nil then
|
|
local base_provider_config = M.get_config(provider_config.__inherited_from)
|
|
local ok, module = pcall(require, "avante.providers." .. provider_config.__inherited_from)
|
|
if not ok then error("Failed to load provider: " .. provider_config.__inherited_from, 2) end
|
|
provider_config = Utils.deep_extend_with_metatable("force", module, base_provider_config, provider_config)
|
|
else
|
|
local ok, module = pcall(require, "avante.providers." .. k)
|
|
if ok then
|
|
provider_config = Utils.deep_extend_with_metatable("force", module, provider_config)
|
|
elseif provider_config.parse_curl_args == nil then
|
|
error(
|
|
string.format(
|
|
'The configuration of your provider "%s" is incorrect, missing the `__inherited_from` attribute or a custom `parse_curl_args` function. Please fix your provider configuration. For more details, see: https://github.com/yetone/avante.nvim/wiki/Custom-providers',
|
|
k
|
|
)
|
|
)
|
|
end
|
|
end
|
|
|
|
t[k] = provider_config
|
|
|
|
if rawget(t[k], "parse_api_key") == nil then t[k].parse_api_key = function() return E.parse_envvar(t[k]) end end
|
|
|
|
-- default to gpt-4o as tokenizer
|
|
if t[k].tokenizer_id == nil then t[k].tokenizer_id = "gpt-4o" end
|
|
|
|
if rawget(t[k], "is_env_set") == nil then
|
|
t[k].is_env_set = function()
|
|
if not E.require_api_key(t[k]) then return true end
|
|
if type(t[k].api_key_name) == "string" and t[k].api_key_name:match("^cmd:") then return true end
|
|
local ok, result = pcall(t[k].parse_api_key)
|
|
if not ok then return false end
|
|
return result ~= nil
|
|
end
|
|
end
|
|
|
|
if rawget(t[k], "setup") == nil then
|
|
local provider_conf = M.parse_config(t[k])
|
|
t[k].setup = function()
|
|
if E.require_api_key(provider_conf) then
|
|
if not (type(provider_conf.api_key_name) == "string" and provider_conf.api_key_name:match("^cmd:")) then
|
|
t[k].parse_api_key()
|
|
end
|
|
end
|
|
require("avante.tokenizers").setup(t[k].tokenizer_id)
|
|
end
|
|
end
|
|
|
|
return t[k]
|
|
end,
|
|
})
|
|
|
|
function M.setup()
|
|
vim.g.avante_login = false
|
|
|
|
if Config.acp_providers[Config.provider] then return end
|
|
|
|
---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor
|
|
local provider = M[Config.provider]
|
|
|
|
E.setup({ provider = provider })
|
|
|
|
if Config.auto_suggestions_provider then
|
|
local auto_suggestions_provider = M[Config.auto_suggestions_provider]
|
|
if auto_suggestions_provider and auto_suggestions_provider ~= provider then
|
|
E.setup({ provider = auto_suggestions_provider })
|
|
end
|
|
end
|
|
|
|
if Config.memory_summary_provider then
|
|
local memory_summary_provider = M[Config.memory_summary_provider]
|
|
if memory_summary_provider and memory_summary_provider ~= provider then
|
|
E.setup({ provider = memory_summary_provider })
|
|
end
|
|
end
|
|
end
|
|
|
|
---@param provider_name avante.ProviderName
|
|
function M.refresh(provider_name)
|
|
require("avante.config").override({ provider = provider_name })
|
|
|
|
if Config.acp_providers[provider_name] then
|
|
Config.provider = provider_name
|
|
else
|
|
---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor
|
|
local p = M[Config.provider]
|
|
E.setup({ provider = p, refresh = true })
|
|
end
|
|
Utils.info("Switch to provider: " .. provider_name, { once = true, title = "Avante" })
|
|
end
|
|
|
|
---@param opts AvanteProvider | AvanteSupportedProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor
|
|
---@return AvanteDefaultBaseProvider provider_opts
|
|
---@return table<string, any> request_body
|
|
function M.parse_config(opts)
|
|
---@type AvanteDefaultBaseProvider
|
|
local provider_opts = {}
|
|
|
|
for key, value in pairs(opts) do
|
|
if key ~= "extra_request_body" then provider_opts[key] = value end
|
|
end
|
|
|
|
---@type table<string, any>
|
|
local request_body = opts.extra_request_body or {}
|
|
|
|
return provider_opts, request_body
|
|
end
|
|
|
|
---@param provider_name avante.ProviderName
|
|
function M.get_config(provider_name)
|
|
provider_name = provider_name or Config.provider
|
|
local cur = Config.get_provider_config(provider_name)
|
|
return type(cur) == "function" and cur() or cur
|
|
end
|
|
|
|
function M.get_memory_summary_provider()
|
|
local provider_name = Config.memory_summary_provider
|
|
if provider_name == nil then provider_name = Config.provider end
|
|
return M[provider_name]
|
|
end
|
|
|
|
return M
|