diff --git a/lua/avante/model_selector.lua b/lua/avante/model_selector.lua index d6d4392..dc2f1e6 100644 --- a/lua/avante/model_selector.lua +++ b/lua/avante/model_selector.lua @@ -12,25 +12,25 @@ M.models_list_returned = {} local models_list_cached_result = {} ---@param provider_name string ----@param cfg table +---@param provider_cfg table ---@return table -local function create_model_entries(provider_name, cfg) - if cfg.models_list then +local function create_model_entries(provider_name, provider_cfg) + if provider_cfg.models_list then local models_list - if type(cfg.models_list) == "function" then - if M.models_list_invoked[cfg.models_list] then return {} end - M.models_list_invoked[cfg.models_list] = true - local cached_result = models_list_cached_result[cfg.models_list] + if type(provider_cfg.models_list) == "function" then + if M.models_list_invoked[provider_cfg.models_list] then return {} end + M.models_list_invoked[provider_cfg.models_list] = true + local cached_result = models_list_cached_result[provider_cfg.models_list] if cached_result then models_list = cached_result else - models_list = cfg.models_list() - models_list_cached_result[cfg.models_list] = models_list + models_list = provider_cfg.models_list() + models_list_cached_result[provider_cfg.models_list] = models_list end else - if M.models_list_returned[cfg.models_list] then return {} end - M.models_list_returned[cfg.models_list] = true - models_list = cfg.models_list + if M.models_list_returned[provider_cfg.models_list] then return {} end + M.models_list_returned[provider_cfg.models_list] = true + models_list = provider_cfg.models_list end if not models_list then return {} end -- If models_list is defined, use it to create entries @@ -49,12 +49,12 @@ local function create_model_entries(provider_name, cfg) :totable() return models end - return cfg.model + return provider_cfg.model and { { - name = cfg.display_name or (provider_name .. "/" .. cfg.model), + name = provider_cfg.display_name or (provider_name .. "/" .. provider_cfg.model), provider_name = provider_name, - model = cfg.model, + model = provider_cfg.model, }, } or {} @@ -67,13 +67,14 @@ function M.open() -- Collect models from main providers and vendors for _, provider_name in ipairs(Config.provider_names) do - local ok, cfg = pcall(function() return Providers[provider_name] end) + local ok, provider_cfg = pcall(function() return Providers[provider_name] end) if not ok then Utils.warn("Failed to load provider: " .. provider_name) goto continue end - if cfg.hide_in_model_selector then goto continue end - local entries = create_model_entries(provider_name, cfg) + if provider_cfg.hide_in_model_selector then goto continue end + if not provider_cfg.is_env_set() then goto continue end + local entries = create_model_entries(provider_name, provider_cfg) models = vim.list_extend(models, entries) ::continue:: end diff --git a/lua/avante/providers/copilot.lua b/lua/avante/providers/copilot.lua index ad258b4..00e48a4 100644 --- a/lua/avante/providers/copilot.lua +++ b/lua/avante/providers/copilot.lua @@ -30,7 +30,7 @@ local curl = require("plenary.curl") local Config = require("avante.config") local Path = require("plenary.path") local Utils = require("avante.utils") -local P = require("avante.providers") +local Providers = require("avante.providers") local OpenAI = require("avante.providers").openai local H = {} @@ -271,7 +271,7 @@ function M:parse_curl_args(prompt_opts) -- (this should rarely happen, as we refresh the token in the background) H.refresh_token(false, false) - local provider_conf, request_body = P.parse_config(self) + local provider_conf, request_body = Providers.parse_config(self) local disable_tools = provider_conf.disable_tools or false local tools = {} @@ -347,6 +347,11 @@ end M._is_setup = false +function M.is_env_set() + local ok = pcall(function() H.get_oauth_token() end) + return ok +end + function M.setup() local copilot_token_file = Path:new(copilot_path)