refactor: remove redundant local field to facilitate provider configuration (#858)

This commit is contained in:
yetone
2024-11-17 02:55:40 +08:00
committed by GitHub
parent 4acdcb6e8b
commit ff85b9c1e2
8 changed files with 39 additions and 40 deletions

View File

@@ -108,32 +108,32 @@ E.parse_envvar = function(Opts)
local cmd = type(api_key_name) == "table" and api_key_name or api_key_name:match("^cmd:(.*)")
local key = nil
local value = nil
if cmd ~= nil then
-- NOTE: in case api_key_name is cmd, and users still set envvar
-- We will try to get envvar first
if Opts._shellenv ~= nil and Opts._shellenv ~= M.AVANTE_INTERNAL_KEY then
key = os.getenv(Opts._shellenv)
if key ~= nil then
---@diagnostic disable: no-unknown
E.cache[Opts._shellenv] = key
E.cache[cache_key] = key
if Opts._shellenv ~= nil and Opts._shellenv ~= "" then
value = os.getenv(Opts._shellenv)
if value ~= nil then
E.cache[cache_key] = value
vim.g.avante_login = true
return key
return value
end
end
if type(cmd) == "string" then cmd = vim.split(cmd, " ", { trimempty = true }) end
Utils.debug("running command:", cmd)
local exit_codes = { 0 }
local ok, job_or_err = pcall(vim.system, cmd, { text = true }, function(result)
Utils.debug("command result:", result)
local code = result.code
local stderr = result.stderr or ""
local stdout = result.stdout and vim.split(result.stdout, "\n") or {}
if vim.tbl_contains(exit_codes, code) then
key = stdout[1]
E.cache[cache_key] = key
value = stdout[1]
E.cache[cache_key] = value
vim.g.avante_login = true
else
Utils.error("Failed to get API key: (error code" .. code .. ")\n" .. stderr, { once = true, title = "Avante" })
@@ -145,15 +145,15 @@ E.parse_envvar = function(Opts)
return
end
else
key = os.getenv(api_key_name)
value = os.getenv(api_key_name)
end
if key ~= nil then
E.cache[cache_key] = key
if value ~= nil then
E.cache[cache_key] = value
vim.g.avante_login = true
end
return key
return value
end
--- initialize the environment variable for current neovim session.
@@ -161,17 +161,17 @@ end
---@param opts {refresh: boolean, provider: AvanteProviderFunctor}
---@private
E.setup = function(opts)
if opts.provider["local"] then
local var = opts.provider.api_key_name
if var == nil or var == "" then
vim.g.avante_login = true
return
end
local var = opts.provider.api_key_name
opts.provider.setup()
-- check if var is a all caps string
if var == M.AVANTE_INTERNAL_KEY or type(var) == "table" or var:match("^cmd:(.*)") then return end
if type(var) == "table" or var:match("^cmd:(.*)") then return end
local refresh = opts.refresh or false
@@ -248,16 +248,21 @@ end
E.REQUEST_LOGIN_PATTERN = "AvanteRequestLogin"
---@param provider Provider
E.is_local = function(provider)
local cur = M.get_config(provider)
return cur["local"] ~= nil and cur["local"] or false
---@param provider AvanteDefaultBaseProvider
E.require_api_key = function(provider)
if provider["local"] ~= nil then
if provider["local"] then
vim.deprecate('"local" = true', "api_key_name = ''", "0.1.0", "avante.nvim")
else
vim.deprecate('"local" = false', "api_key_name", "0.1.0", "avante.nvim")
end
return not provider["local"]
end
return provider.api_key_name ~= nil and provider.api_key_name ~= ""
end
M.env = E
M.AVANTE_INTERNAL_KEY = "__avante_env_internal"
M = setmetatable(M, {
---@param t avante.Providers
---@param k Provider
@@ -272,7 +277,6 @@ M = setmetatable(M, {
local BaseOpts = M.get_config(Opts.__inherited_from)
local ok, module = pcall(require, "avante.providers." .. Opts.__inherited_from)
if not ok then error("Failed to load provider: " .. Opts.__inherited_from) end
Opts._shellenv = module.api_key_name ~= M.AVANTE_INTERNAL_KEY and module.api_key_name or nil
t[k] = vim.tbl_deep_extend("keep", Opts, BaseOpts, module)
else
t[k] = Opts
@@ -280,7 +284,6 @@ M = setmetatable(M, {
else
local ok, module = pcall(require, "avante.providers." .. k)
if not ok then error("Failed to load provider: " .. k) end
Opts._shellenv = module.api_key_name ~= M.AVANTE_INTERNAL_KEY and module.api_key_name or nil
t[k] = vim.tbl_deep_extend("keep", Opts, module)
end
@@ -294,8 +297,9 @@ M = setmetatable(M, {
if t[k].has == nil then t[k].has = function() return E.parse_envvar(t[k]) ~= nil end end
if t[k].setup == nil then
local base = M.parse_config(t[k])
t[k].setup = function()
if not E.is_local(k) then t[k].parse_api_key() end
if E.require_api_key(base) then t[k].parse_api_key() end
require("avante.tokenizers").setup(t[k].tokenizer_id)
end
end