refactor: providers config (#2117)
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
---@class AvanteAzureProvider: AvanteDefaultBaseProvider
|
||||
---@field deployment string
|
||||
---@field api_version string
|
||||
---@class AvanteAzureExtraRequestBody
|
||||
---@field temperature number
|
||||
---@field max_completion_tokens number
|
||||
---@field reasoning_effort? string
|
||||
|
||||
---@class AvanteAzureProvider: AvanteDefaultBaseProvider
|
||||
---@field deployment string
|
||||
---@field api_version string
|
||||
---@field extra_request_body AvanteAzureExtraRequestBody
|
||||
|
||||
local Utils = require("avante.utils")
|
||||
local P = require("avante.providers")
|
||||
local O = require("avante.providers").openai
|
||||
|
||||
@@ -27,7 +27,6 @@
|
||||
|
||||
local curl = require("plenary.curl")
|
||||
|
||||
local Config = require("avante.config")
|
||||
local Path = require("plenary.path")
|
||||
local Utils = require("avante.utils")
|
||||
local Providers = require("avante.providers")
|
||||
@@ -156,14 +155,16 @@ function H.refresh_token(async, force)
|
||||
return false
|
||||
end
|
||||
|
||||
local provider_conf = Providers.get_config("copilot")
|
||||
|
||||
local curl_opts = {
|
||||
headers = {
|
||||
["Authorization"] = "token " .. M.state.oauth_token,
|
||||
["Accept"] = "application/json",
|
||||
},
|
||||
timeout = Config.copilot.timeout,
|
||||
proxy = Config.copilot.proxy,
|
||||
insecure = Config.copilot.allow_insecure,
|
||||
timeout = provider_conf.timeout,
|
||||
proxy = provider_conf.proxy,
|
||||
insecure = provider_conf.allow_insecure,
|
||||
}
|
||||
|
||||
local function handle_response(response)
|
||||
@@ -219,6 +220,7 @@ function M:models_list()
|
||||
-- refresh token synchronously, only if it has expired
|
||||
-- (this should rarely happen, as we refresh the token in the background)
|
||||
H.refresh_token(false, false)
|
||||
local provider_conf = Providers.parse_config(self)
|
||||
local curl_opts = {
|
||||
headers = {
|
||||
["Content-Type"] = "application/json",
|
||||
@@ -226,9 +228,9 @@ function M:models_list()
|
||||
["Copilot-Integration-Id"] = "vscode-chat",
|
||||
["Editor-Version"] = ("Neovim/%s.%s.%s"):format(vim.version().major, vim.version().minor, vim.version().patch),
|
||||
},
|
||||
timeout = Config.copilot.timeout,
|
||||
proxy = Config.copilot.proxy,
|
||||
insecure = Config.copilot.allow_insecure,
|
||||
timeout = provider_conf.timeout,
|
||||
proxy = provider_conf.proxy,
|
||||
insecure = provider_conf.allow_insecure,
|
||||
}
|
||||
|
||||
local function handle_response(response)
|
||||
|
||||
@@ -157,32 +157,22 @@ M = setmetatable(M, {
|
||||
---@param t avante.Providers
|
||||
---@param k avante.ProviderName
|
||||
__index = function(t, k)
|
||||
if Config.providers[k] == nil then error("Failed to find provider: " .. k, 2) end
|
||||
|
||||
local provider_config = M.get_config(k)
|
||||
|
||||
if Config.vendors[k] ~= nil and k == "ollama" then
|
||||
Utils.warn(
|
||||
"ollama is now a first-class provider in avante.nvim, please stop using vendors to define ollama, for migration guide please refer to: https://github.com/yetone/avante.nvim/wiki/Custom-providers#ollama"
|
||||
)
|
||||
end
|
||||
---@diagnostic disable: undefined-field,no-unknown,inject-field
|
||||
if Config.vendors[k] ~= nil and k ~= "ollama" then
|
||||
if provider_config.parse_response_data ~= nil then
|
||||
Utils.error("parse_response_data is not supported for avante.nvim vendors")
|
||||
end
|
||||
if provider_config.__inherited_from ~= nil then
|
||||
local base_provider_config = M.get_config(provider_config.__inherited_from)
|
||||
local ok, module = pcall(require, "avante.providers." .. provider_config.__inherited_from)
|
||||
if not ok then error("Failed to load provider: " .. provider_config.__inherited_from) end
|
||||
t[k] = Utils.deep_extend_with_metatable("keep", provider_config, base_provider_config, module)
|
||||
else
|
||||
t[k] = provider_config
|
||||
end
|
||||
if provider_config.__inherited_from ~= nil then
|
||||
local base_provider_config = M.get_config(provider_config.__inherited_from)
|
||||
local ok, module = pcall(require, "avante.providers." .. provider_config.__inherited_from)
|
||||
if not ok then error("Failed to load provider: " .. provider_config.__inherited_from, 2) end
|
||||
provider_config = Utils.deep_extend_with_metatable("force", module, base_provider_config, provider_config)
|
||||
else
|
||||
local ok, module = pcall(require, "avante.providers." .. k)
|
||||
if not ok then error("Failed to load provider: " .. k) end
|
||||
t[k] = Utils.deep_extend_with_metatable("keep", provider_config, module)
|
||||
if ok then provider_config = Utils.deep_extend_with_metatable("force", module, provider_config) end
|
||||
end
|
||||
|
||||
t[k] = provider_config
|
||||
|
||||
if t[k].parse_api_key == nil then t[k].parse_api_key = function() return E.parse_envvar(t[k]) end end
|
||||
|
||||
-- default to gpt-4o as tokenizer
|
||||
@@ -241,29 +231,17 @@ end
|
||||
function M.parse_config(opts)
|
||||
---@type AvanteDefaultBaseProvider
|
||||
local provider_opts = {}
|
||||
---@type table<string, any>
|
||||
local request_body = {}
|
||||
|
||||
for key, value in pairs(opts) do
|
||||
if vim.tbl_contains(Config.BASE_PROVIDER_KEYS, key) then
|
||||
provider_opts[key] = value
|
||||
else
|
||||
request_body[key] = value
|
||||
end
|
||||
if key ~= "extra_request_body" then provider_opts[key] = value end
|
||||
end
|
||||
|
||||
request_body = vim
|
||||
.iter(request_body)
|
||||
:filter(function(_, v) return type(v) ~= "function" and type(v) ~= "userdata" end)
|
||||
:fold({}, function(acc, k, v)
|
||||
acc[k] = v
|
||||
return acc
|
||||
end)
|
||||
---@type table<string, any>
|
||||
local request_body = opts.extra_request_body or {}
|
||||
|
||||
return provider_opts, request_body
|
||||
end
|
||||
|
||||
---@private
|
||||
---@param provider_name avante.ProviderName
|
||||
function M.get_config(provider_name)
|
||||
provider_name = provider_name or Config.provider
|
||||
|
||||
@@ -48,7 +48,7 @@ function M:parse_curl_args(prompt_opts)
|
||||
|
||||
request_body = vim.tbl_deep_extend("force", request_body, {
|
||||
anthropic_version = "vertex-2023-10-16",
|
||||
temperature = 0,
|
||||
temperature = 0.75,
|
||||
max_tokens = 4096,
|
||||
stream = true,
|
||||
messages = messages,
|
||||
|
||||
Reference in New Issue
Block a user