fix: the last used model caused the loss of the model configured in the profile in the model selector (#2600)
This commit is contained in:
@@ -762,6 +762,12 @@ local function apply_model_selection(config, model_name, provider_name)
|
||||
if target_provider_name ~= current_provider_name or model_name ~= current_model_name then
|
||||
config.provider = target_provider_name
|
||||
target_provider.model = model_name
|
||||
if not target_provider.model_names then target_provider.model_names = {} end
|
||||
for _, model_name_ in ipairs({ model_name, current_model_name }) do
|
||||
if not vim.tbl_contains(target_provider.model_names, model_name_) then
|
||||
table.insert(target_provider.model_names, model_name_)
|
||||
end
|
||||
end
|
||||
Utils.info(string.format("Using previously selected model: %s/%s", target_provider_name, model_name))
|
||||
end
|
||||
end
|
||||
|
||||
@@ -6,64 +6,80 @@ local Selector = require("avante.ui.selector")
|
||||
---@class avante.ModelSelector
|
||||
local M = {}
|
||||
|
||||
M.models_list_invoked = {}
|
||||
M.models_list_returned = {}
|
||||
M.list_models_invoked = {}
|
||||
M.list_models_returned = {}
|
||||
|
||||
local models_list_cached_result = {}
|
||||
local list_models_cached_result = {}
|
||||
|
||||
---@param provider_name string
|
||||
---@param provider_cfg table
|
||||
---@return table
|
||||
local function create_model_entries(provider_name, provider_cfg)
|
||||
if provider_cfg.models_list and provider_cfg.__inherited_from == nil then
|
||||
local models_list
|
||||
if type(provider_cfg.models_list) == "function" then
|
||||
if M.models_list_invoked[provider_cfg.models_list] then return {} end
|
||||
M.models_list_invoked[provider_cfg.models_list] = true
|
||||
local cached_result = models_list_cached_result[provider_cfg.models_list]
|
||||
local res = {}
|
||||
if provider_cfg.list_models and provider_cfg.__inherited_from == nil then
|
||||
local models
|
||||
if type(provider_cfg.list_models) == "function" then
|
||||
if M.list_models_invoked[provider_cfg.list_models] then return {} end
|
||||
M.list_models_invoked[provider_cfg.list_models] = true
|
||||
local cached_result = list_models_cached_result[provider_cfg.list_models]
|
||||
if cached_result then
|
||||
models_list = cached_result
|
||||
models = cached_result
|
||||
else
|
||||
models_list = provider_cfg:models_list()
|
||||
models_list_cached_result[provider_cfg.models_list] = models_list
|
||||
models = provider_cfg:list_models()
|
||||
list_models_cached_result[provider_cfg.list_models] = models
|
||||
end
|
||||
else
|
||||
if M.models_list_returned[provider_cfg.models_list] then return {} end
|
||||
M.models_list_returned[provider_cfg.models_list] = true
|
||||
models_list = provider_cfg.models_list
|
||||
if M.list_models_returned[provider_cfg.list_models] then return {} end
|
||||
M.list_models_returned[provider_cfg.list_models] = true
|
||||
models = provider_cfg.list_models
|
||||
end
|
||||
if models then
|
||||
-- If list_models is defined, use it to create entries
|
||||
res = vim
|
||||
.iter(models)
|
||||
:map(
|
||||
function(model)
|
||||
return {
|
||||
name = model.name or model.id,
|
||||
display_name = model.display_name or model.name or model.id,
|
||||
provider_name = provider_name,
|
||||
model = model.id,
|
||||
}
|
||||
end
|
||||
)
|
||||
:totable()
|
||||
end
|
||||
if not models_list then return {} end
|
||||
-- If models_list is defined, use it to create entries
|
||||
local models = vim
|
||||
.iter(models_list)
|
||||
:map(
|
||||
function(model)
|
||||
return {
|
||||
name = model.name or model.id,
|
||||
display_name = model.display_name or model.name or model.id,
|
||||
provider_name = provider_name,
|
||||
model = model.id,
|
||||
}
|
||||
end
|
||||
)
|
||||
:totable()
|
||||
return models
|
||||
end
|
||||
return provider_cfg.model
|
||||
and {
|
||||
{
|
||||
name = provider_cfg.display_name or (provider_name .. "/" .. provider_cfg.model),
|
||||
display_name = provider_cfg.display_name or (provider_name .. "/" .. provider_cfg.model),
|
||||
if provider_cfg.model then
|
||||
local seen = vim.iter(res):find(function(item) return item.model == provider_cfg.model end)
|
||||
if not seen then
|
||||
table.insert(res, {
|
||||
name = provider_cfg.display_name or (provider_name .. "/" .. provider_cfg.model),
|
||||
display_name = provider_cfg.display_name or (provider_name .. "/" .. provider_cfg.model),
|
||||
provider_name = provider_name,
|
||||
model = provider_cfg.model,
|
||||
})
|
||||
end
|
||||
end
|
||||
if provider_cfg.model_names then
|
||||
for _, model_name in ipairs(provider_cfg.model_names) do
|
||||
local seen = vim.iter(res):find(function(item) return item.model == model_name end)
|
||||
if not seen then
|
||||
table.insert(res, {
|
||||
name = provider_cfg.display_name or (provider_name .. "/" .. model_name),
|
||||
display_name = provider_cfg.display_name or (provider_name .. "/" .. model_name),
|
||||
provider_name = provider_name,
|
||||
model = provider_cfg.model,
|
||||
},
|
||||
}
|
||||
or {}
|
||||
model = model_name,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
return res
|
||||
end
|
||||
|
||||
function M.open()
|
||||
M.models_list_invoked = {}
|
||||
M.models_list_returned = {}
|
||||
M.list_models_invoked = {}
|
||||
M.list_models_returned = {}
|
||||
local models = {}
|
||||
|
||||
-- Collect models from providers
|
||||
|
||||
@@ -214,7 +214,7 @@ function M:is_disable_stream() return false end
|
||||
|
||||
setmetatable(M, { __index = OpenAI })
|
||||
|
||||
function M:models_list()
|
||||
function M:list_models()
|
||||
if M._model_list_cache then return M._model_list_cache end
|
||||
if not M._is_setup then M.setup() end
|
||||
-- refresh token synchronously, only if it has expired
|
||||
|
||||
@@ -233,7 +233,7 @@ M.on_error = function(result)
|
||||
end
|
||||
|
||||
-- List available models using Ollama's tags API
|
||||
function M:models_list()
|
||||
function M:list_models()
|
||||
-- Return cached models if available
|
||||
if self._model_list_cache then return self._model_list_cache end
|
||||
|
||||
|
||||
@@ -235,6 +235,7 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field endpoint? string
|
||||
---@field extra_request_body? table<string, any>
|
||||
---@field model? string
|
||||
---@field model_names? string[]
|
||||
---@field local? boolean
|
||||
---@field proxy? string
|
||||
---@field keep_alive? string
|
||||
@@ -344,7 +345,7 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field on_error? fun(result: table<string, any>): nil
|
||||
---@field transform_tool? fun(self: AvanteProviderFunctor, tool: AvanteLLMTool): AvanteOpenAITool | AvanteClaudeTool
|
||||
---@field get_rate_limit_sleep_time? fun(self: AvanteProviderFunctor, headers: table<string, string>): integer | nil
|
||||
---@field models_list? fun(self): AvanteProviderModelList | nil
|
||||
---@field list_models? fun(self): AvanteProviderModelList | nil
|
||||
---
|
||||
---@alias AvanteBedrockPayloadBuilder fun(self: AvanteBedrockModelHandler | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions, request_body: table<string, any>): table<string, any>
|
||||
---
|
||||
|
||||
Reference in New Issue
Block a user