fix: list copilot models (#2115)
This commit is contained in:
@@ -12,25 +12,25 @@ M.models_list_returned = {}
|
|||||||
local models_list_cached_result = {}
|
local models_list_cached_result = {}
|
||||||
|
|
||||||
---@param provider_name string
|
---@param provider_name string
|
||||||
---@param cfg table
|
---@param provider_cfg table
|
||||||
---@return table
|
---@return table
|
||||||
local function create_model_entries(provider_name, cfg)
|
local function create_model_entries(provider_name, provider_cfg)
|
||||||
if cfg.models_list then
|
if provider_cfg.models_list then
|
||||||
local models_list
|
local models_list
|
||||||
if type(cfg.models_list) == "function" then
|
if type(provider_cfg.models_list) == "function" then
|
||||||
if M.models_list_invoked[cfg.models_list] then return {} end
|
if M.models_list_invoked[provider_cfg.models_list] then return {} end
|
||||||
M.models_list_invoked[cfg.models_list] = true
|
M.models_list_invoked[provider_cfg.models_list] = true
|
||||||
local cached_result = models_list_cached_result[cfg.models_list]
|
local cached_result = models_list_cached_result[provider_cfg.models_list]
|
||||||
if cached_result then
|
if cached_result then
|
||||||
models_list = cached_result
|
models_list = cached_result
|
||||||
else
|
else
|
||||||
models_list = cfg.models_list()
|
models_list = provider_cfg.models_list()
|
||||||
models_list_cached_result[cfg.models_list] = models_list
|
models_list_cached_result[provider_cfg.models_list] = models_list
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
if M.models_list_returned[cfg.models_list] then return {} end
|
if M.models_list_returned[provider_cfg.models_list] then return {} end
|
||||||
M.models_list_returned[cfg.models_list] = true
|
M.models_list_returned[provider_cfg.models_list] = true
|
||||||
models_list = cfg.models_list
|
models_list = provider_cfg.models_list
|
||||||
end
|
end
|
||||||
if not models_list then return {} end
|
if not models_list then return {} end
|
||||||
-- If models_list is defined, use it to create entries
|
-- If models_list is defined, use it to create entries
|
||||||
@@ -49,12 +49,12 @@ local function create_model_entries(provider_name, cfg)
|
|||||||
:totable()
|
:totable()
|
||||||
return models
|
return models
|
||||||
end
|
end
|
||||||
return cfg.model
|
return provider_cfg.model
|
||||||
and {
|
and {
|
||||||
{
|
{
|
||||||
name = cfg.display_name or (provider_name .. "/" .. cfg.model),
|
name = provider_cfg.display_name or (provider_name .. "/" .. provider_cfg.model),
|
||||||
provider_name = provider_name,
|
provider_name = provider_name,
|
||||||
model = cfg.model,
|
model = provider_cfg.model,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
or {}
|
or {}
|
||||||
@@ -67,13 +67,14 @@ function M.open()
|
|||||||
|
|
||||||
-- Collect models from main providers and vendors
|
-- Collect models from main providers and vendors
|
||||||
for _, provider_name in ipairs(Config.provider_names) do
|
for _, provider_name in ipairs(Config.provider_names) do
|
||||||
local ok, cfg = pcall(function() return Providers[provider_name] end)
|
local ok, provider_cfg = pcall(function() return Providers[provider_name] end)
|
||||||
if not ok then
|
if not ok then
|
||||||
Utils.warn("Failed to load provider: " .. provider_name)
|
Utils.warn("Failed to load provider: " .. provider_name)
|
||||||
goto continue
|
goto continue
|
||||||
end
|
end
|
||||||
if cfg.hide_in_model_selector then goto continue end
|
if provider_cfg.hide_in_model_selector then goto continue end
|
||||||
local entries = create_model_entries(provider_name, cfg)
|
if not provider_cfg.is_env_set() then goto continue end
|
||||||
|
local entries = create_model_entries(provider_name, provider_cfg)
|
||||||
models = vim.list_extend(models, entries)
|
models = vim.list_extend(models, entries)
|
||||||
::continue::
|
::continue::
|
||||||
end
|
end
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ local curl = require("plenary.curl")
|
|||||||
local Config = require("avante.config")
|
local Config = require("avante.config")
|
||||||
local Path = require("plenary.path")
|
local Path = require("plenary.path")
|
||||||
local Utils = require("avante.utils")
|
local Utils = require("avante.utils")
|
||||||
local P = require("avante.providers")
|
local Providers = require("avante.providers")
|
||||||
local OpenAI = require("avante.providers").openai
|
local OpenAI = require("avante.providers").openai
|
||||||
|
|
||||||
local H = {}
|
local H = {}
|
||||||
@@ -271,7 +271,7 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
-- (this should rarely happen, as we refresh the token in the background)
|
-- (this should rarely happen, as we refresh the token in the background)
|
||||||
H.refresh_token(false, false)
|
H.refresh_token(false, false)
|
||||||
|
|
||||||
local provider_conf, request_body = P.parse_config(self)
|
local provider_conf, request_body = Providers.parse_config(self)
|
||||||
local disable_tools = provider_conf.disable_tools or false
|
local disable_tools = provider_conf.disable_tools or false
|
||||||
|
|
||||||
local tools = {}
|
local tools = {}
|
||||||
@@ -347,6 +347,11 @@ end
|
|||||||
|
|
||||||
M._is_setup = false
|
M._is_setup = false
|
||||||
|
|
||||||
|
function M.is_env_set()
|
||||||
|
local ok = pcall(function() H.get_oauth_token() end)
|
||||||
|
return ok
|
||||||
|
end
|
||||||
|
|
||||||
function M.setup()
|
function M.setup()
|
||||||
local copilot_token_file = Path:new(copilot_path)
|
local copilot_token_file = Path:new(copilot_path)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user