From 054e84840ba4ac908f786a79668587fa614db279 Mon Sep 17 00:00:00 2001 From: Avinash Thakur <19588421+80avin@users.noreply.github.com> Date: Mon, 2 Jun 2025 14:17:03 +0530 Subject: [PATCH] feat: support multiple models in one provider (#2106) --- lua/avante/model_selector.lua | 47 ++++++++++++++++++++++------- lua/avante/providers/copilot.lua | 51 +++++++++++++++++++++++++++++++- lua/avante/types.lua | 1 + 3 files changed, 88 insertions(+), 11 deletions(-) diff --git a/lua/avante/model_selector.lua b/lua/avante/model_selector.lua index 3a1de4f..c49f0a8 100644 --- a/lua/avante/model_selector.lua +++ b/lua/avante/model_selector.lua @@ -1,4 +1,5 @@ local Utils = require("avante.utils") +local Providers = require("avante.providers") local Config = require("avante.config") local Selector = require("avante.ui.selector") @@ -7,14 +8,36 @@ local M = {} ---@param provider_name string ---@param cfg table ----@return table? -local function create_model_entry(provider_name, cfg) +---@return table +local function create_model_entries(provider_name, cfg) + if cfg.models_list then + local models_list = type(cfg.models_list) == "function" and cfg:models_list() or cfg.models_list + if not models_list then return {} end + -- If models_list is defined, use it to create entries + local models = vim + .iter(models_list) + :map( + function(model) + return { + name = model.name or model.id, + display_name = model.display_name or model.name or model.id, + provider_name = provider_name, + model = model.id, + } + end + ) + :totable() + return models + end return cfg.model - and { - name = cfg.display_name or (provider_name .. "/" .. cfg.model), - provider_name = provider_name, - model = cfg.model, - } + and { + { + name = cfg.display_name or (provider_name .. "/" .. cfg.model), + provider_name = provider_name, + model = cfg.model, + }, + } + or {} end function M.open() @@ -22,10 +45,14 @@ function M.open() -- Collect models from main providers and vendors for _, provider_name in ipairs(Config.provider_names) do - local cfg = Config.get_provider_config(provider_name) + local ok, cfg = pcall(function() return Providers[provider_name] end) + if not ok then + Utils.warn("Failed to load provider: " .. provider_name) + goto continue + end if cfg.hide_in_model_selector then goto continue end - local entry = create_model_entry(provider_name, cfg) - if entry then table.insert(models, entry) end + local entries = create_model_entries(provider_name, cfg) + models = vim.list_extend(models, entries) ::continue:: end diff --git a/lua/avante/providers/copilot.lua b/lua/avante/providers/copilot.lua index caf8f91..fa3ea2f 100644 --- a/lua/avante/providers/copilot.lua +++ b/lua/avante/providers/copilot.lua @@ -213,6 +213,55 @@ function M:is_disable_stream() return false end setmetatable(M, { __index = OpenAI }) +function M:models_list() + if M._model_list_cache then return M._model_list_cache end + local curl_opts = { + headers = { + ["Content-Type"] = "application/json", + ["Authorization"] = "Bearer " .. M.state.github_token.token, + ["Copilot-Integration-Id"] = "vscode-chat", + ["Editor-Version"] = ("Neovim/%s.%s.%s"):format(vim.version().major, vim.version().minor, vim.version().patch), + }, + timeout = Config.copilot.timeout, + proxy = Config.copilot.proxy, + insecure = Config.copilot.allow_insecure, + } + + local function handle_response(response) + if response.status == 200 then + local body = vim.json.decode(response.body) + -- ref: https://github.com/CopilotC-Nvim/CopilotChat.nvim/blob/16d897fd43d07e3b54478ccdb2f8a16e4df4f45a/lua/CopilotChat/config/providers.lua#L171-L187 + local models = vim + .iter(body.data) + :filter(function(model) return model.capabilities.type == "chat" and not vim.endswith(model.id, "paygo") end) + :map( + function(model) + return { + id = model.id, + display_name = model.name, + name = "copilot/" .. model.name .. " (" .. model.id .. ")", + provider_name = "copilot", + tokenizer = model.capabilities.tokenizer, + max_input_tokens = model.capabilities.limits.max_prompt_tokens, + max_output_tokens = model.capabilities.limits.max_output_tokens, + policy = not model["policy"] or model["policy"]["state"] == "enabled", + version = model.version, + } + end + ) + :totable() + M._model_list_cache = models + return models + else + error("Failed to get success response: " .. vim.inspect(response)) + return {} + end + end + + local response = curl.get((M.state.github_token.endpoints.api or "") .. "/models", curl_opts) + return handle_response(response) +end + function M:parse_curl_args(prompt_opts) -- refresh token synchronously, only if it has expired -- (this should rarely happen, as we refresh the token in the background) @@ -229,7 +278,7 @@ function M:parse_curl_args(prompt_opts) end return { - url = H.chat_completion_url(provider_conf.endpoint), + url = H.chat_completion_url(M.state.github_token.endpoints.api or provider_conf.endpoint), timeout = provider_conf.timeout, proxy = provider_conf.proxy, insecure = provider_conf.allow_insecure, diff --git a/lua/avante/types.lua b/lua/avante/types.lua index c43a350..1c17592 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -287,6 +287,7 @@ vim.g.avante_login = vim.g.avante_login ---@field parse_api_key? fun(): string | nil --- ---@class AvanteProviderFunctor +---@field _model_list_cache table ---@field support_prompt_caching boolean | nil ---@field role_map table<"user" | "assistant", string> ---@field parse_messages AvanteMessagesParser