feat: allow overriding provider headers (#2161)
This commit is contained in:
@@ -58,7 +58,7 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
),
|
),
|
||||||
proxy = provider_conf.proxy,
|
proxy = provider_conf.proxy,
|
||||||
insecure = provider_conf.allow_insecure,
|
insecure = provider_conf.allow_insecure,
|
||||||
headers = headers,
|
headers = Utils.tbl_override(headers, self.extra_headers),
|
||||||
body = vim.tbl_deep_extend("force", {
|
body = vim.tbl_deep_extend("force", {
|
||||||
messages = self:parse_messages(prompt_opts),
|
messages = self:parse_messages(prompt_opts),
|
||||||
stream = true,
|
stream = true,
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
url = endpoint,
|
url = endpoint,
|
||||||
proxy = provider_conf.proxy,
|
proxy = provider_conf.proxy,
|
||||||
insecure = provider_conf.allow_insecure,
|
insecure = provider_conf.allow_insecure,
|
||||||
headers = headers,
|
headers = Utils.tbl_override(headers, self.extra_headers),
|
||||||
body = body_payload,
|
body = body_payload,
|
||||||
rawArgs = rawArgs,
|
rawArgs = rawArgs,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -396,7 +396,7 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
url = Utils.url_join(provider_conf.endpoint, "/v1/messages"),
|
url = Utils.url_join(provider_conf.endpoint, "/v1/messages"),
|
||||||
proxy = provider_conf.proxy,
|
proxy = provider_conf.proxy,
|
||||||
insecure = provider_conf.allow_insecure,
|
insecure = provider_conf.allow_insecure,
|
||||||
headers = headers,
|
headers = Utils.tbl_override(headers, self.extra_headers),
|
||||||
body = vim.tbl_deep_extend("force", {
|
body = vim.tbl_deep_extend("force", {
|
||||||
model = provider_conf.model,
|
model = provider_conf.model,
|
||||||
system = {
|
system = {
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
url = Utils.url_join(provider_conf.endpoint, "/chat"),
|
url = Utils.url_join(provider_conf.endpoint, "/chat"),
|
||||||
proxy = provider_conf.proxy,
|
proxy = provider_conf.proxy,
|
||||||
insecure = provider_conf.allow_insecure,
|
insecure = provider_conf.allow_insecure,
|
||||||
headers = headers,
|
headers = Utils.tbl_override(headers, self.extra_headers),
|
||||||
body = vim.tbl_deep_extend("force", {
|
body = vim.tbl_deep_extend("force", {
|
||||||
model = provider_conf.model,
|
model = provider_conf.model,
|
||||||
stream = true,
|
stream = true,
|
||||||
|
|||||||
@@ -222,12 +222,12 @@ function M:models_list()
|
|||||||
H.refresh_token(false, false)
|
H.refresh_token(false, false)
|
||||||
local provider_conf = Providers.parse_config(self)
|
local provider_conf = Providers.parse_config(self)
|
||||||
local curl_opts = {
|
local curl_opts = {
|
||||||
headers = {
|
headers = Utils.tbl_override({
|
||||||
["Content-Type"] = "application/json",
|
["Content-Type"] = "application/json",
|
||||||
["Authorization"] = "Bearer " .. M.state.github_token.token,
|
["Authorization"] = "Bearer " .. M.state.github_token.token,
|
||||||
["Copilot-Integration-Id"] = "vscode-chat",
|
["Copilot-Integration-Id"] = "vscode-chat",
|
||||||
["Editor-Version"] = ("Neovim/%s.%s.%s"):format(vim.version().major, vim.version().minor, vim.version().patch),
|
["Editor-Version"] = ("Neovim/%s.%s.%s"):format(vim.version().major, vim.version().minor, vim.version().patch),
|
||||||
},
|
}, self.extra_headers),
|
||||||
timeout = provider_conf.timeout,
|
timeout = provider_conf.timeout,
|
||||||
proxy = provider_conf.proxy,
|
proxy = provider_conf.proxy,
|
||||||
insecure = provider_conf.allow_insecure,
|
insecure = provider_conf.allow_insecure,
|
||||||
@@ -288,12 +288,11 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
timeout = provider_conf.timeout,
|
timeout = provider_conf.timeout,
|
||||||
proxy = provider_conf.proxy,
|
proxy = provider_conf.proxy,
|
||||||
insecure = provider_conf.allow_insecure,
|
insecure = provider_conf.allow_insecure,
|
||||||
headers = {
|
headers = Utils.tbl_override({
|
||||||
["Content-Type"] = "application/json",
|
|
||||||
["Authorization"] = "Bearer " .. M.state.github_token.token,
|
["Authorization"] = "Bearer " .. M.state.github_token.token,
|
||||||
["Copilot-Integration-Id"] = "vscode-chat",
|
["Copilot-Integration-Id"] = "vscode-chat",
|
||||||
["Editor-Version"] = ("Neovim/%s.%s.%s"):format(vim.version().major, vim.version().minor, vim.version().patch),
|
["Editor-Version"] = ("Neovim/%s.%s.%s"):format(vim.version().major, vim.version().minor, vim.version().patch),
|
||||||
},
|
}, self.extra_headers),
|
||||||
body = vim.tbl_deep_extend("force", {
|
body = vim.tbl_deep_extend("force", {
|
||||||
model = provider_conf.model,
|
model = provider_conf.model,
|
||||||
messages = self:parse_messages(prompt_opts),
|
messages = self:parse_messages(prompt_opts),
|
||||||
|
|||||||
@@ -309,7 +309,7 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
),
|
),
|
||||||
proxy = provider_conf.proxy,
|
proxy = provider_conf.proxy,
|
||||||
insecure = provider_conf.allow_insecure,
|
insecure = provider_conf.allow_insecure,
|
||||||
headers = { ["Content-Type"] = "application/json" },
|
headers = Utils.tbl_override({ ["Content-Type"] = "application/json" }, self.extra_headers),
|
||||||
body = M.prepare_request_body(self, prompt_opts, provider_conf, request_body),
|
body = M.prepare_request_body(self, prompt_opts, provider_conf, request_body),
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|||||||
@@ -213,7 +213,7 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
url = Utils.url_join(provider_conf.endpoint, "/api/chat"),
|
url = Utils.url_join(provider_conf.endpoint, "/api/chat"),
|
||||||
headers = headers,
|
headers = Utils.tbl_override(headers, self.extra_headers),
|
||||||
body = vim.tbl_deep_extend("force", {
|
body = vim.tbl_deep_extend("force", {
|
||||||
model = provider_conf.model,
|
model = provider_conf.model,
|
||||||
messages = self:parse_messages(prompt_opts),
|
messages = self:parse_messages(prompt_opts),
|
||||||
|
|||||||
@@ -497,12 +497,6 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
["Content-Type"] = "application/json",
|
["Content-Type"] = "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
if provider_conf.extra_headers then
|
|
||||||
for key, value in pairs(provider_conf.extra_headers) do
|
|
||||||
headers[key] = value
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
if Providers.env.require_api_key(provider_conf) then
|
if Providers.env.require_api_key(provider_conf) then
|
||||||
local api_key = self.parse_api_key()
|
local api_key = self.parse_api_key()
|
||||||
if api_key == nil then
|
if api_key == nil then
|
||||||
@@ -536,7 +530,7 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
url = Utils.url_join(provider_conf.endpoint, "/chat/completions"),
|
url = Utils.url_join(provider_conf.endpoint, "/chat/completions"),
|
||||||
proxy = provider_conf.proxy,
|
proxy = provider_conf.proxy,
|
||||||
insecure = provider_conf.allow_insecure,
|
insecure = provider_conf.allow_insecure,
|
||||||
headers = headers,
|
headers = Utils.tbl_override(headers, self.extra_headers),
|
||||||
body = vim.tbl_deep_extend("force", {
|
body = vim.tbl_deep_extend("force", {
|
||||||
model = provider_conf.model,
|
model = provider_conf.model,
|
||||||
messages = self:parse_messages(prompt_opts),
|
messages = self:parse_messages(prompt_opts),
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
local P = require("avante.providers")
|
local P = require("avante.providers")
|
||||||
|
local Utils = require("avante.utils")
|
||||||
local Gemini = require("avante.providers.gemini")
|
local Gemini = require("avante.providers.gemini")
|
||||||
|
|
||||||
---@class AvanteProviderFunctor
|
---@class AvanteProviderFunctor
|
||||||
@@ -50,10 +51,10 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
url = url,
|
url = url,
|
||||||
headers = {
|
headers = Utils.tbl_override({
|
||||||
["Authorization"] = "Bearer " .. bearer_token,
|
["Authorization"] = "Bearer " .. bearer_token,
|
||||||
["Content-Type"] = "application/json; charset=utf-8",
|
["Content-Type"] = "application/json; charset=utf-8",
|
||||||
},
|
}, self.extra_headers),
|
||||||
proxy = provider_conf.proxy,
|
proxy = provider_conf.proxy,
|
||||||
insecure = provider_conf.allow_insecure,
|
insecure = provider_conf.allow_insecure,
|
||||||
body = Gemini.prepare_request_body(self, prompt_opts, provider_conf, request_body),
|
body = Gemini.prepare_request_body(self, prompt_opts, provider_conf, request_body),
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
local P = require("avante.providers")
|
local P = require("avante.providers")
|
||||||
|
local Utils = require("avante.utils")
|
||||||
local Vertex = require("avante.providers.vertex")
|
local Vertex = require("avante.providers.vertex")
|
||||||
|
|
||||||
---@class AvanteProviderFunctor
|
---@class AvanteProviderFunctor
|
||||||
@@ -64,10 +65,10 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
url = url,
|
url = url,
|
||||||
headers = {
|
headers = Utils.tbl_override({
|
||||||
["Authorization"] = "Bearer " .. Vertex.parse_api_key(),
|
["Authorization"] = "Bearer " .. Vertex.parse_api_key(),
|
||||||
["Content-Type"] = "application/json; charset=utf-8",
|
["Content-Type"] = "application/json; charset=utf-8",
|
||||||
},
|
}, self.extra_headers),
|
||||||
body = vim.tbl_deep_extend("force", {}, request_body),
|
body = vim.tbl_deep_extend("force", {}, request_body),
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|||||||
@@ -221,7 +221,6 @@ vim.g.avante_login = vim.g.avante_login
|
|||||||
---
|
---
|
||||||
---@class AvanteDefaultBaseProvider: table<string, any>
|
---@class AvanteDefaultBaseProvider: table<string, any>
|
||||||
---@field endpoint? string
|
---@field endpoint? string
|
||||||
---@field extra_headers? table<string, any>
|
|
||||||
---@field extra_request_body? table<string, any>
|
---@field extra_request_body? table<string, any>
|
||||||
---@field model? string
|
---@field model? string
|
||||||
---@field local? boolean
|
---@field local? boolean
|
||||||
@@ -291,6 +290,7 @@ vim.g.avante_login = vim.g.avante_login
|
|||||||
---
|
---
|
||||||
---@class AvanteProviderFunctor
|
---@class AvanteProviderFunctor
|
||||||
---@field _model_list_cache table
|
---@field _model_list_cache table
|
||||||
|
---@field extra_headers function(table) -> table | table | nil
|
||||||
---@field support_prompt_caching boolean | nil
|
---@field support_prompt_caching boolean | nil
|
||||||
---@field role_map table<"user" | "assistant", string>
|
---@field role_map table<"user" | "assistant", string>
|
||||||
---@field parse_messages AvanteMessagesParser
|
---@field parse_messages AvanteMessagesParser
|
||||||
|
|||||||
@@ -1662,4 +1662,10 @@ function M.count_lines(str)
|
|||||||
return count
|
return count
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function M.tbl_override(value, override)
|
||||||
|
override = override or {}
|
||||||
|
if type(override) == "function" then return override(value) or value end
|
||||||
|
return vim.tbl_extend("force", value, override)
|
||||||
|
end
|
||||||
|
|
||||||
return M
|
return M
|
||||||
|
|||||||
Reference in New Issue
Block a user