refactor: providers config (#2117)

This commit is contained in:
yetone
2025-06-03 04:01:01 +08:00
committed by GitHub
parent b89e6d84a0
commit e9ab2ca2fd
12 changed files with 327 additions and 251 deletions

View File

@@ -1,10 +1,13 @@
---@class AvanteAzureProvider: AvanteDefaultBaseProvider
---@field deployment string
---@field api_version string
---@class AvanteAzureExtraRequestBody
---@field temperature number
---@field max_completion_tokens number
---@field reasoning_effort? string
---@class AvanteAzureProvider: AvanteDefaultBaseProvider
---@field deployment string
---@field api_version string
---@field extra_request_body AvanteAzureExtraRequestBody
local Utils = require("avante.utils")
local P = require("avante.providers")
local O = require("avante.providers").openai

View File

@@ -27,7 +27,6 @@
local curl = require("plenary.curl")
local Config = require("avante.config")
local Path = require("plenary.path")
local Utils = require("avante.utils")
local Providers = require("avante.providers")
@@ -156,14 +155,16 @@ function H.refresh_token(async, force)
return false
end
local provider_conf = Providers.get_config("copilot")
local curl_opts = {
headers = {
["Authorization"] = "token " .. M.state.oauth_token,
["Accept"] = "application/json",
},
timeout = Config.copilot.timeout,
proxy = Config.copilot.proxy,
insecure = Config.copilot.allow_insecure,
timeout = provider_conf.timeout,
proxy = provider_conf.proxy,
insecure = provider_conf.allow_insecure,
}
local function handle_response(response)
@@ -219,6 +220,7 @@ function M:models_list()
-- refresh token synchronously, only if it has expired
-- (this should rarely happen, as we refresh the token in the background)
H.refresh_token(false, false)
local provider_conf = Providers.parse_config(self)
local curl_opts = {
headers = {
["Content-Type"] = "application/json",
@@ -226,9 +228,9 @@ function M:models_list()
["Copilot-Integration-Id"] = "vscode-chat",
["Editor-Version"] = ("Neovim/%s.%s.%s"):format(vim.version().major, vim.version().minor, vim.version().patch),
},
timeout = Config.copilot.timeout,
proxy = Config.copilot.proxy,
insecure = Config.copilot.allow_insecure,
timeout = provider_conf.timeout,
proxy = provider_conf.proxy,
insecure = provider_conf.allow_insecure,
}
local function handle_response(response)

View File

@@ -157,32 +157,22 @@ M = setmetatable(M, {
---@param t avante.Providers
---@param k avante.ProviderName
__index = function(t, k)
if Config.providers[k] == nil then error("Failed to find provider: " .. k, 2) end
local provider_config = M.get_config(k)
if Config.vendors[k] ~= nil and k == "ollama" then
Utils.warn(
"ollama is now a first-class provider in avante.nvim, please stop using vendors to define ollama, for migration guide please refer to: https://github.com/yetone/avante.nvim/wiki/Custom-providers#ollama"
)
end
---@diagnostic disable: undefined-field,no-unknown,inject-field
if Config.vendors[k] ~= nil and k ~= "ollama" then
if provider_config.parse_response_data ~= nil then
Utils.error("parse_response_data is not supported for avante.nvim vendors")
end
if provider_config.__inherited_from ~= nil then
local base_provider_config = M.get_config(provider_config.__inherited_from)
local ok, module = pcall(require, "avante.providers." .. provider_config.__inherited_from)
if not ok then error("Failed to load provider: " .. provider_config.__inherited_from) end
t[k] = Utils.deep_extend_with_metatable("keep", provider_config, base_provider_config, module)
else
t[k] = provider_config
end
if provider_config.__inherited_from ~= nil then
local base_provider_config = M.get_config(provider_config.__inherited_from)
local ok, module = pcall(require, "avante.providers." .. provider_config.__inherited_from)
if not ok then error("Failed to load provider: " .. provider_config.__inherited_from, 2) end
provider_config = Utils.deep_extend_with_metatable("force", module, base_provider_config, provider_config)
else
local ok, module = pcall(require, "avante.providers." .. k)
if not ok then error("Failed to load provider: " .. k) end
t[k] = Utils.deep_extend_with_metatable("keep", provider_config, module)
if ok then provider_config = Utils.deep_extend_with_metatable("force", module, provider_config) end
end
t[k] = provider_config
if t[k].parse_api_key == nil then t[k].parse_api_key = function() return E.parse_envvar(t[k]) end end
-- default to gpt-4o as tokenizer
@@ -241,29 +231,17 @@ end
function M.parse_config(opts)
---@type AvanteDefaultBaseProvider
local provider_opts = {}
---@type table<string, any>
local request_body = {}
for key, value in pairs(opts) do
if vim.tbl_contains(Config.BASE_PROVIDER_KEYS, key) then
provider_opts[key] = value
else
request_body[key] = value
end
if key ~= "extra_request_body" then provider_opts[key] = value end
end
request_body = vim
.iter(request_body)
:filter(function(_, v) return type(v) ~= "function" and type(v) ~= "userdata" end)
:fold({}, function(acc, k, v)
acc[k] = v
return acc
end)
---@type table<string, any>
local request_body = opts.extra_request_body or {}
return provider_opts, request_body
end
---@private
---@param provider_name avante.ProviderName
function M.get_config(provider_name)
provider_name = provider_name or Config.provider

View File

@@ -48,7 +48,7 @@ function M:parse_curl_args(prompt_opts)
request_body = vim.tbl_deep_extend("force", request_body, {
anthropic_version = "vertex-2023-10-16",
temperature = 0,
temperature = 0.75,
max_tokens = 4096,
stream = true,
messages = messages,