feat: allow overriding provider headers (#2161)
This commit is contained in:
@@ -58,7 +58,7 @@ function M:parse_curl_args(prompt_opts)
|
||||
),
|
||||
proxy = provider_conf.proxy,
|
||||
insecure = provider_conf.allow_insecure,
|
||||
headers = headers,
|
||||
headers = Utils.tbl_override(headers, self.extra_headers),
|
||||
body = vim.tbl_deep_extend("force", {
|
||||
messages = self:parse_messages(prompt_opts),
|
||||
stream = true,
|
||||
|
||||
@@ -152,7 +152,7 @@ function M:parse_curl_args(prompt_opts)
|
||||
url = endpoint,
|
||||
proxy = provider_conf.proxy,
|
||||
insecure = provider_conf.allow_insecure,
|
||||
headers = headers,
|
||||
headers = Utils.tbl_override(headers, self.extra_headers),
|
||||
body = body_payload,
|
||||
rawArgs = rawArgs,
|
||||
}
|
||||
|
||||
@@ -396,7 +396,7 @@ function M:parse_curl_args(prompt_opts)
|
||||
url = Utils.url_join(provider_conf.endpoint, "/v1/messages"),
|
||||
proxy = provider_conf.proxy,
|
||||
insecure = provider_conf.allow_insecure,
|
||||
headers = headers,
|
||||
headers = Utils.tbl_override(headers, self.extra_headers),
|
||||
body = vim.tbl_deep_extend("force", {
|
||||
model = provider_conf.model,
|
||||
system = {
|
||||
|
||||
@@ -95,7 +95,7 @@ function M:parse_curl_args(prompt_opts)
|
||||
url = Utils.url_join(provider_conf.endpoint, "/chat"),
|
||||
proxy = provider_conf.proxy,
|
||||
insecure = provider_conf.allow_insecure,
|
||||
headers = headers,
|
||||
headers = Utils.tbl_override(headers, self.extra_headers),
|
||||
body = vim.tbl_deep_extend("force", {
|
||||
model = provider_conf.model,
|
||||
stream = true,
|
||||
|
||||
@@ -222,12 +222,12 @@ function M:models_list()
|
||||
H.refresh_token(false, false)
|
||||
local provider_conf = Providers.parse_config(self)
|
||||
local curl_opts = {
|
||||
headers = {
|
||||
headers = Utils.tbl_override({
|
||||
["Content-Type"] = "application/json",
|
||||
["Authorization"] = "Bearer " .. M.state.github_token.token,
|
||||
["Copilot-Integration-Id"] = "vscode-chat",
|
||||
["Editor-Version"] = ("Neovim/%s.%s.%s"):format(vim.version().major, vim.version().minor, vim.version().patch),
|
||||
},
|
||||
}, self.extra_headers),
|
||||
timeout = provider_conf.timeout,
|
||||
proxy = provider_conf.proxy,
|
||||
insecure = provider_conf.allow_insecure,
|
||||
@@ -288,12 +288,11 @@ function M:parse_curl_args(prompt_opts)
|
||||
timeout = provider_conf.timeout,
|
||||
proxy = provider_conf.proxy,
|
||||
insecure = provider_conf.allow_insecure,
|
||||
headers = {
|
||||
["Content-Type"] = "application/json",
|
||||
headers = Utils.tbl_override({
|
||||
["Authorization"] = "Bearer " .. M.state.github_token.token,
|
||||
["Copilot-Integration-Id"] = "vscode-chat",
|
||||
["Editor-Version"] = ("Neovim/%s.%s.%s"):format(vim.version().major, vim.version().minor, vim.version().patch),
|
||||
},
|
||||
}, self.extra_headers),
|
||||
body = vim.tbl_deep_extend("force", {
|
||||
model = provider_conf.model,
|
||||
messages = self:parse_messages(prompt_opts),
|
||||
|
||||
@@ -309,7 +309,7 @@ function M:parse_curl_args(prompt_opts)
|
||||
),
|
||||
proxy = provider_conf.proxy,
|
||||
insecure = provider_conf.allow_insecure,
|
||||
headers = { ["Content-Type"] = "application/json" },
|
||||
headers = Utils.tbl_override({ ["Content-Type"] = "application/json" }, self.extra_headers),
|
||||
body = M.prepare_request_body(self, prompt_opts, provider_conf, request_body),
|
||||
}
|
||||
end
|
||||
|
||||
@@ -213,7 +213,7 @@ function M:parse_curl_args(prompt_opts)
|
||||
|
||||
return {
|
||||
url = Utils.url_join(provider_conf.endpoint, "/api/chat"),
|
||||
headers = headers,
|
||||
headers = Utils.tbl_override(headers, self.extra_headers),
|
||||
body = vim.tbl_deep_extend("force", {
|
||||
model = provider_conf.model,
|
||||
messages = self:parse_messages(prompt_opts),
|
||||
|
||||
@@ -497,12 +497,6 @@ function M:parse_curl_args(prompt_opts)
|
||||
["Content-Type"] = "application/json",
|
||||
}
|
||||
|
||||
if provider_conf.extra_headers then
|
||||
for key, value in pairs(provider_conf.extra_headers) do
|
||||
headers[key] = value
|
||||
end
|
||||
end
|
||||
|
||||
if Providers.env.require_api_key(provider_conf) then
|
||||
local api_key = self.parse_api_key()
|
||||
if api_key == nil then
|
||||
@@ -536,7 +530,7 @@ function M:parse_curl_args(prompt_opts)
|
||||
url = Utils.url_join(provider_conf.endpoint, "/chat/completions"),
|
||||
proxy = provider_conf.proxy,
|
||||
insecure = provider_conf.allow_insecure,
|
||||
headers = headers,
|
||||
headers = Utils.tbl_override(headers, self.extra_headers),
|
||||
body = vim.tbl_deep_extend("force", {
|
||||
model = provider_conf.model,
|
||||
messages = self:parse_messages(prompt_opts),
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
local P = require("avante.providers")
|
||||
local Utils = require("avante.utils")
|
||||
local Gemini = require("avante.providers.gemini")
|
||||
|
||||
---@class AvanteProviderFunctor
|
||||
@@ -50,10 +51,10 @@ function M:parse_curl_args(prompt_opts)
|
||||
|
||||
return {
|
||||
url = url,
|
||||
headers = {
|
||||
headers = Utils.tbl_override({
|
||||
["Authorization"] = "Bearer " .. bearer_token,
|
||||
["Content-Type"] = "application/json; charset=utf-8",
|
||||
},
|
||||
}, self.extra_headers),
|
||||
proxy = provider_conf.proxy,
|
||||
insecure = provider_conf.allow_insecure,
|
||||
body = Gemini.prepare_request_body(self, prompt_opts, provider_conf, request_body),
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
local P = require("avante.providers")
|
||||
local Utils = require("avante.utils")
|
||||
local Vertex = require("avante.providers.vertex")
|
||||
|
||||
---@class AvanteProviderFunctor
|
||||
@@ -64,10 +65,10 @@ function M:parse_curl_args(prompt_opts)
|
||||
|
||||
return {
|
||||
url = url,
|
||||
headers = {
|
||||
headers = Utils.tbl_override({
|
||||
["Authorization"] = "Bearer " .. Vertex.parse_api_key(),
|
||||
["Content-Type"] = "application/json; charset=utf-8",
|
||||
},
|
||||
}, self.extra_headers),
|
||||
body = vim.tbl_deep_extend("force", {}, request_body),
|
||||
}
|
||||
end
|
||||
|
||||
@@ -221,7 +221,6 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---
|
||||
---@class AvanteDefaultBaseProvider: table<string, any>
|
||||
---@field endpoint? string
|
||||
---@field extra_headers? table<string, any>
|
||||
---@field extra_request_body? table<string, any>
|
||||
---@field model? string
|
||||
---@field local? boolean
|
||||
@@ -291,6 +290,7 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---
|
||||
---@class AvanteProviderFunctor
|
||||
---@field _model_list_cache table
|
||||
---@field extra_headers function(table) -> table | table | nil
|
||||
---@field support_prompt_caching boolean | nil
|
||||
---@field role_map table<"user" | "assistant", string>
|
||||
---@field parse_messages AvanteMessagesParser
|
||||
|
||||
@@ -1662,4 +1662,10 @@ function M.count_lines(str)
|
||||
return count
|
||||
end
|
||||
|
||||
function M.tbl_override(value, override)
|
||||
override = override or {}
|
||||
if type(override) == "function" then return override(value) or value end
|
||||
return vim.tbl_extend("force", value, override)
|
||||
end
|
||||
|
||||
return M
|
||||
|
||||
Reference in New Issue
Block a user