fix: the last used model caused the loss of the model configured in the profile in the model selector (#2600)
This commit is contained in:
@@ -6,64 +6,80 @@ local Selector = require("avante.ui.selector")
|
||||
---@class avante.ModelSelector
|
||||
local M = {}
|
||||
|
||||
M.models_list_invoked = {}
|
||||
M.models_list_returned = {}
|
||||
M.list_models_invoked = {}
|
||||
M.list_models_returned = {}
|
||||
|
||||
local models_list_cached_result = {}
|
||||
local list_models_cached_result = {}
|
||||
|
||||
---@param provider_name string
|
||||
---@param provider_cfg table
|
||||
---@return table
|
||||
local function create_model_entries(provider_name, provider_cfg)
|
||||
if provider_cfg.models_list and provider_cfg.__inherited_from == nil then
|
||||
local models_list
|
||||
if type(provider_cfg.models_list) == "function" then
|
||||
if M.models_list_invoked[provider_cfg.models_list] then return {} end
|
||||
M.models_list_invoked[provider_cfg.models_list] = true
|
||||
local cached_result = models_list_cached_result[provider_cfg.models_list]
|
||||
local res = {}
|
||||
if provider_cfg.list_models and provider_cfg.__inherited_from == nil then
|
||||
local models
|
||||
if type(provider_cfg.list_models) == "function" then
|
||||
if M.list_models_invoked[provider_cfg.list_models] then return {} end
|
||||
M.list_models_invoked[provider_cfg.list_models] = true
|
||||
local cached_result = list_models_cached_result[provider_cfg.list_models]
|
||||
if cached_result then
|
||||
models_list = cached_result
|
||||
models = cached_result
|
||||
else
|
||||
models_list = provider_cfg:models_list()
|
||||
models_list_cached_result[provider_cfg.models_list] = models_list
|
||||
models = provider_cfg:list_models()
|
||||
list_models_cached_result[provider_cfg.list_models] = models
|
||||
end
|
||||
else
|
||||
if M.models_list_returned[provider_cfg.models_list] then return {} end
|
||||
M.models_list_returned[provider_cfg.models_list] = true
|
||||
models_list = provider_cfg.models_list
|
||||
if M.list_models_returned[provider_cfg.list_models] then return {} end
|
||||
M.list_models_returned[provider_cfg.list_models] = true
|
||||
models = provider_cfg.list_models
|
||||
end
|
||||
if models then
|
||||
-- If list_models is defined, use it to create entries
|
||||
res = vim
|
||||
.iter(models)
|
||||
:map(
|
||||
function(model)
|
||||
return {
|
||||
name = model.name or model.id,
|
||||
display_name = model.display_name or model.name or model.id,
|
||||
provider_name = provider_name,
|
||||
model = model.id,
|
||||
}
|
||||
end
|
||||
)
|
||||
:totable()
|
||||
end
|
||||
if not models_list then return {} end
|
||||
-- If models_list is defined, use it to create entries
|
||||
local models = vim
|
||||
.iter(models_list)
|
||||
:map(
|
||||
function(model)
|
||||
return {
|
||||
name = model.name or model.id,
|
||||
display_name = model.display_name or model.name or model.id,
|
||||
provider_name = provider_name,
|
||||
model = model.id,
|
||||
}
|
||||
end
|
||||
)
|
||||
:totable()
|
||||
return models
|
||||
end
|
||||
return provider_cfg.model
|
||||
and {
|
||||
{
|
||||
name = provider_cfg.display_name or (provider_name .. "/" .. provider_cfg.model),
|
||||
display_name = provider_cfg.display_name or (provider_name .. "/" .. provider_cfg.model),
|
||||
if provider_cfg.model then
|
||||
local seen = vim.iter(res):find(function(item) return item.model == provider_cfg.model end)
|
||||
if not seen then
|
||||
table.insert(res, {
|
||||
name = provider_cfg.display_name or (provider_name .. "/" .. provider_cfg.model),
|
||||
display_name = provider_cfg.display_name or (provider_name .. "/" .. provider_cfg.model),
|
||||
provider_name = provider_name,
|
||||
model = provider_cfg.model,
|
||||
})
|
||||
end
|
||||
end
|
||||
if provider_cfg.model_names then
|
||||
for _, model_name in ipairs(provider_cfg.model_names) do
|
||||
local seen = vim.iter(res):find(function(item) return item.model == model_name end)
|
||||
if not seen then
|
||||
table.insert(res, {
|
||||
name = provider_cfg.display_name or (provider_name .. "/" .. model_name),
|
||||
display_name = provider_cfg.display_name or (provider_name .. "/" .. model_name),
|
||||
provider_name = provider_name,
|
||||
model = provider_cfg.model,
|
||||
},
|
||||
}
|
||||
or {}
|
||||
model = model_name,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
return res
|
||||
end
|
||||
|
||||
function M.open()
|
||||
M.models_list_invoked = {}
|
||||
M.models_list_returned = {}
|
||||
M.list_models_invoked = {}
|
||||
M.list_models_returned = {}
|
||||
local models = {}
|
||||
|
||||
-- Collect models from providers
|
||||
|
||||
Reference in New Issue
Block a user