From e9ab2ca2fd7b8df4bed0963f490f59d8ed119ecb Mon Sep 17 00:00:00 2001 From: yetone Date: Tue, 3 Jun 2025 04:01:01 +0800 Subject: [PATCH] refactor: providers config (#2117) --- README.md | 33 +- README_zh.md | 33 +- cursor-planning-mode.md | 4 +- lua/avante/config.lua | 399 +++++++++++++++---------- lua/avante/model_selector.lua | 26 +- lua/avante/providers/azure.lua | 9 +- lua/avante/providers/copilot.lua | 16 +- lua/avante/providers/init.lua | 48 +-- lua/avante/providers/vertex_claude.lua | 2 +- lua/avante/sidebar.lua | 5 - lua/avante/types.lua | 1 + plugin/avante.lua | 2 +- 12 files changed, 327 insertions(+), 251 deletions(-) diff --git a/README.md b/README.md index 9e0dd6d..7ed9a6a 100644 --- a/README.md +++ b/README.md @@ -79,13 +79,17 @@ For building binary if you wish to build from source, then `cargo` is required. -- add any opts here -- for example provider = "openai", - openai = { - endpoint = "https://api.openai.com/v1", - model = "gpt-4o", -- your desired model (or use gpt-4o, etc.) - timeout = 30000, -- Timeout in milliseconds, increase this for reasoning models - temperature = 0, - max_completion_tokens = 8192, -- Increase this to include reasoning tokens (for reasoning models) - --reasoning_effort = "medium", -- low|medium|high, only used for reasoning models + providers = { + openai = { + endpoint = "https://api.openai.com/v1", + model = "gpt-4o", -- your desired model (or use gpt-4o, etc.) + extra_request_body = { + timeout = 30000, -- Timeout in milliseconds, increase this for reasoning models + temperature = 0.75, + max_completion_tokens = 8192, -- Increase this to include reasoning tokens (for reasoning models) + --reasoning_effort = "medium", -- low|medium|high, only used for reasoning models + }, + }, }, }, -- if you want to build from source then do `make BUILD_FROM_SOURCE=true` @@ -326,12 +330,15 @@ _See [config.lua#L9](./lua/avante/config.lua) for the full config_ -- currently designating it as `copilot` provider is dangerous because: https://github.com/yetone/avante.nvim/issues/1048 -- Of course, you can reduce the request frequency by increasing `suggestion.debounce`. auto_suggestions_provider = "claude", - cursor_applying_provider = nil, -- The provider used in the applying phase of Cursor Planning Mode, defaults to nil, when nil uses Config.provider as the provider for the applying phase - claude = { - endpoint = "https://api.anthropic.com", - model = "claude-3-5-sonnet-20241022", - temperature = 0, - max_tokens = 4096, + providers = { + claude = { + endpoint = "https://api.anthropic.com", + model = "claude-3-5-sonnet-20241022", + extra_request_body = { + temperature = 0, + max_tokens = 4096, + }, + }, }, ---Specify the special dual_boost mode ---1. enabled: Whether to enable dual_boost mode. Default to false. diff --git a/README_zh.md b/README_zh.md index f3bd4c6..1d1e68b 100644 --- a/README_zh.md +++ b/README_zh.md @@ -64,13 +64,17 @@ -- 在此处添加任何选项 -- 例如 provider = "openai", - openai = { - endpoint = "https://api.openai.com/v1", - model = "gpt-4o", -- 您想要的模型(或使用 gpt-4o 等) - timeout = 30000, -- 超时时间(毫秒),增加此值以适应推理模型 - temperature = 0, - max_tokens = 8192, -- 增加此值以包括推理模型的推理令牌 - --reasoning_effort = "medium", -- low|medium|high,仅用于推理模型 + providers = { + openai = { + endpoint = "https://api.openai.com/v1", + model = "gpt-4o", -- 您想要的模型(或使用 gpt-4o 等) + extra_request_body = { + timeout = 30000, -- 超时时间(毫秒),增加此值以适应推理模型 + temperature = 0, + max_tokens = 8192, -- 增加此值以包括推理模型的推理令牌 + --reasoning_effort = "medium", -- low|medium|high,仅用于推理模型 + }, + }, }, }, -- 如果您想从源代码构建,请执行 `make BUILD_FROM_SOURCE=true` @@ -309,12 +313,15 @@ _请参见 [config.lua#L9](./lua/avante/config.lua) 以获取完整配置_ -- 目前将其指定为 `copilot` 提供者是危险的,因为:https://github.com/yetone/avante.nvim/issues/1048 -- 当然,您可以通过增加 `suggestion.debounce` 来减少请求频率。 auto_suggestions_provider = "claude", - cursor_applying_provider = nil, -- Cursor 规划模式应用阶段使用的提供者,默认为 nil,当为 nil 时使用 Config.provider 作为应用阶段的提供者 - claude = { - endpoint = "https://api.anthropic.com", - model = "claude-3-5-sonnet-20241022", - temperature = 0, - max_tokens = 4096, + providers = { + claude = { + endpoint = "https://api.anthropic.com", + model = "claude-3-5-sonnet-20241022", + extra_request_body = { + temperature = 0, + max_tokens = 4096, + }, + }, }, ---指定特殊的 dual_boost 模式 ---1. enabled: 是否启用 dual_boost 模式。默认为 false。 diff --git a/cursor-planning-mode.md b/cursor-planning-mode.md index a78d3d9..84733a1 100644 --- a/cursor-planning-mode.md +++ b/cursor-planning-mode.md @@ -28,8 +28,8 @@ Then enable it in avante.nvim: --- ... existing behaviours enable_cursor_planning_mode = true, -- enable cursor planning mode! }, - vendors = { - --- ... existing vendors + providers = { + --- ... existing providers groq = { -- define groq provider __inherited_from = 'openai', api_key_name = 'GROQ_API_KEY', diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 50d70ab..157c966 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -27,7 +27,6 @@ M._defaults = { -- currently designating it as `copilot` provider is dangerous because: https://github.com/yetone/avante.nvim/issues/1048 -- Of course, you can reduce the request frequency by increasing `suggestion.debounce`. auto_suggestions_provider = nil, - cursor_applying_provider = nil, memory_summary_provider = nil, ---@alias Tokenizer "tiktoken" | "hf" -- Used for counting tokens and encoding text. @@ -215,113 +214,137 @@ M._defaults = { }, }, }, - ---@type AvanteSupportedProvider - openai = { - endpoint = "https://api.openai.com/v1", - model = "gpt-4o", - timeout = 30000, -- Timeout in milliseconds, increase this for reasoning models - temperature = 0.75, - max_completion_tokens = 16384, -- Increase this to include reasoning tokens (for reasoning models) - reasoning_effort = "medium", -- low|medium|high, only used for reasoning models - }, - ---@type AvanteSupportedProvider - copilot = { - endpoint = "https://api.githubcopilot.com", - model = "gpt-4o-2024-11-20", - proxy = nil, -- [protocol://]host[:port] Use this proxy - allow_insecure = false, -- Allow insecure server connections - timeout = 30000, -- Timeout in milliseconds - temperature = 0.75, - max_tokens = 20480, - }, - ---@type AvanteAzureProvider - azure = { - endpoint = "", -- example: "https://.openai.azure.com" - deployment = "", -- Azure deployment name (e.g., "gpt-4o", "my-gpt-4o-deployment") - api_version = "2024-12-01-preview", - timeout = 30000, -- Timeout in milliseconds, increase this for reasoning models - temperature = 0.75, - max_completion_tokens = 20480, -- Increase this to include reasoning tokens (for reasoning models) - reasoning_effort = "medium", -- low|medium|high, only used for reasoning models - }, - ---@type AvanteSupportedProvider - claude = { - endpoint = "https://api.anthropic.com", - model = "claude-3-7-sonnet-20250219", - timeout = 30000, -- Timeout in milliseconds - temperature = 0.75, - max_tokens = 20480, - }, - ---@type AvanteSupportedProvider - bedrock = { - model = "anthropic.claude-3-5-sonnet-20241022-v2:0", - timeout = 30000, -- Timeout in milliseconds - temperature = 0.75, - max_tokens = 20480, - aws_region = "", -- AWS region to use for authentication and bedrock API - aws_profile = "", -- AWS profile to use for authentication, if unspecified uses default credentials chain - }, - ---@type AvanteSupportedProvider - gemini = { - endpoint = "https://generativelanguage.googleapis.com/v1beta/models", - model = "gemini-2.0-flash", - timeout = 30000, -- Timeout in milliseconds - temperature = 0.75, - max_tokens = 8192, - }, - ---@type AvanteSupportedProvider - vertex = { - endpoint = "https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/publishers/google/models", - model = "gemini-1.5-flash-002", - timeout = 30000, -- Timeout in milliseconds - temperature = 0.75, - max_tokens = 20480, - }, - ---@type AvanteSupportedProvider - cohere = { - endpoint = "https://api.cohere.com/v2", - model = "command-r-plus-08-2024", - timeout = 30000, -- Timeout in milliseconds - temperature = 0.75, - max_tokens = 20480, - }, - ---@type AvanteSupportedProvider - ollama = { - endpoint = "http://127.0.0.1:11434", - timeout = 30000, -- Timeout in milliseconds - options = { - temperature = 0.75, - num_ctx = 20480, - keep_alive = "5m", - }, - }, - ---@type AvanteSupportedProvider - vertex_claude = { - endpoint = "https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/publishers/antrhopic/models", - model = "claude-3-5-sonnet-v2@20241022", - timeout = 30000, -- Timeout in milliseconds - temperature = 0.75, - max_tokens = 20480, - }, ---To add support for custom provider, follow the format below ---See https://github.com/yetone/avante.nvim/wiki#custom-providers for more details ---@type {[string]: AvanteProvider} - vendors = { + providers = { + ---@type AvanteSupportedProvider + openai = { + endpoint = "https://api.openai.com/v1", + model = "gpt-4o", + timeout = 30000, -- Timeout in milliseconds, increase this for reasoning models + extra_request_body = { + temperature = 0.75, + max_completion_tokens = 16384, -- Increase this to include reasoning tokens (for reasoning models) + reasoning_effort = "medium", -- low|medium|high, only used for reasoning models + }, + }, + ---@type AvanteSupportedProvider + copilot = { + endpoint = "https://api.githubcopilot.com", + model = "gpt-4o-2024-11-20", + proxy = nil, -- [protocol://]host[:port] Use this proxy + allow_insecure = false, -- Allow insecure server connections + timeout = 30000, -- Timeout in milliseconds + extra_request_body = { + temperature = 0.75, + max_tokens = 20480, + }, + }, + ---@type AvanteAzureProvider + azure = { + endpoint = "", -- example: "https://.openai.azure.com" + deployment = "", -- Azure deployment name (e.g., "gpt-4o", "my-gpt-4o-deployment") + api_version = "2024-12-01-preview", + timeout = 30000, -- Timeout in milliseconds, increase this for reasoning models + extra_request_body = { + temperature = 0.75, + max_completion_tokens = 20480, -- Increase this to include reasoning tokens (for reasoning models) + reasoning_effort = "medium", -- low|medium|high, only used for reasoning models + }, + }, + ---@type AvanteSupportedProvider + claude = { + endpoint = "https://api.anthropic.com", + model = "claude-3-7-sonnet-20250219", + timeout = 30000, -- Timeout in milliseconds + extra_request_body = { + temperature = 0.75, + max_tokens = 20480, + }, + }, + ---@type AvanteSupportedProvider + bedrock = { + model = "anthropic.claude-3-5-sonnet-20241022-v2:0", + timeout = 30000, -- Timeout in milliseconds + extra_request_body = { + temperature = 0.75, + max_tokens = 20480, + }, + aws_region = "", -- AWS region to use for authentication and bedrock API + aws_profile = "", -- AWS profile to use for authentication, if unspecified uses default credentials chain + }, + ---@type AvanteSupportedProvider + gemini = { + endpoint = "https://generativelanguage.googleapis.com/v1beta/models", + model = "gemini-2.0-flash", + timeout = 30000, -- Timeout in milliseconds + extra_request_body = { + temperature = 0.75, + max_tokens = 8192, + }, + }, + ---@type AvanteSupportedProvider + vertex = { + endpoint = "https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/publishers/google/models", + model = "gemini-1.5-flash-002", + timeout = 30000, -- Timeout in milliseconds + extra_request_body = { + temperature = 0.75, + max_tokens = 20480, + }, + }, + ---@type AvanteSupportedProvider + cohere = { + endpoint = "https://api.cohere.com/v2", + model = "command-r-plus-08-2024", + timeout = 30000, -- Timeout in milliseconds + extra_request_body = { + temperature = 0.75, + max_tokens = 20480, + }, + }, + ---@type AvanteSupportedProvider + ollama = { + endpoint = "http://127.0.0.1:11434", + timeout = 30000, -- Timeout in milliseconds + extra_request_body = { + options = { + temperature = 0.75, + num_ctx = 20480, + keep_alive = "5m", + }, + }, + }, + ---@type AvanteSupportedProvider + vertex_claude = { + endpoint = "https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/publishers/antrhopic/models", + model = "claude-3-5-sonnet-v2@20241022", + timeout = 30000, -- Timeout in milliseconds + extra_request_body = { + temperature = 0.75, + max_tokens = 20480, + }, + }, ---@type AvanteSupportedProvider ["claude-haiku"] = { __inherited_from = "claude", model = "claude-3-5-haiku-20241022", timeout = 30000, -- Timeout in milliseconds - temperature = 0.75, - max_tokens = 8192, + extra_request_body = { + temperature = 0.75, + max_tokens = 8192, + }, }, ---@type AvanteSupportedProvider ["claude-opus"] = { __inherited_from = "claude", model = "claude-3-opus-20240229", timeout = 30000, -- Timeout in milliseconds - temperature = 0.75, - max_tokens = 20480, + extra_request_body = { + temperature = 0.75, + max_tokens = 20480, + }, }, ["openai-gpt-4o-mini"] = { __inherited_from = "openai", @@ -342,7 +365,9 @@ M._defaults = { ["bedrock-claude-3.7-sonnet"] = { __inherited_from = "bedrock", model = "us.anthropic.claude-3-7-sonnet-20250219-v1:0", - max_tokens = 4096, + extra_request_body = { + max_tokens = 4096, + }, }, }, ---Specify the special dual_boost mode @@ -542,17 +567,129 @@ M._defaults = { ---@diagnostic disable-next-line: missing-fields M._options = {} ----@type avante.ProviderName[] -M.provider_names = {} - ---@param opts? avante.Config function M.setup(opts) vim.validate({ opts = { opts, "table", true } }) + opts = opts or {} + + local migration_url = "https://github.com/yetone/avante.nvim/wiki/Provider-configuration-migration-guide" + + if opts.providers ~= nil then + for k, v in pairs(opts.providers) do + local extra_request_body + if type(v) == "table" then + if M._defaults.providers[k] ~= nil then + extra_request_body = M._defaults.providers[k].extra_request_body + elseif v.__inherited_from ~= nil then + if M._defaults.providers[v.__inherited_from] ~= nil then + extra_request_body = M._defaults.providers[v.__inherited_from].extra_request_body + end + end + end + if extra_request_body ~= nil then + for k_, v_ in pairs(v) do + if extra_request_body[k_] ~= nil then + opts.providers[k].extra_request_body = opts.providers[k].extra_request_body or {} + opts.providers[k].extra_request_body[k_] = v_ + Utils.warn( + string.format( + "[DEPRECATED] The configuration of `providers.%s.%s` should be placed in `providers.%s.extra_request_body.%s`; for detailed migration instructions, please visit: %s", + k, + k_, + k, + k_, + migration_url + ), + { title = "Avante" } + ) + end + end + end + end + end + + for k, v in pairs(opts) do + if M._defaults.providers[k] ~= nil then + opts.providers = opts.providers or {} + opts.providers[k] = v + Utils.warn( + string.format( + "[DEPRACATED] The configuration of `%s` should be placed in `providers.%s`. For detailed migration instructions, please visit: %s", + k, + k, + migration_url + ), + { title = "Avante" } + ) + local extra_request_body = M._defaults.providers[k].extra_request_body + if type(v) == "table" and extra_request_body ~= nil then + for k_, v_ in pairs(v) do + if extra_request_body[k_] ~= nil then + opts.providers[k].extra_request_body = opts.providers[k].extra_request_body or {} + opts.providers[k].extra_request_body[k_] = v_ + Utils.warn( + string.format( + "[DEPRECATED] The configuration of `%s.%s` should be placed in `providers.%s.extra_request_body.%s`; for detailed migration instructions, please visit: %s", + k, + k_, + k, + k_, + migration_url + ), + { title = "Avante" } + ) + end + end + end + end + if k == "vendors" and v ~= nil then + for k2, v2 in pairs(v) do + opts.providers = opts.providers or {} + opts.providers[k2] = v2 + Utils.warn( + string.format( + "[DEPRECATED] The configuration of `vendors.%s` should be placed in `providers.%s`. For detailed migration instructions, please visit: %s", + k2, + k2, + migration_url + ), + { title = "Avante" } + ) + if + type(v2) == "table" + and v2.__inherited_from ~= nil + and M._defaults.providers[v2.__inherited_from] ~= nil + then + local extra_request_body = M._defaults.providers[v2.__inherited_from].extra_request_body + if extra_request_body ~= nil then + for k2_, v2_ in pairs(v2) do + if extra_request_body[k2_] ~= nil then + opts.providers[k2].extra_request_body = opts.providers[k2].extra_request_body or {} + opts.providers[k2].extra_request_body[k2_] = v2_ + Utils.warn( + string.format( + "[DEPRECATED] The configuration of `vendors.%s.%s` should be placed in `providers.%s.extra_request_body.%s`; for detailed migration instructions, please visit: %s", + k2, + k2_, + k2, + k2_, + migration_url + ), + { title = "Avante" } + ) + end + end + end + end + end + end + end + local merged = vim.tbl_deep_extend( "force", M._defaults, - opts or {}, + opts, ---@type avante.Config { behaviour = { @@ -562,14 +699,6 @@ function M.setup(opts) ) M._options = merged - M.provider_names = vim - .iter(M._defaults) - :filter(function(_, value) return type(value) == "table" and (value.endpoint ~= nil or value.model ~= nil) end) - :fold({}, function(acc, k) - acc = vim.list_extend({}, acc) - acc = vim.list_extend(acc, { k }) - return acc - end) ---@diagnostic disable-next-line: undefined-field if M._options.disable_tools ~= nil then @@ -588,12 +717,8 @@ function M.setup(opts) vim.validate({ provider = { M._options.provider, "string", false } }) - if next(M._options.vendors) ~= nil then - for k, v in pairs(M._options.vendors) do - M._options.vendors[k] = type(v) == "function" and v() or v - end - vim.validate({ vendors = { M._options.vendors, "table", true } }) - M.provider_names = vim.list_extend(M.provider_names, vim.tbl_keys(M._options.vendors)) + for k, v in pairs(M._options.providers) do + M._options.providers[k] = type(v) == "function" and v() or v end end @@ -603,12 +728,8 @@ function M.override(opts) M._options = vim.tbl_deep_extend("force", M._options, opts or {}) - if next(M._options.vendors) ~= nil then - for k, v in pairs(M._options.vendors) do - M._options.vendors[k] = type(v) == "function" and v() or v - if not vim.tbl_contains(M.provider_names, k) then M.provider_names = vim.list_extend(M.provider_names, { k }) end - end - vim.validate({ vendors = { M._options.vendors, "table", true } }) + for k, v in pairs(M._options.providers) do + M._options.providers[k] = type(v) == "function" and v() or v end end @@ -622,25 +743,15 @@ function M.support_paste_image() return Utils.has("img-clip.nvim") or Utils.has( function M.get_window_width() return math.ceil(vim.o.columns * (M.windows.width / 100)) end ----@param provider_name avante.ProviderName ----@return boolean -function M.has_provider(provider_name) return vim.list_contains(M.provider_names, provider_name) end - ---get supported providers ---@param provider_name avante.ProviderName function M.get_provider_config(provider_name) - if not M.has_provider(provider_name) then error("No provider found: " .. provider_name, 2) end local found = false local config = {} - if M.vendors and M.vendors[provider_name] ~= nil then + if M.providers[provider_name] ~= nil then found = true - config = vim.tbl_deep_extend("force", config, vim.deepcopy(M.vendors[provider_name], true)) - end - - if M._options[provider_name] ~= nil then - found = true - config = vim.tbl_deep_extend("force", config, vim.deepcopy(M._options[provider_name], true)) + config = vim.tbl_deep_extend("force", config, vim.deepcopy(M.providers[provider_name], true)) end if not found then error("Failed to find provider: " .. provider_name, 2) end @@ -648,30 +759,4 @@ function M.get_provider_config(provider_name) return config end -M.BASE_PROVIDER_KEYS = { - "endpoint", - "extra_headers", - "model", - "deployment", - "api_version", - "proxy", - "allow_insecure", - "api_key_name", - "timeout", - "display_name", - "aws_region", - "aws_profile", - -- internal - "local", - "_shellenv", - "tokenizer_id", - "role_map", - "support_prompt_caching", - "__inherited_from", - "disable_tools", - "entra", - "hide_in_model_selector", - "use_ReAct_prompt", -} - return M diff --git a/lua/avante/model_selector.lua b/lua/avante/model_selector.lua index dc2f1e6..0cf468f 100644 --- a/lua/avante/model_selector.lua +++ b/lua/avante/model_selector.lua @@ -15,7 +15,7 @@ local models_list_cached_result = {} ---@param provider_cfg table ---@return table local function create_model_entries(provider_name, provider_cfg) - if provider_cfg.models_list then + if provider_cfg.models_list and provider_cfg.__inherited_from == nil then local models_list if type(provider_cfg.models_list) == "function" then if M.models_list_invoked[provider_cfg.models_list] then return {} end @@ -24,7 +24,7 @@ local function create_model_entries(provider_name, provider_cfg) if cached_result then models_list = cached_result else - models_list = provider_cfg.models_list() + models_list = provider_cfg:models_list() models_list_cached_result[provider_cfg.models_list] = models_list end else @@ -65,13 +65,9 @@ function M.open() M.models_list_returned = {} local models = {} - -- Collect models from main providers and vendors - for _, provider_name in ipairs(Config.provider_names) do - local ok, provider_cfg = pcall(function() return Providers[provider_name] end) - if not ok then - Utils.warn("Failed to load provider: " .. provider_name) - goto continue - end + -- Collect models from providers + for provider_name, _ in pairs(Config.providers) do + local provider_cfg = Providers[provider_name] if provider_cfg.hide_in_model_selector then goto continue end if not provider_cfg.is_env_set() then goto continue end local entries = create_model_entries(provider_name, provider_cfg) @@ -106,11 +102,13 @@ function M.open() -- Update config with new model Config.override({ - [choice.provider_name] = vim.tbl_deep_extend( - "force", - Config.get_provider_config(choice.provider_name), - { model = choice.model } - ), + providers = { + [choice.provider_name] = vim.tbl_deep_extend( + "force", + Config.get_provider_config(choice.provider_name), + { model = choice.model } + ), + }, }) Utils.info("Switched to model: " .. choice.name) diff --git a/lua/avante/providers/azure.lua b/lua/avante/providers/azure.lua index 12f2713..3e60a96 100644 --- a/lua/avante/providers/azure.lua +++ b/lua/avante/providers/azure.lua @@ -1,10 +1,13 @@ ----@class AvanteAzureProvider: AvanteDefaultBaseProvider ----@field deployment string ----@field api_version string +---@class AvanteAzureExtraRequestBody ---@field temperature number ---@field max_completion_tokens number ---@field reasoning_effort? string +---@class AvanteAzureProvider: AvanteDefaultBaseProvider +---@field deployment string +---@field api_version string +---@field extra_request_body AvanteAzureExtraRequestBody + local Utils = require("avante.utils") local P = require("avante.providers") local O = require("avante.providers").openai diff --git a/lua/avante/providers/copilot.lua b/lua/avante/providers/copilot.lua index 00e48a4..4a2eb27 100644 --- a/lua/avante/providers/copilot.lua +++ b/lua/avante/providers/copilot.lua @@ -27,7 +27,6 @@ local curl = require("plenary.curl") -local Config = require("avante.config") local Path = require("plenary.path") local Utils = require("avante.utils") local Providers = require("avante.providers") @@ -156,14 +155,16 @@ function H.refresh_token(async, force) return false end + local provider_conf = Providers.get_config("copilot") + local curl_opts = { headers = { ["Authorization"] = "token " .. M.state.oauth_token, ["Accept"] = "application/json", }, - timeout = Config.copilot.timeout, - proxy = Config.copilot.proxy, - insecure = Config.copilot.allow_insecure, + timeout = provider_conf.timeout, + proxy = provider_conf.proxy, + insecure = provider_conf.allow_insecure, } local function handle_response(response) @@ -219,6 +220,7 @@ function M:models_list() -- refresh token synchronously, only if it has expired -- (this should rarely happen, as we refresh the token in the background) H.refresh_token(false, false) + local provider_conf = Providers.parse_config(self) local curl_opts = { headers = { ["Content-Type"] = "application/json", @@ -226,9 +228,9 @@ function M:models_list() ["Copilot-Integration-Id"] = "vscode-chat", ["Editor-Version"] = ("Neovim/%s.%s.%s"):format(vim.version().major, vim.version().minor, vim.version().patch), }, - timeout = Config.copilot.timeout, - proxy = Config.copilot.proxy, - insecure = Config.copilot.allow_insecure, + timeout = provider_conf.timeout, + proxy = provider_conf.proxy, + insecure = provider_conf.allow_insecure, } local function handle_response(response) diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index 7e1680c..210c660 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -157,32 +157,22 @@ M = setmetatable(M, { ---@param t avante.Providers ---@param k avante.ProviderName __index = function(t, k) + if Config.providers[k] == nil then error("Failed to find provider: " .. k, 2) end + local provider_config = M.get_config(k) - if Config.vendors[k] ~= nil and k == "ollama" then - Utils.warn( - "ollama is now a first-class provider in avante.nvim, please stop using vendors to define ollama, for migration guide please refer to: https://github.com/yetone/avante.nvim/wiki/Custom-providers#ollama" - ) - end - ---@diagnostic disable: undefined-field,no-unknown,inject-field - if Config.vendors[k] ~= nil and k ~= "ollama" then - if provider_config.parse_response_data ~= nil then - Utils.error("parse_response_data is not supported for avante.nvim vendors") - end - if provider_config.__inherited_from ~= nil then - local base_provider_config = M.get_config(provider_config.__inherited_from) - local ok, module = pcall(require, "avante.providers." .. provider_config.__inherited_from) - if not ok then error("Failed to load provider: " .. provider_config.__inherited_from) end - t[k] = Utils.deep_extend_with_metatable("keep", provider_config, base_provider_config, module) - else - t[k] = provider_config - end + if provider_config.__inherited_from ~= nil then + local base_provider_config = M.get_config(provider_config.__inherited_from) + local ok, module = pcall(require, "avante.providers." .. provider_config.__inherited_from) + if not ok then error("Failed to load provider: " .. provider_config.__inherited_from, 2) end + provider_config = Utils.deep_extend_with_metatable("force", module, base_provider_config, provider_config) else local ok, module = pcall(require, "avante.providers." .. k) - if not ok then error("Failed to load provider: " .. k) end - t[k] = Utils.deep_extend_with_metatable("keep", provider_config, module) + if ok then provider_config = Utils.deep_extend_with_metatable("force", module, provider_config) end end + t[k] = provider_config + if t[k].parse_api_key == nil then t[k].parse_api_key = function() return E.parse_envvar(t[k]) end end -- default to gpt-4o as tokenizer @@ -241,29 +231,17 @@ end function M.parse_config(opts) ---@type AvanteDefaultBaseProvider local provider_opts = {} - ---@type table - local request_body = {} for key, value in pairs(opts) do - if vim.tbl_contains(Config.BASE_PROVIDER_KEYS, key) then - provider_opts[key] = value - else - request_body[key] = value - end + if key ~= "extra_request_body" then provider_opts[key] = value end end - request_body = vim - .iter(request_body) - :filter(function(_, v) return type(v) ~= "function" and type(v) ~= "userdata" end) - :fold({}, function(acc, k, v) - acc[k] = v - return acc - end) + ---@type table + local request_body = opts.extra_request_body or {} return provider_opts, request_body end ----@private ---@param provider_name avante.ProviderName function M.get_config(provider_name) provider_name = provider_name or Config.provider diff --git a/lua/avante/providers/vertex_claude.lua b/lua/avante/providers/vertex_claude.lua index cb16bce..04d399a 100644 --- a/lua/avante/providers/vertex_claude.lua +++ b/lua/avante/providers/vertex_claude.lua @@ -48,7 +48,7 @@ function M:parse_curl_args(prompt_opts) request_body = vim.tbl_deep_extend("force", request_body, { anthropic_version = "vertex-2023-10-16", - temperature = 0, + temperature = 0.75, max_tokens = 4096, stream = true, messages = messages, diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 4eee22c..1433fe0 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -2356,11 +2356,6 @@ function Sidebar:create_input_container() end end - -- local model = Config.has_provider(Config.provider) and Config.get_provider_config(Config.provider).model - -- or "default" - -- - -- local timestamp = Utils.get_timestamp() - local selected_filepaths = self.file_selector:get_selected_filepaths() ---@type AvanteSelectedCode | nil diff --git a/lua/avante/types.lua b/lua/avante/types.lua index 1c17592..f958c13 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -220,6 +220,7 @@ vim.g.avante_login = vim.g.avante_login ---@class AvanteDefaultBaseProvider: table ---@field endpoint? string ---@field extra_headers? table +---@field extra_request_body? table ---@field model? string ---@field local? boolean ---@field proxy? string diff --git a/plugin/avante.lua b/plugin/avante.lua index 869637f..c93ce1e 100644 --- a/plugin/avante.lua +++ b/plugin/avante.lua @@ -118,7 +118,7 @@ cmd("SwitchProvider", function(opts) require("avante.api").switch_provider(vim.t complete = function(_, line, _) local prefix = line:match("AvanteSwitchProvider%s*(.*)$") or "" ---@param key string - return vim.tbl_filter(function(key) return key:find(prefix, 1, true) == 1 end, Config.provider_names) + return vim.tbl_filter(function(key) return key:find(prefix, 1, true) == 1 end, vim.tbl_keys(Config.providers)) end, }) cmd(