feat: allow overriding provider headers (#2161)

This commit is contained in:
Avinash Thakur
2025-06-07 23:34:00 +05:30
committed by GitHub
parent 86489ef2be
commit 8396cc77e4
12 changed files with 24 additions and 23 deletions

View File

@@ -58,7 +58,7 @@ function M:parse_curl_args(prompt_opts)
),
proxy = provider_conf.proxy,
insecure = provider_conf.allow_insecure,
headers = headers,
headers = Utils.tbl_override(headers, self.extra_headers),
body = vim.tbl_deep_extend("force", {
messages = self:parse_messages(prompt_opts),
stream = true,

View File

@@ -152,7 +152,7 @@ function M:parse_curl_args(prompt_opts)
url = endpoint,
proxy = provider_conf.proxy,
insecure = provider_conf.allow_insecure,
headers = headers,
headers = Utils.tbl_override(headers, self.extra_headers),
body = body_payload,
rawArgs = rawArgs,
}

View File

@@ -396,7 +396,7 @@ function M:parse_curl_args(prompt_opts)
url = Utils.url_join(provider_conf.endpoint, "/v1/messages"),
proxy = provider_conf.proxy,
insecure = provider_conf.allow_insecure,
headers = headers,
headers = Utils.tbl_override(headers, self.extra_headers),
body = vim.tbl_deep_extend("force", {
model = provider_conf.model,
system = {

View File

@@ -95,7 +95,7 @@ function M:parse_curl_args(prompt_opts)
url = Utils.url_join(provider_conf.endpoint, "/chat"),
proxy = provider_conf.proxy,
insecure = provider_conf.allow_insecure,
headers = headers,
headers = Utils.tbl_override(headers, self.extra_headers),
body = vim.tbl_deep_extend("force", {
model = provider_conf.model,
stream = true,

View File

@@ -222,12 +222,12 @@ function M:models_list()
H.refresh_token(false, false)
local provider_conf = Providers.parse_config(self)
local curl_opts = {
headers = {
headers = Utils.tbl_override({
["Content-Type"] = "application/json",
["Authorization"] = "Bearer " .. M.state.github_token.token,
["Copilot-Integration-Id"] = "vscode-chat",
["Editor-Version"] = ("Neovim/%s.%s.%s"):format(vim.version().major, vim.version().minor, vim.version().patch),
},
}, self.extra_headers),
timeout = provider_conf.timeout,
proxy = provider_conf.proxy,
insecure = provider_conf.allow_insecure,
@@ -288,12 +288,11 @@ function M:parse_curl_args(prompt_opts)
timeout = provider_conf.timeout,
proxy = provider_conf.proxy,
insecure = provider_conf.allow_insecure,
headers = {
["Content-Type"] = "application/json",
headers = Utils.tbl_override({
["Authorization"] = "Bearer " .. M.state.github_token.token,
["Copilot-Integration-Id"] = "vscode-chat",
["Editor-Version"] = ("Neovim/%s.%s.%s"):format(vim.version().major, vim.version().minor, vim.version().patch),
},
}, self.extra_headers),
body = vim.tbl_deep_extend("force", {
model = provider_conf.model,
messages = self:parse_messages(prompt_opts),

View File

@@ -309,7 +309,7 @@ function M:parse_curl_args(prompt_opts)
),
proxy = provider_conf.proxy,
insecure = provider_conf.allow_insecure,
headers = { ["Content-Type"] = "application/json" },
headers = Utils.tbl_override({ ["Content-Type"] = "application/json" }, self.extra_headers),
body = M.prepare_request_body(self, prompt_opts, provider_conf, request_body),
}
end

View File

@@ -213,7 +213,7 @@ function M:parse_curl_args(prompt_opts)
return {
url = Utils.url_join(provider_conf.endpoint, "/api/chat"),
headers = headers,
headers = Utils.tbl_override(headers, self.extra_headers),
body = vim.tbl_deep_extend("force", {
model = provider_conf.model,
messages = self:parse_messages(prompt_opts),

View File

@@ -497,12 +497,6 @@ function M:parse_curl_args(prompt_opts)
["Content-Type"] = "application/json",
}
if provider_conf.extra_headers then
for key, value in pairs(provider_conf.extra_headers) do
headers[key] = value
end
end
if Providers.env.require_api_key(provider_conf) then
local api_key = self.parse_api_key()
if api_key == nil then
@@ -536,7 +530,7 @@ function M:parse_curl_args(prompt_opts)
url = Utils.url_join(provider_conf.endpoint, "/chat/completions"),
proxy = provider_conf.proxy,
insecure = provider_conf.allow_insecure,
headers = headers,
headers = Utils.tbl_override(headers, self.extra_headers),
body = vim.tbl_deep_extend("force", {
model = provider_conf.model,
messages = self:parse_messages(prompt_opts),

View File

@@ -1,4 +1,5 @@
local P = require("avante.providers")
local Utils = require("avante.utils")
local Gemini = require("avante.providers.gemini")
---@class AvanteProviderFunctor
@@ -50,10 +51,10 @@ function M:parse_curl_args(prompt_opts)
return {
url = url,
headers = {
headers = Utils.tbl_override({
["Authorization"] = "Bearer " .. bearer_token,
["Content-Type"] = "application/json; charset=utf-8",
},
}, self.extra_headers),
proxy = provider_conf.proxy,
insecure = provider_conf.allow_insecure,
body = Gemini.prepare_request_body(self, prompt_opts, provider_conf, request_body),

View File

@@ -1,4 +1,5 @@
local P = require("avante.providers")
local Utils = require("avante.utils")
local Vertex = require("avante.providers.vertex")
---@class AvanteProviderFunctor
@@ -64,10 +65,10 @@ function M:parse_curl_args(prompt_opts)
return {
url = url,
headers = {
headers = Utils.tbl_override({
["Authorization"] = "Bearer " .. Vertex.parse_api_key(),
["Content-Type"] = "application/json; charset=utf-8",
},
}, self.extra_headers),
body = vim.tbl_deep_extend("force", {}, request_body),
}
end

View File

@@ -221,7 +221,6 @@ vim.g.avante_login = vim.g.avante_login
---
---@class AvanteDefaultBaseProvider: table<string, any>
---@field endpoint? string
---@field extra_headers? table<string, any>
---@field extra_request_body? table<string, any>
---@field model? string
---@field local? boolean
@@ -291,6 +290,7 @@ vim.g.avante_login = vim.g.avante_login
---
---@class AvanteProviderFunctor
---@field _model_list_cache table
---@field extra_headers function(table) -> table | table | nil
---@field support_prompt_caching boolean | nil
---@field role_map table<"user" | "assistant", string>
---@field parse_messages AvanteMessagesParser

View File

@@ -1662,4 +1662,10 @@ function M.count_lines(str)
return count
end
function M.tbl_override(value, override)
override = override or {}
if type(override) == "function" then return override(value) or value end
return vim.tbl_extend("force", value, override)
end
return M