feat: support multiple models in one provider (#2106)
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
local Utils = require("avante.utils")
|
local Utils = require("avante.utils")
|
||||||
|
local Providers = require("avante.providers")
|
||||||
local Config = require("avante.config")
|
local Config = require("avante.config")
|
||||||
local Selector = require("avante.ui.selector")
|
local Selector = require("avante.ui.selector")
|
||||||
|
|
||||||
@@ -7,14 +8,36 @@ local M = {}
|
|||||||
|
|
||||||
---@param provider_name string
|
---@param provider_name string
|
||||||
---@param cfg table
|
---@param cfg table
|
||||||
---@return table?
|
---@return table
|
||||||
local function create_model_entry(provider_name, cfg)
|
local function create_model_entries(provider_name, cfg)
|
||||||
|
if cfg.models_list then
|
||||||
|
local models_list = type(cfg.models_list) == "function" and cfg:models_list() or cfg.models_list
|
||||||
|
if not models_list then return {} end
|
||||||
|
-- If models_list is defined, use it to create entries
|
||||||
|
local models = vim
|
||||||
|
.iter(models_list)
|
||||||
|
:map(
|
||||||
|
function(model)
|
||||||
|
return {
|
||||||
|
name = model.name or model.id,
|
||||||
|
display_name = model.display_name or model.name or model.id,
|
||||||
|
provider_name = provider_name,
|
||||||
|
model = model.id,
|
||||||
|
}
|
||||||
|
end
|
||||||
|
)
|
||||||
|
:totable()
|
||||||
|
return models
|
||||||
|
end
|
||||||
return cfg.model
|
return cfg.model
|
||||||
and {
|
and {
|
||||||
name = cfg.display_name or (provider_name .. "/" .. cfg.model),
|
{
|
||||||
provider_name = provider_name,
|
name = cfg.display_name or (provider_name .. "/" .. cfg.model),
|
||||||
model = cfg.model,
|
provider_name = provider_name,
|
||||||
}
|
model = cfg.model,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
or {}
|
||||||
end
|
end
|
||||||
|
|
||||||
function M.open()
|
function M.open()
|
||||||
@@ -22,10 +45,14 @@ function M.open()
|
|||||||
|
|
||||||
-- Collect models from main providers and vendors
|
-- Collect models from main providers and vendors
|
||||||
for _, provider_name in ipairs(Config.provider_names) do
|
for _, provider_name in ipairs(Config.provider_names) do
|
||||||
local cfg = Config.get_provider_config(provider_name)
|
local ok, cfg = pcall(function() return Providers[provider_name] end)
|
||||||
|
if not ok then
|
||||||
|
Utils.warn("Failed to load provider: " .. provider_name)
|
||||||
|
goto continue
|
||||||
|
end
|
||||||
if cfg.hide_in_model_selector then goto continue end
|
if cfg.hide_in_model_selector then goto continue end
|
||||||
local entry = create_model_entry(provider_name, cfg)
|
local entries = create_model_entries(provider_name, cfg)
|
||||||
if entry then table.insert(models, entry) end
|
models = vim.list_extend(models, entries)
|
||||||
::continue::
|
::continue::
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
@@ -213,6 +213,55 @@ function M:is_disable_stream() return false end
|
|||||||
|
|
||||||
setmetatable(M, { __index = OpenAI })
|
setmetatable(M, { __index = OpenAI })
|
||||||
|
|
||||||
|
function M:models_list()
|
||||||
|
if M._model_list_cache then return M._model_list_cache end
|
||||||
|
local curl_opts = {
|
||||||
|
headers = {
|
||||||
|
["Content-Type"] = "application/json",
|
||||||
|
["Authorization"] = "Bearer " .. M.state.github_token.token,
|
||||||
|
["Copilot-Integration-Id"] = "vscode-chat",
|
||||||
|
["Editor-Version"] = ("Neovim/%s.%s.%s"):format(vim.version().major, vim.version().minor, vim.version().patch),
|
||||||
|
},
|
||||||
|
timeout = Config.copilot.timeout,
|
||||||
|
proxy = Config.copilot.proxy,
|
||||||
|
insecure = Config.copilot.allow_insecure,
|
||||||
|
}
|
||||||
|
|
||||||
|
local function handle_response(response)
|
||||||
|
if response.status == 200 then
|
||||||
|
local body = vim.json.decode(response.body)
|
||||||
|
-- ref: https://github.com/CopilotC-Nvim/CopilotChat.nvim/blob/16d897fd43d07e3b54478ccdb2f8a16e4df4f45a/lua/CopilotChat/config/providers.lua#L171-L187
|
||||||
|
local models = vim
|
||||||
|
.iter(body.data)
|
||||||
|
:filter(function(model) return model.capabilities.type == "chat" and not vim.endswith(model.id, "paygo") end)
|
||||||
|
:map(
|
||||||
|
function(model)
|
||||||
|
return {
|
||||||
|
id = model.id,
|
||||||
|
display_name = model.name,
|
||||||
|
name = "copilot/" .. model.name .. " (" .. model.id .. ")",
|
||||||
|
provider_name = "copilot",
|
||||||
|
tokenizer = model.capabilities.tokenizer,
|
||||||
|
max_input_tokens = model.capabilities.limits.max_prompt_tokens,
|
||||||
|
max_output_tokens = model.capabilities.limits.max_output_tokens,
|
||||||
|
policy = not model["policy"] or model["policy"]["state"] == "enabled",
|
||||||
|
version = model.version,
|
||||||
|
}
|
||||||
|
end
|
||||||
|
)
|
||||||
|
:totable()
|
||||||
|
M._model_list_cache = models
|
||||||
|
return models
|
||||||
|
else
|
||||||
|
error("Failed to get success response: " .. vim.inspect(response))
|
||||||
|
return {}
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
local response = curl.get((M.state.github_token.endpoints.api or "") .. "/models", curl_opts)
|
||||||
|
return handle_response(response)
|
||||||
|
end
|
||||||
|
|
||||||
function M:parse_curl_args(prompt_opts)
|
function M:parse_curl_args(prompt_opts)
|
||||||
-- refresh token synchronously, only if it has expired
|
-- refresh token synchronously, only if it has expired
|
||||||
-- (this should rarely happen, as we refresh the token in the background)
|
-- (this should rarely happen, as we refresh the token in the background)
|
||||||
@@ -229,7 +278,7 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
end
|
end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
url = H.chat_completion_url(provider_conf.endpoint),
|
url = H.chat_completion_url(M.state.github_token.endpoints.api or provider_conf.endpoint),
|
||||||
timeout = provider_conf.timeout,
|
timeout = provider_conf.timeout,
|
||||||
proxy = provider_conf.proxy,
|
proxy = provider_conf.proxy,
|
||||||
insecure = provider_conf.allow_insecure,
|
insecure = provider_conf.allow_insecure,
|
||||||
|
|||||||
@@ -287,6 +287,7 @@ vim.g.avante_login = vim.g.avante_login
|
|||||||
---@field parse_api_key? fun(): string | nil
|
---@field parse_api_key? fun(): string | nil
|
||||||
---
|
---
|
||||||
---@class AvanteProviderFunctor
|
---@class AvanteProviderFunctor
|
||||||
|
---@field _model_list_cache table
|
||||||
---@field support_prompt_caching boolean | nil
|
---@field support_prompt_caching boolean | nil
|
||||||
---@field role_map table<"user" | "assistant", string>
|
---@field role_map table<"user" | "assistant", string>
|
||||||
---@field parse_messages AvanteMessagesParser
|
---@field parse_messages AvanteMessagesParser
|
||||||
|
|||||||
Reference in New Issue
Block a user