Resolve AWS credentials using default credentials provider chain for Bedrock (#1752)
This commit is contained in:
17
README.md
17
README.md
@@ -576,12 +576,27 @@ Given its early stage, `avante.nvim` currently supports the following basic func
|
||||
>
|
||||
> For Amazon Bedrock:
|
||||
>
|
||||
> You can specify the `BEDROCK_KEYS` environment variable to set credentials. When this variable is not specified, bedrock will use the default AWS credentials chain (see below).
|
||||
>
|
||||
> ```sh
|
||||
> export BEDROCK_KEYS=aws_access_key_id,aws_secret_access_key,aws_region[,aws_session_token]
|
||||
>
|
||||
> ```
|
||||
>
|
||||
> Note: The aws_session_token is optional and only needed when using temporary AWS credentials
|
||||
>
|
||||
> Alternatively Bedrock tries to resolve AWS credentials using the [Default Credentials Provider Chain](https://docs.aws.amazon.com/cli/v1/userguide/cli-chap-authentication.html).
|
||||
> This means you can have credentials e.g. configured via the AWS CLI, stored in your ~/.aws/profile, use AWS SSO etc.
|
||||
> In this case `aws_region` and optionally `aws_profile` should be specified via the bedrock config, e.g.:
|
||||
>
|
||||
> ```lua
|
||||
> bedrock = {
|
||||
> model = "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
> aws_profile = "bedrock",
|
||||
> aws_region = "us-east-1",
|
||||
>},
|
||||
> ```
|
||||
>
|
||||
> Note: Bedrock requires the [AWS CLI](https://aws.amazon.com/cli/) to be installed on your system.
|
||||
|
||||
1. Open a code file in Neovim.
|
||||
2. Use the `:AvanteAsk` command to query the AI about the code.
|
||||
|
||||
@@ -258,6 +258,8 @@ M._defaults = {
|
||||
timeout = 30000, -- Timeout in milliseconds
|
||||
temperature = 0,
|
||||
max_tokens = 20480,
|
||||
aws_region = "", -- AWS region to use for authentication and bedrock API
|
||||
aws_profile = "", -- AWS profile to use for authentication, if unspecified uses default credentials chain
|
||||
},
|
||||
---@type AvanteSupportedProvider
|
||||
gemini = {
|
||||
@@ -642,6 +644,8 @@ M.BASE_PROVIDER_KEYS = {
|
||||
"api_key_name",
|
||||
"timeout",
|
||||
"display_name",
|
||||
"aws_region",
|
||||
"aws_profile",
|
||||
-- internal
|
||||
"local",
|
||||
"_shellenv",
|
||||
|
||||
@@ -6,6 +6,12 @@ local M = {}
|
||||
|
||||
M.api_key_name = "BEDROCK_KEYS"
|
||||
|
||||
---@class AWSCreds
|
||||
---@field access_key_id string
|
||||
---@field secret_access_key string
|
||||
---@field session_token string
|
||||
local AWSCreds = {}
|
||||
|
||||
M = setmetatable(M, {
|
||||
__index = function(_, k)
|
||||
local model_handler = M.load_model_handler()
|
||||
@@ -13,6 +19,28 @@ M = setmetatable(M, {
|
||||
end,
|
||||
})
|
||||
|
||||
function M.setup()
|
||||
-- Check if AWS CLI is installed
|
||||
if not M.check_aws_cli_installed() then
|
||||
Utils.error(
|
||||
"AWS CLI not found. Please install it to use the Bedrock provider: https://aws.amazon.com/cli/",
|
||||
{ once = true, title = "Avante Bedrock" }
|
||||
)
|
||||
return false
|
||||
end
|
||||
|
||||
-- Check if curl supports AWS signature v4
|
||||
if not M.check_curl_supports_aws_sig() then
|
||||
Utils.error(
|
||||
"Your curl version doesn't support AWS signature v4 properly. Please upgrade to curl 8.10.0 or newer.",
|
||||
{ once = true, title = "Avante Bedrock" }
|
||||
)
|
||||
return false
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function M.load_model_handler()
|
||||
local provider_conf, _ = P.parse_config(P["bedrock"])
|
||||
local bedrock_model = provider_conf.model
|
||||
@@ -73,17 +101,35 @@ end
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(self)
|
||||
|
||||
local access_key_id, secret_access_key, session_token, region
|
||||
|
||||
-- try to parse credentials from api key
|
||||
local api_key = self.parse_api_key()
|
||||
if api_key == nil then error("Cannot get the bedrock api key!") end
|
||||
local parts = vim.split(api_key, ",")
|
||||
local aws_access_key_id = parts[1]
|
||||
local aws_secret_access_key = parts[2]
|
||||
local aws_region = parts[3]
|
||||
local aws_session_token = parts[4]
|
||||
if api_key ~= nil then
|
||||
local parts = vim.split(api_key, ",")
|
||||
access_key_id = parts[1]
|
||||
secret_access_key = parts[2]
|
||||
region = parts[3]
|
||||
session_token = parts[4]
|
||||
else
|
||||
-- alternatively parse credentials from default AWS credentials provider chain
|
||||
|
||||
---@diagnostic disable-next-line: undefined-field
|
||||
region = provider_conf.aws_region
|
||||
---@diagnostic disable-next-line: undefined-field
|
||||
local profile = provider_conf.aws_profile
|
||||
|
||||
local awsCreds = M:get_aws_credentials(region, profile)
|
||||
if not region or region == "" then error("No aws_region specified in bedrock config") end
|
||||
|
||||
access_key_id = awsCreds.access_key_id
|
||||
secret_access_key = awsCreds.secret_access_key
|
||||
session_token = awsCreds.session_token
|
||||
end
|
||||
|
||||
local endpoint = string.format(
|
||||
"https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream",
|
||||
aws_region,
|
||||
region,
|
||||
provider_conf.model
|
||||
)
|
||||
|
||||
@@ -91,15 +137,15 @@ function M:parse_curl_args(prompt_opts)
|
||||
["Content-Type"] = "application/json",
|
||||
}
|
||||
|
||||
if aws_session_token and aws_session_token ~= "" then headers["x-amz-security-token"] = aws_session_token end
|
||||
if session_token and session_token ~= "" then headers["x-amz-security-token"] = session_token end
|
||||
|
||||
local body_payload = self:build_bedrock_payload(prompt_opts, request_body)
|
||||
|
||||
local rawArgs = {
|
||||
"--aws-sigv4",
|
||||
string.format("aws:amz:%s:bedrock", aws_region),
|
||||
string.format("aws:amz:%s:bedrock", region),
|
||||
"--user",
|
||||
string.format("%s:%s", aws_access_key_id, aws_secret_access_key),
|
||||
string.format("%s:%s", access_key_id, secret_access_key),
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -127,4 +173,161 @@ function M.on_error(result)
|
||||
Utils.error(error_msg, { once = true, title = "Avante" })
|
||||
end
|
||||
|
||||
--- Run a command and capture its output
|
||||
---@param cmd string The command to run
|
||||
---@param args table The command arguments
|
||||
---@return string output The command output
|
||||
---@return number exit_code The command exit code
|
||||
local function run_command(cmd, args)
|
||||
local stdout = vim.loop.new_pipe(false)
|
||||
local stderr = vim.loop.new_pipe(false)
|
||||
local output = ""
|
||||
local error_output = ""
|
||||
local exit_code = -1
|
||||
|
||||
local handle
|
||||
handle = vim.loop.spawn(cmd, {
|
||||
args = args,
|
||||
stdio = { nil, stdout, stderr },
|
||||
}, function(code)
|
||||
-- Safely close all handles
|
||||
if stdout then
|
||||
stdout:read_stop()
|
||||
stdout:close()
|
||||
end
|
||||
if stderr then
|
||||
stderr:read_stop()
|
||||
stderr:close()
|
||||
end
|
||||
if handle then handle:close() end
|
||||
exit_code = code
|
||||
end)
|
||||
|
||||
if not handle then
|
||||
-- Clean up if spawn failed
|
||||
if stdout then stdout:close() end
|
||||
if stderr then stderr:close() end
|
||||
return "", -1
|
||||
end
|
||||
|
||||
if stdout then
|
||||
stdout:read_start(function(err, data)
|
||||
if err then
|
||||
Utils.error("Error reading stdout: " .. err)
|
||||
return
|
||||
end
|
||||
if data then output = output .. data end
|
||||
end)
|
||||
end
|
||||
|
||||
if stderr then
|
||||
stderr:read_start(function(err, data)
|
||||
if err then
|
||||
Utils.error("Error reading stderr: " .. err)
|
||||
return
|
||||
end
|
||||
if data then error_output = error_output .. data end
|
||||
end)
|
||||
end
|
||||
|
||||
-- Wait for the command to complete
|
||||
vim.wait(10000, function() return exit_code ~= -1 end)
|
||||
|
||||
-- If we timed out, clean up
|
||||
if exit_code == -1 then
|
||||
if stdout then
|
||||
stdout:read_stop()
|
||||
stdout:close()
|
||||
end
|
||||
if stderr then
|
||||
stderr:read_stop()
|
||||
stderr:close()
|
||||
end
|
||||
if handle then handle:close() end
|
||||
end
|
||||
|
||||
return output, exit_code
|
||||
end
|
||||
|
||||
--- get_aws_credentials returns aws credentials using the aws cli
|
||||
---@param region string
|
||||
---@param profile string
|
||||
---@return AWSCreds
|
||||
function M:get_aws_credentials(region, profile)
|
||||
local awsCreds = {
|
||||
access_key_id = "",
|
||||
secret_access_key = "",
|
||||
session_token = "",
|
||||
}
|
||||
|
||||
local args = { "configure", "export-credentials" }
|
||||
|
||||
if profile and profile ~= "" then
|
||||
table.insert(args, "--profile")
|
||||
table.insert(args, profile)
|
||||
end
|
||||
|
||||
if region and region ~= "" then
|
||||
table.insert(args, "--region")
|
||||
table.insert(args, region)
|
||||
end
|
||||
|
||||
-- run aws configure export-credentials and capture the json output
|
||||
local start_time = vim.loop.hrtime()
|
||||
local output, exit_code = run_command("aws", args)
|
||||
|
||||
if exit_code == 0 then
|
||||
local credentials = vim.json.decode(output)
|
||||
awsCreds.access_key_id = credentials.AccessKeyId
|
||||
awsCreds.secret_access_key = credentials.SecretAccessKey
|
||||
awsCreds.session_token = credentials.SessionToken
|
||||
else
|
||||
print("Failed to run AWS command")
|
||||
end
|
||||
|
||||
local end_time = vim.loop.hrtime()
|
||||
local duration_ms = (end_time - start_time) / 1000000
|
||||
Utils.debug(string.format("AWS credentials fetch took %.2f ms", duration_ms))
|
||||
|
||||
return awsCreds
|
||||
end
|
||||
|
||||
--- check_aws_cli_installed returns true when the aws cli is installed
|
||||
--- @return boolean
|
||||
function M.check_aws_cli_installed()
|
||||
local _, exit_code = run_command("aws", { "--version" })
|
||||
return exit_code == 0
|
||||
end
|
||||
|
||||
--- check_curl_version_supports_aws_sig checks if the given curl version supports aws sigv4 correctly
|
||||
--- we require at least version 8.10.0 because it contains critical fixes for aws sigv4 support
|
||||
--- https://curl.se/ch/8.10.0.html
|
||||
--- @param version_string string The curl version string to check
|
||||
--- @return boolean
|
||||
function M.check_curl_version_supports_aws_sig(version_string)
|
||||
-- Extract the version number
|
||||
local major, minor = version_string:match("curl (%d+)%.(%d+)")
|
||||
|
||||
if major and minor then
|
||||
major = tonumber(major)
|
||||
minor = tonumber(minor)
|
||||
|
||||
-- Check if the version is at least 8.10
|
||||
if major > 8 or (major == 8 and minor >= 10) then return true end
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- check_curl_supports_aws_sig returns true when the installed curl version supports aws sigv4
|
||||
--- @return boolean
|
||||
function M.check_curl_supports_aws_sig()
|
||||
local output, exit_code = run_command("curl", { "--version" })
|
||||
if exit_code ~= 0 then return false end
|
||||
|
||||
-- Get first line of output which contains version info
|
||||
local version_string = output:match("^[^\n]+")
|
||||
return M.check_curl_version_supports_aws_sig(version_string)
|
||||
end
|
||||
|
||||
return M
|
||||
|
||||
43
tests/providers/bedrock_spec.lua
Normal file
43
tests/providers/bedrock_spec.lua
Normal file
@@ -0,0 +1,43 @@
|
||||
local bedrock_provider = require("avante.providers.bedrock")
|
||||
|
||||
describe("bedrock_provider", function()
|
||||
describe("check_curl_version_supports_aws_sig", function()
|
||||
it(
|
||||
"should return true for curl version 8.10.0",
|
||||
function()
|
||||
assert.is_true(
|
||||
bedrock_provider.check_curl_version_supports_aws_sig(
|
||||
"curl 8.10.0 (x86_64-pc-linux-gnu) libcurl/7.68.0 OpenSSL/1.1.1f zlib/1.2.11 brotli/1.0.7 libidn2/2.2.0 libpsl/0.21.0 (+libidn2/2.2.0) libssh2/1.8.0 nghttp2/1.40.0 librtmp/2.3"
|
||||
)
|
||||
)
|
||||
end
|
||||
)
|
||||
|
||||
it(
|
||||
"should return true for curl version higher than 8.10.0",
|
||||
function()
|
||||
assert.is_true(
|
||||
bedrock_provider.check_curl_version_supports_aws_sig(
|
||||
"curl 8.11.0 (aarch64-apple-darwin23.6.0) libcurl/8.11.0 OpenSSL/3.4.0 (SecureTransport) zlib/1.2.12 brotli/1.1.0 zstd/1.5.6 AppleIDN libssh2/1.11.1 nghttp2/1.64.0 librtmp/2.3"
|
||||
)
|
||||
)
|
||||
end
|
||||
)
|
||||
|
||||
it(
|
||||
"should return false for curl version lower than 8.10.0",
|
||||
function()
|
||||
assert.is_false(
|
||||
bedrock_provider.check_curl_version_supports_aws_sig(
|
||||
"curl 7.68.0 (x86_64-pc-linux-gnu) libcurl/7.68.0 OpenSSL/1.1.1f zlib/1.2.11 brotli/1.0.7 libidn2/2.2.0 libpsl/0.21.0 (+libidn2/2.2.0) libssh2/1.8.0 nghttp2/1.40.0 librtmp/2.3"
|
||||
)
|
||||
)
|
||||
end
|
||||
)
|
||||
|
||||
it(
|
||||
"should return false for invalid version string",
|
||||
function() assert.is_false(bedrock_provider.check_curl_version_supports_aws_sig("Invalid version string")) end
|
||||
)
|
||||
end)
|
||||
end)
|
||||
Reference in New Issue
Block a user