fix: bedrock (#1524)
* fix: bedrock * fix: bad variable name * fix: missing metatable
This commit is contained in:
@@ -20,7 +20,7 @@ local M = {}
|
|||||||
---@field custom_tools AvanteLLMToolPublic[]
|
---@field custom_tools AvanteLLMToolPublic[]
|
||||||
M._defaults = {
|
M._defaults = {
|
||||||
debug = false,
|
debug = false,
|
||||||
---@alias ProviderName "claude" | "openai" | "azure" | "gemini" | "vertex" | "cohere" | "copilot" | string
|
---@alias ProviderName "claude" | "openai" | "azure" | "gemini" | "vertex" | "cohere" | "copilot" | "bedrock" | string
|
||||||
provider = "claude",
|
provider = "claude",
|
||||||
-- WARNING: Since auto-suggestions are a high-frequency operation and therefore expensive,
|
-- WARNING: Since auto-suggestions are a high-frequency operation and therefore expensive,
|
||||||
-- currently designating it as `copilot` provider is dangerous because: https://github.com/yetone/avante.nvim/issues/1048
|
-- currently designating it as `copilot` provider is dangerous because: https://github.com/yetone/avante.nvim/issues/1048
|
||||||
@@ -225,7 +225,7 @@ M._defaults = {
|
|||||||
},
|
},
|
||||||
---@type AvanteSupportedProvider
|
---@type AvanteSupportedProvider
|
||||||
bedrock = {
|
bedrock = {
|
||||||
model = "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
model = "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
timeout = 30000, -- Timeout in milliseconds
|
timeout = 30000, -- Timeout in milliseconds
|
||||||
temperature = 0,
|
temperature = 0,
|
||||||
max_tokens = 8000,
|
max_tokens = 8000,
|
||||||
@@ -474,7 +474,7 @@ function M.setup(opts)
|
|||||||
M._options = merged
|
M._options = merged
|
||||||
M.provider_names = vim
|
M.provider_names = vim
|
||||||
.iter(M._defaults)
|
.iter(M._defaults)
|
||||||
:filter(function(_, value) return type(value) == "table" and value.endpoint ~= nil end)
|
:filter(function(_, value) return type(value) == "table" and (value.endpoint ~= nil or value.model ~= nil) end)
|
||||||
:fold({}, function(acc, k)
|
:fold({}, function(acc, k)
|
||||||
acc = vim.list_extend({}, acc)
|
acc = vim.list_extend({}, acc)
|
||||||
acc = vim.list_extend(acc, { k })
|
acc = vim.list_extend(acc, { k })
|
||||||
@@ -519,12 +519,12 @@ function M.get_window_width() return math.ceil(vim.o.columns * (M.windows.width
|
|||||||
|
|
||||||
---@param provider_name ProviderName
|
---@param provider_name ProviderName
|
||||||
---@return boolean
|
---@return boolean
|
||||||
function M.has_provider(provider_name) return M._options[provider_name] ~= nil or M.vendors[provider_name] ~= nil end
|
function M.has_provider(provider_name) return vim.list_contains(M.provider_names, provider_name) end
|
||||||
|
|
||||||
---get supported providers
|
---get supported providers
|
||||||
---@param provider_name ProviderName
|
---@param provider_name ProviderName
|
||||||
---@return AvanteProviderFunctor
|
function M.get_provider_config(provider_name)
|
||||||
function M.get_provider(provider_name)
|
if not M.has_provider(provider_name) then error("No provider found: " .. provider_name, 2) end
|
||||||
if M._options[provider_name] ~= nil then
|
if M._options[provider_name] ~= nil then
|
||||||
return vim.deepcopy(M._options[provider_name], true)
|
return vim.deepcopy(M._options[provider_name], true)
|
||||||
elseif M.vendors and M.vendors[provider_name] ~= nil then
|
elseif M.vendors and M.vendors[provider_name] ~= nil then
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ function M.open()
|
|||||||
|
|
||||||
-- Collect models from main providers and vendors
|
-- Collect models from main providers and vendors
|
||||||
for _, provider_name in ipairs(Config.provider_names) do
|
for _, provider_name in ipairs(Config.provider_names) do
|
||||||
local entry = create_model_entry(provider_name, Config.get_provider(provider_name))
|
local entry = create_model_entry(provider_name, Config.get_provider_config(provider_name))
|
||||||
if entry then table.insert(models, entry) end
|
if entry then table.insert(models, entry) end
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -43,7 +43,7 @@ function M.open()
|
|||||||
Config.override({
|
Config.override({
|
||||||
[choice.provider_name] = vim.tbl_deep_extend(
|
[choice.provider_name] = vim.tbl_deep_extend(
|
||||||
"force",
|
"force",
|
||||||
Config.get_provider(choice.provider_name),
|
Config.get_provider_config(choice.provider_name),
|
||||||
{ model = choice.model }
|
{ model = choice.model }
|
||||||
),
|
),
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -35,9 +35,9 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
|||||||
return model_handler.parse_response(self, ctx, data_stream, event_state, opts)
|
return model_handler.parse_response(self, ctx, data_stream, event_state, opts)
|
||||||
end
|
end
|
||||||
|
|
||||||
function M:build_bedrock_payload(prompt_opts, body_opts)
|
function M:build_bedrock_payload(prompt_opts, request_body)
|
||||||
local model_handler = M.load_model_handler()
|
local model_handler = M.load_model_handler()
|
||||||
return model_handler.build_bedrock_payload(self, prompt_opts, body_opts)
|
return model_handler.build_bedrock_payload(self, prompt_opts, request_body)
|
||||||
end
|
end
|
||||||
|
|
||||||
function M:parse_stream_data(ctx, data, opts)
|
function M:parse_stream_data(ctx, data, opts)
|
||||||
@@ -66,7 +66,7 @@ end
|
|||||||
---@param prompt_opts AvantePromptOptions
|
---@param prompt_opts AvantePromptOptions
|
||||||
---@return table
|
---@return table
|
||||||
function M:parse_curl_args(prompt_opts)
|
function M:parse_curl_args(prompt_opts)
|
||||||
local base, body_opts = P.parse_config(self)
|
local provider_conf, request_body = P.parse_config(self)
|
||||||
|
|
||||||
local api_key = self.parse_api_key()
|
local api_key = self.parse_api_key()
|
||||||
if api_key == nil then error("Cannot get the bedrock api key!") end
|
if api_key == nil then error("Cannot get the bedrock api key!") end
|
||||||
@@ -79,7 +79,7 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
local endpoint = string.format(
|
local endpoint = string.format(
|
||||||
"https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream",
|
"https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream",
|
||||||
aws_region,
|
aws_region,
|
||||||
base.model
|
provider_conf.model
|
||||||
)
|
)
|
||||||
|
|
||||||
local headers = {
|
local headers = {
|
||||||
@@ -88,7 +88,7 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
|
|
||||||
if aws_session_token and aws_session_token ~= "" then headers["x-amz-security-token"] = aws_session_token end
|
if aws_session_token and aws_session_token ~= "" then headers["x-amz-security-token"] = aws_session_token end
|
||||||
|
|
||||||
local body_payload = self:build_bedrock_payload(prompt_opts, body_opts)
|
local body_payload = self:build_bedrock_payload(prompt_opts, request_body)
|
||||||
|
|
||||||
local rawArgs = {
|
local rawArgs = {
|
||||||
"--aws-sigv4",
|
"--aws-sigv4",
|
||||||
@@ -99,8 +99,8 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
url = endpoint,
|
url = endpoint,
|
||||||
proxy = base.proxy,
|
proxy = provider_conf.proxy,
|
||||||
insecure = base.allow_insecure,
|
insecure = provider_conf.allow_insecure,
|
||||||
headers = headers,
|
headers = headers,
|
||||||
body = body_payload,
|
body = body_payload,
|
||||||
rawArgs = rawArgs,
|
rawArgs = rawArgs,
|
||||||
|
|||||||
@@ -23,19 +23,19 @@ M.parse_response = Claude.parse_response
|
|||||||
|
|
||||||
---@param provider AvanteProviderFunctor
|
---@param provider AvanteProviderFunctor
|
||||||
---@param prompt_opts AvantePromptOptions
|
---@param prompt_opts AvantePromptOptions
|
||||||
---@param body_opts table
|
---@param request_body table
|
||||||
---@return table
|
---@return table
|
||||||
function M.build_bedrock_payload(provider, prompt_opts, body_opts)
|
function M.build_bedrock_payload(provider, prompt_opts, request_body)
|
||||||
local system_prompt = prompt_opts.system_prompt or ""
|
local system_prompt = prompt_opts.system_prompt or ""
|
||||||
local messages = provider:parse_messages(prompt_opts)
|
local messages = provider:parse_messages(prompt_opts)
|
||||||
local max_tokens = body_opts.max_tokens or 2000
|
local max_tokens = request_body.max_tokens or 2000
|
||||||
local payload = {
|
local payload = {
|
||||||
anthropic_version = "bedrock-2023-05-31",
|
anthropic_version = "bedrock-2023-05-31",
|
||||||
max_tokens = max_tokens,
|
max_tokens = max_tokens,
|
||||||
messages = messages,
|
messages = messages,
|
||||||
system = system_prompt,
|
system = system_prompt,
|
||||||
}
|
}
|
||||||
return vim.tbl_deep_extend("force", payload, body_opts or {})
|
return vim.tbl_deep_extend("force", payload, request_body or {})
|
||||||
end
|
end
|
||||||
|
|
||||||
return M
|
return M
|
||||||
|
|||||||
@@ -198,26 +198,25 @@ M = setmetatable(M, {
|
|||||||
---@param t avante.Providers
|
---@param t avante.Providers
|
||||||
---@param k ProviderName
|
---@param k ProviderName
|
||||||
__index = function(t, k)
|
__index = function(t, k)
|
||||||
---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor
|
local provider_config = M.get_config(k)
|
||||||
local Opts = M.get_config(k)
|
|
||||||
|
|
||||||
---@diagnostic disable: undefined-field,no-unknown,inject-field
|
---@diagnostic disable: undefined-field,no-unknown,inject-field
|
||||||
if Config.vendors[k] ~= nil then
|
if Config.vendors[k] ~= nil then
|
||||||
if Opts.parse_response_data ~= nil then
|
if provider_config.parse_response_data ~= nil then
|
||||||
Utils.error("parse_response_data is not supported for avante.nvim vendors")
|
Utils.error("parse_response_data is not supported for avante.nvim vendors")
|
||||||
end
|
end
|
||||||
if Opts.__inherited_from ~= nil then
|
if provider_config.__inherited_from ~= nil then
|
||||||
local BaseOpts = M.get_config(Opts.__inherited_from)
|
local base_provider_config = M.get_config(provider_config.__inherited_from)
|
||||||
local ok, module = pcall(require, "avante.providers." .. Opts.__inherited_from)
|
local ok, module = pcall(require, "avante.providers." .. provider_config.__inherited_from)
|
||||||
if not ok then error("Failed to load provider: " .. Opts.__inherited_from) end
|
if not ok then error("Failed to load provider: " .. provider_config.__inherited_from) end
|
||||||
t[k] = vim.tbl_deep_extend("keep", Opts, BaseOpts, module)
|
t[k] = Utils.deep_extend_with_metatable("keep", provider_config, base_provider_config, module)
|
||||||
else
|
else
|
||||||
t[k] = Opts
|
t[k] = provider_config
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
local ok, module = pcall(require, "avante.providers." .. k)
|
local ok, module = pcall(require, "avante.providers." .. k)
|
||||||
if not ok then error("Failed to load provider: " .. k) end
|
if not ok then error("Failed to load provider: " .. k) end
|
||||||
t[k] = vim.tbl_deep_extend("keep", Opts, module)
|
t[k] = Utils.deep_extend_with_metatable("keep", provider_config, module)
|
||||||
end
|
end
|
||||||
|
|
||||||
t[k].parse_api_key = function() return E.parse_envvar(t[k]) end
|
t[k].parse_api_key = function() return E.parse_envvar(t[k]) end
|
||||||
@@ -310,10 +309,9 @@ end
|
|||||||
|
|
||||||
---@private
|
---@private
|
||||||
---@param provider_name ProviderName
|
---@param provider_name ProviderName
|
||||||
---@return AvanteProviderFunctor | AvanteBedrockProviderFunctor
|
|
||||||
function M.get_config(provider_name)
|
function M.get_config(provider_name)
|
||||||
provider_name = provider_name or Config.provider
|
provider_name = provider_name or Config.provider
|
||||||
local cur = Config.get_provider(provider_name)
|
local cur = Config.get_provider_config(provider_name)
|
||||||
return type(cur) == "function" and cur() or cur
|
return type(cur) == "function" and cur() or cur
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
@@ -2177,7 +2177,7 @@ function Sidebar:reset_memory(args, cb)
|
|||||||
table.insert(chat_history, {
|
table.insert(chat_history, {
|
||||||
timestamp = get_timestamp(),
|
timestamp = get_timestamp(),
|
||||||
provider = Config.provider,
|
provider = Config.provider,
|
||||||
model = Config.get_provider(Config.provider).model,
|
model = Config.get_provider_config(Config.provider).model,
|
||||||
request = "",
|
request = "",
|
||||||
response = "",
|
response = "",
|
||||||
original_response = "",
|
original_response = "",
|
||||||
@@ -2414,7 +2414,8 @@ function Sidebar:create_input_container(opts)
|
|||||||
|
|
||||||
---@param request string
|
---@param request string
|
||||||
local function handle_submit(request)
|
local function handle_submit(request)
|
||||||
local model = Config.has_provider(Config.provider) and Config.get_provider(Config.provider).model or "default"
|
local model = Config.has_provider(Config.provider) and Config.get_provider_config(Config.provider).model
|
||||||
|
or "default"
|
||||||
|
|
||||||
local timestamp = get_timestamp()
|
local timestamp = get_timestamp()
|
||||||
|
|
||||||
|
|||||||
@@ -272,7 +272,7 @@ vim.g.avante_login = vim.g.avante_login
|
|||||||
---@field parse_stream_data? AvanteStreamParser
|
---@field parse_stream_data? AvanteStreamParser
|
||||||
---@field on_error? fun(result: table<string, any>): nil
|
---@field on_error? fun(result: table<string, any>): nil
|
||||||
---
|
---
|
||||||
---@alias AvanteBedrockPayloadBuilder fun(self: AvanteBedrockModelHandler | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions, body_opts: table<string, any>): table<string, any>
|
---@alias AvanteBedrockPayloadBuilder fun(self: AvanteBedrockModelHandler | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions, request_body: table<string, any>): table<string, any>
|
||||||
---
|
---
|
||||||
---@class AvanteBedrockProviderFunctor: AvanteProviderFunctor
|
---@class AvanteBedrockProviderFunctor: AvanteProviderFunctor
|
||||||
---@field load_model_handler fun(): AvanteBedrockModelHandler
|
---@field load_model_handler fun(): AvanteBedrockModelHandler
|
||||||
|
|||||||
@@ -1015,4 +1015,17 @@ function M.icon(string_with_icon, utf8_fallback)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function M.deep_extend_with_metatable(behavior, ...)
|
||||||
|
local tables = { ... }
|
||||||
|
local base = tables[1]
|
||||||
|
if behavior == "keep" then base = tables[#tables] end
|
||||||
|
local mt = getmetatable(base)
|
||||||
|
|
||||||
|
local result = vim.tbl_deep_extend(behavior, ...)
|
||||||
|
|
||||||
|
if mt then setmetatable(result, mt) end
|
||||||
|
|
||||||
|
return result
|
||||||
|
end
|
||||||
|
|
||||||
return M
|
return M
|
||||||
|
|||||||
Reference in New Issue
Block a user