diff --git a/lua/avante/providers/azure.lua b/lua/avante/providers/azure.lua index 3e60a96..b9cb336 100644 --- a/lua/avante/providers/azure.lua +++ b/lua/avante/providers/azure.lua @@ -58,7 +58,7 @@ function M:parse_curl_args(prompt_opts) ), proxy = provider_conf.proxy, insecure = provider_conf.allow_insecure, - headers = headers, + headers = Utils.tbl_override(headers, self.extra_headers), body = vim.tbl_deep_extend("force", { messages = self:parse_messages(prompt_opts), stream = true, diff --git a/lua/avante/providers/bedrock.lua b/lua/avante/providers/bedrock.lua index f7a04fe..54a1211 100644 --- a/lua/avante/providers/bedrock.lua +++ b/lua/avante/providers/bedrock.lua @@ -152,7 +152,7 @@ function M:parse_curl_args(prompt_opts) url = endpoint, proxy = provider_conf.proxy, insecure = provider_conf.allow_insecure, - headers = headers, + headers = Utils.tbl_override(headers, self.extra_headers), body = body_payload, rawArgs = rawArgs, } diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index 8d9c3c8..f0498fe 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -396,7 +396,7 @@ function M:parse_curl_args(prompt_opts) url = Utils.url_join(provider_conf.endpoint, "/v1/messages"), proxy = provider_conf.proxy, insecure = provider_conf.allow_insecure, - headers = headers, + headers = Utils.tbl_override(headers, self.extra_headers), body = vim.tbl_deep_extend("force", { model = provider_conf.model, system = { diff --git a/lua/avante/providers/cohere.lua b/lua/avante/providers/cohere.lua index 6946d5b..8a0852a 100644 --- a/lua/avante/providers/cohere.lua +++ b/lua/avante/providers/cohere.lua @@ -95,7 +95,7 @@ function M:parse_curl_args(prompt_opts) url = Utils.url_join(provider_conf.endpoint, "/chat"), proxy = provider_conf.proxy, insecure = provider_conf.allow_insecure, - headers = headers, + headers = Utils.tbl_override(headers, self.extra_headers), body = vim.tbl_deep_extend("force", { model = provider_conf.model, stream = true, diff --git a/lua/avante/providers/copilot.lua b/lua/avante/providers/copilot.lua index 4a2eb27..7f59646 100644 --- a/lua/avante/providers/copilot.lua +++ b/lua/avante/providers/copilot.lua @@ -222,12 +222,12 @@ function M:models_list() H.refresh_token(false, false) local provider_conf = Providers.parse_config(self) local curl_opts = { - headers = { + headers = Utils.tbl_override({ ["Content-Type"] = "application/json", ["Authorization"] = "Bearer " .. M.state.github_token.token, ["Copilot-Integration-Id"] = "vscode-chat", ["Editor-Version"] = ("Neovim/%s.%s.%s"):format(vim.version().major, vim.version().minor, vim.version().patch), - }, + }, self.extra_headers), timeout = provider_conf.timeout, proxy = provider_conf.proxy, insecure = provider_conf.allow_insecure, @@ -288,12 +288,11 @@ function M:parse_curl_args(prompt_opts) timeout = provider_conf.timeout, proxy = provider_conf.proxy, insecure = provider_conf.allow_insecure, - headers = { - ["Content-Type"] = "application/json", + headers = Utils.tbl_override({ ["Authorization"] = "Bearer " .. M.state.github_token.token, ["Copilot-Integration-Id"] = "vscode-chat", ["Editor-Version"] = ("Neovim/%s.%s.%s"):format(vim.version().major, vim.version().minor, vim.version().patch), - }, + }, self.extra_headers), body = vim.tbl_deep_extend("force", { model = provider_conf.model, messages = self:parse_messages(prompt_opts), diff --git a/lua/avante/providers/gemini.lua b/lua/avante/providers/gemini.lua index 50ad9d7..866ece6 100644 --- a/lua/avante/providers/gemini.lua +++ b/lua/avante/providers/gemini.lua @@ -309,7 +309,7 @@ function M:parse_curl_args(prompt_opts) ), proxy = provider_conf.proxy, insecure = provider_conf.allow_insecure, - headers = { ["Content-Type"] = "application/json" }, + headers = Utils.tbl_override({ ["Content-Type"] = "application/json" }, self.extra_headers), body = M.prepare_request_body(self, prompt_opts, provider_conf, request_body), } end diff --git a/lua/avante/providers/ollama.lua b/lua/avante/providers/ollama.lua index a3569fb..ac18fcd 100644 --- a/lua/avante/providers/ollama.lua +++ b/lua/avante/providers/ollama.lua @@ -213,7 +213,7 @@ function M:parse_curl_args(prompt_opts) return { url = Utils.url_join(provider_conf.endpoint, "/api/chat"), - headers = headers, + headers = Utils.tbl_override(headers, self.extra_headers), body = vim.tbl_deep_extend("force", { model = provider_conf.model, messages = self:parse_messages(prompt_opts), diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index 25b658f..002bc64 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -497,12 +497,6 @@ function M:parse_curl_args(prompt_opts) ["Content-Type"] = "application/json", } - if provider_conf.extra_headers then - for key, value in pairs(provider_conf.extra_headers) do - headers[key] = value - end - end - if Providers.env.require_api_key(provider_conf) then local api_key = self.parse_api_key() if api_key == nil then @@ -536,7 +530,7 @@ function M:parse_curl_args(prompt_opts) url = Utils.url_join(provider_conf.endpoint, "/chat/completions"), proxy = provider_conf.proxy, insecure = provider_conf.allow_insecure, - headers = headers, + headers = Utils.tbl_override(headers, self.extra_headers), body = vim.tbl_deep_extend("force", { model = provider_conf.model, messages = self:parse_messages(prompt_opts), diff --git a/lua/avante/providers/vertex.lua b/lua/avante/providers/vertex.lua index 499ed05..44d48df 100644 --- a/lua/avante/providers/vertex.lua +++ b/lua/avante/providers/vertex.lua @@ -1,4 +1,5 @@ local P = require("avante.providers") +local Utils = require("avante.utils") local Gemini = require("avante.providers.gemini") ---@class AvanteProviderFunctor @@ -50,10 +51,10 @@ function M:parse_curl_args(prompt_opts) return { url = url, - headers = { + headers = Utils.tbl_override({ ["Authorization"] = "Bearer " .. bearer_token, ["Content-Type"] = "application/json; charset=utf-8", - }, + }, self.extra_headers), proxy = provider_conf.proxy, insecure = provider_conf.allow_insecure, body = Gemini.prepare_request_body(self, prompt_opts, provider_conf, request_body), diff --git a/lua/avante/providers/vertex_claude.lua b/lua/avante/providers/vertex_claude.lua index 04d399a..5b72df3 100644 --- a/lua/avante/providers/vertex_claude.lua +++ b/lua/avante/providers/vertex_claude.lua @@ -1,4 +1,5 @@ local P = require("avante.providers") +local Utils = require("avante.utils") local Vertex = require("avante.providers.vertex") ---@class AvanteProviderFunctor @@ -64,10 +65,10 @@ function M:parse_curl_args(prompt_opts) return { url = url, - headers = { + headers = Utils.tbl_override({ ["Authorization"] = "Bearer " .. Vertex.parse_api_key(), ["Content-Type"] = "application/json; charset=utf-8", - }, + }, self.extra_headers), body = vim.tbl_deep_extend("force", {}, request_body), } end diff --git a/lua/avante/types.lua b/lua/avante/types.lua index 2eabb90..863a394 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -221,7 +221,6 @@ vim.g.avante_login = vim.g.avante_login --- ---@class AvanteDefaultBaseProvider: table ---@field endpoint? string ----@field extra_headers? table ---@field extra_request_body? table ---@field model? string ---@field local? boolean @@ -291,6 +290,7 @@ vim.g.avante_login = vim.g.avante_login --- ---@class AvanteProviderFunctor ---@field _model_list_cache table +---@field extra_headers function(table) -> table | table | nil ---@field support_prompt_caching boolean | nil ---@field role_map table<"user" | "assistant", string> ---@field parse_messages AvanteMessagesParser diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index e0ae1ed..e841ec2 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -1662,4 +1662,10 @@ function M.count_lines(str) return count end +function M.tbl_override(value, override) + override = override or {} + if type(override) == "function" then return override(value) or value end + return vim.tbl_extend("force", value, override) +end + return M