fix: bedrock (#1524)
* fix: bedrock * fix: bad variable name * fix: missing metatable
This commit is contained in:
@@ -35,9 +35,9 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
return model_handler.parse_response(self, ctx, data_stream, event_state, opts)
|
||||
end
|
||||
|
||||
function M:build_bedrock_payload(prompt_opts, body_opts)
|
||||
function M:build_bedrock_payload(prompt_opts, request_body)
|
||||
local model_handler = M.load_model_handler()
|
||||
return model_handler.build_bedrock_payload(self, prompt_opts, body_opts)
|
||||
return model_handler.build_bedrock_payload(self, prompt_opts, request_body)
|
||||
end
|
||||
|
||||
function M:parse_stream_data(ctx, data, opts)
|
||||
@@ -66,7 +66,7 @@ end
|
||||
---@param prompt_opts AvantePromptOptions
|
||||
---@return table
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local base, body_opts = P.parse_config(self)
|
||||
local provider_conf, request_body = P.parse_config(self)
|
||||
|
||||
local api_key = self.parse_api_key()
|
||||
if api_key == nil then error("Cannot get the bedrock api key!") end
|
||||
@@ -79,7 +79,7 @@ function M:parse_curl_args(prompt_opts)
|
||||
local endpoint = string.format(
|
||||
"https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream",
|
||||
aws_region,
|
||||
base.model
|
||||
provider_conf.model
|
||||
)
|
||||
|
||||
local headers = {
|
||||
@@ -88,7 +88,7 @@ function M:parse_curl_args(prompt_opts)
|
||||
|
||||
if aws_session_token and aws_session_token ~= "" then headers["x-amz-security-token"] = aws_session_token end
|
||||
|
||||
local body_payload = self:build_bedrock_payload(prompt_opts, body_opts)
|
||||
local body_payload = self:build_bedrock_payload(prompt_opts, request_body)
|
||||
|
||||
local rawArgs = {
|
||||
"--aws-sigv4",
|
||||
@@ -99,8 +99,8 @@ function M:parse_curl_args(prompt_opts)
|
||||
|
||||
return {
|
||||
url = endpoint,
|
||||
proxy = base.proxy,
|
||||
insecure = base.allow_insecure,
|
||||
proxy = provider_conf.proxy,
|
||||
insecure = provider_conf.allow_insecure,
|
||||
headers = headers,
|
||||
body = body_payload,
|
||||
rawArgs = rawArgs,
|
||||
|
||||
@@ -23,19 +23,19 @@ M.parse_response = Claude.parse_response
|
||||
|
||||
---@param provider AvanteProviderFunctor
|
||||
---@param prompt_opts AvantePromptOptions
|
||||
---@param body_opts table
|
||||
---@param request_body table
|
||||
---@return table
|
||||
function M.build_bedrock_payload(provider, prompt_opts, body_opts)
|
||||
function M.build_bedrock_payload(provider, prompt_opts, request_body)
|
||||
local system_prompt = prompt_opts.system_prompt or ""
|
||||
local messages = provider:parse_messages(prompt_opts)
|
||||
local max_tokens = body_opts.max_tokens or 2000
|
||||
local max_tokens = request_body.max_tokens or 2000
|
||||
local payload = {
|
||||
anthropic_version = "bedrock-2023-05-31",
|
||||
max_tokens = max_tokens,
|
||||
messages = messages,
|
||||
system = system_prompt,
|
||||
}
|
||||
return vim.tbl_deep_extend("force", payload, body_opts or {})
|
||||
return vim.tbl_deep_extend("force", payload, request_body or {})
|
||||
end
|
||||
|
||||
return M
|
||||
|
||||
@@ -198,26 +198,25 @@ M = setmetatable(M, {
|
||||
---@param t avante.Providers
|
||||
---@param k ProviderName
|
||||
__index = function(t, k)
|
||||
---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor
|
||||
local Opts = M.get_config(k)
|
||||
local provider_config = M.get_config(k)
|
||||
|
||||
---@diagnostic disable: undefined-field,no-unknown,inject-field
|
||||
if Config.vendors[k] ~= nil then
|
||||
if Opts.parse_response_data ~= nil then
|
||||
if provider_config.parse_response_data ~= nil then
|
||||
Utils.error("parse_response_data is not supported for avante.nvim vendors")
|
||||
end
|
||||
if Opts.__inherited_from ~= nil then
|
||||
local BaseOpts = M.get_config(Opts.__inherited_from)
|
||||
local ok, module = pcall(require, "avante.providers." .. Opts.__inherited_from)
|
||||
if not ok then error("Failed to load provider: " .. Opts.__inherited_from) end
|
||||
t[k] = vim.tbl_deep_extend("keep", Opts, BaseOpts, module)
|
||||
if provider_config.__inherited_from ~= nil then
|
||||
local base_provider_config = M.get_config(provider_config.__inherited_from)
|
||||
local ok, module = pcall(require, "avante.providers." .. provider_config.__inherited_from)
|
||||
if not ok then error("Failed to load provider: " .. provider_config.__inherited_from) end
|
||||
t[k] = Utils.deep_extend_with_metatable("keep", provider_config, base_provider_config, module)
|
||||
else
|
||||
t[k] = Opts
|
||||
t[k] = provider_config
|
||||
end
|
||||
else
|
||||
local ok, module = pcall(require, "avante.providers." .. k)
|
||||
if not ok then error("Failed to load provider: " .. k) end
|
||||
t[k] = vim.tbl_deep_extend("keep", Opts, module)
|
||||
t[k] = Utils.deep_extend_with_metatable("keep", provider_config, module)
|
||||
end
|
||||
|
||||
t[k].parse_api_key = function() return E.parse_envvar(t[k]) end
|
||||
@@ -310,10 +309,9 @@ end
|
||||
|
||||
---@private
|
||||
---@param provider_name ProviderName
|
||||
---@return AvanteProviderFunctor | AvanteBedrockProviderFunctor
|
||||
function M.get_config(provider_name)
|
||||
provider_name = provider_name or Config.provider
|
||||
local cur = Config.get_provider(provider_name)
|
||||
local cur = Config.get_provider_config(provider_name)
|
||||
return type(cur) == "function" and cur() or cur
|
||||
end
|
||||
|
||||
|
||||
Reference in New Issue
Block a user