diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 87bec79..efadd4a 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -762,6 +762,12 @@ local function apply_model_selection(config, model_name, provider_name) if target_provider_name ~= current_provider_name or model_name ~= current_model_name then config.provider = target_provider_name target_provider.model = model_name + if not target_provider.model_names then target_provider.model_names = {} end + for _, model_name_ in ipairs({ model_name, current_model_name }) do + if not vim.tbl_contains(target_provider.model_names, model_name_) then + table.insert(target_provider.model_names, model_name_) + end + end Utils.info(string.format("Using previously selected model: %s/%s", target_provider_name, model_name)) end end diff --git a/lua/avante/model_selector.lua b/lua/avante/model_selector.lua index 85ac10a..87566cc 100644 --- a/lua/avante/model_selector.lua +++ b/lua/avante/model_selector.lua @@ -6,64 +6,80 @@ local Selector = require("avante.ui.selector") ---@class avante.ModelSelector local M = {} -M.models_list_invoked = {} -M.models_list_returned = {} +M.list_models_invoked = {} +M.list_models_returned = {} -local models_list_cached_result = {} +local list_models_cached_result = {} ---@param provider_name string ---@param provider_cfg table ---@return table local function create_model_entries(provider_name, provider_cfg) - if provider_cfg.models_list and provider_cfg.__inherited_from == nil then - local models_list - if type(provider_cfg.models_list) == "function" then - if M.models_list_invoked[provider_cfg.models_list] then return {} end - M.models_list_invoked[provider_cfg.models_list] = true - local cached_result = models_list_cached_result[provider_cfg.models_list] + local res = {} + if provider_cfg.list_models and provider_cfg.__inherited_from == nil then + local models + if type(provider_cfg.list_models) == "function" then + if M.list_models_invoked[provider_cfg.list_models] then return {} end + M.list_models_invoked[provider_cfg.list_models] = true + local cached_result = list_models_cached_result[provider_cfg.list_models] if cached_result then - models_list = cached_result + models = cached_result else - models_list = provider_cfg:models_list() - models_list_cached_result[provider_cfg.models_list] = models_list + models = provider_cfg:list_models() + list_models_cached_result[provider_cfg.list_models] = models end else - if M.models_list_returned[provider_cfg.models_list] then return {} end - M.models_list_returned[provider_cfg.models_list] = true - models_list = provider_cfg.models_list + if M.list_models_returned[provider_cfg.list_models] then return {} end + M.list_models_returned[provider_cfg.list_models] = true + models = provider_cfg.list_models + end + if models then + -- If list_models is defined, use it to create entries + res = vim + .iter(models) + :map( + function(model) + return { + name = model.name or model.id, + display_name = model.display_name or model.name or model.id, + provider_name = provider_name, + model = model.id, + } + end + ) + :totable() end - if not models_list then return {} end - -- If models_list is defined, use it to create entries - local models = vim - .iter(models_list) - :map( - function(model) - return { - name = model.name or model.id, - display_name = model.display_name or model.name or model.id, - provider_name = provider_name, - model = model.id, - } - end - ) - :totable() - return models end - return provider_cfg.model - and { - { - name = provider_cfg.display_name or (provider_name .. "/" .. provider_cfg.model), - display_name = provider_cfg.display_name or (provider_name .. "/" .. provider_cfg.model), + if provider_cfg.model then + local seen = vim.iter(res):find(function(item) return item.model == provider_cfg.model end) + if not seen then + table.insert(res, { + name = provider_cfg.display_name or (provider_name .. "/" .. provider_cfg.model), + display_name = provider_cfg.display_name or (provider_name .. "/" .. provider_cfg.model), + provider_name = provider_name, + model = provider_cfg.model, + }) + end + end + if provider_cfg.model_names then + for _, model_name in ipairs(provider_cfg.model_names) do + local seen = vim.iter(res):find(function(item) return item.model == model_name end) + if not seen then + table.insert(res, { + name = provider_cfg.display_name or (provider_name .. "/" .. model_name), + display_name = provider_cfg.display_name or (provider_name .. "/" .. model_name), provider_name = provider_name, - model = provider_cfg.model, - }, - } - or {} + model = model_name, + }) + end + end + end + return res end function M.open() - M.models_list_invoked = {} - M.models_list_returned = {} + M.list_models_invoked = {} + M.list_models_returned = {} local models = {} -- Collect models from providers diff --git a/lua/avante/providers/copilot.lua b/lua/avante/providers/copilot.lua index a5774f9..caf50cc 100644 --- a/lua/avante/providers/copilot.lua +++ b/lua/avante/providers/copilot.lua @@ -214,7 +214,7 @@ function M:is_disable_stream() return false end setmetatable(M, { __index = OpenAI }) -function M:models_list() +function M:list_models() if M._model_list_cache then return M._model_list_cache end if not M._is_setup then M.setup() end -- refresh token synchronously, only if it has expired diff --git a/lua/avante/providers/ollama.lua b/lua/avante/providers/ollama.lua index 650bc4b..0c236da 100644 --- a/lua/avante/providers/ollama.lua +++ b/lua/avante/providers/ollama.lua @@ -233,7 +233,7 @@ M.on_error = function(result) end -- List available models using Ollama's tags API -function M:models_list() +function M:list_models() -- Return cached models if available if self._model_list_cache then return self._model_list_cache end diff --git a/lua/avante/types.lua b/lua/avante/types.lua index 38ca135..236adfe 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -235,6 +235,7 @@ vim.g.avante_login = vim.g.avante_login ---@field endpoint? string ---@field extra_request_body? table ---@field model? string +---@field model_names? string[] ---@field local? boolean ---@field proxy? string ---@field keep_alive? string @@ -344,7 +345,7 @@ vim.g.avante_login = vim.g.avante_login ---@field on_error? fun(result: table): nil ---@field transform_tool? fun(self: AvanteProviderFunctor, tool: AvanteLLMTool): AvanteOpenAITool | AvanteClaudeTool ---@field get_rate_limit_sleep_time? fun(self: AvanteProviderFunctor, headers: table): integer | nil ----@field models_list? fun(self): AvanteProviderModelList | nil +---@field list_models? fun(self): AvanteProviderModelList | nil --- ---@alias AvanteBedrockPayloadBuilder fun(self: AvanteBedrockModelHandler | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions, request_body: table): table ---