From 2f806ca34223a50435018dd21d443795f619828d Mon Sep 17 00:00:00 2001 From: msvechla Date: Fri, 16 May 2025 05:13:40 +0200 Subject: [PATCH] Resolve AWS credentials using default credentials provider chain for Bedrock (#1752) --- README.md | 17 ++- lua/avante/config.lua | 4 + lua/avante/providers/bedrock.lua | 223 +++++++++++++++++++++++++++++-- tests/providers/bedrock_spec.lua | 43 ++++++ 4 files changed, 276 insertions(+), 11 deletions(-) create mode 100644 tests/providers/bedrock_spec.lua diff --git a/README.md b/README.md index 78c676b..9e0dd6d 100644 --- a/README.md +++ b/README.md @@ -576,12 +576,27 @@ Given its early stage, `avante.nvim` currently supports the following basic func > > For Amazon Bedrock: > +> You can specify the `BEDROCK_KEYS` environment variable to set credentials. When this variable is not specified, bedrock will use the default AWS credentials chain (see below). +> > ```sh > export BEDROCK_KEYS=aws_access_key_id,aws_secret_access_key,aws_region[,aws_session_token] -> > ``` > > Note: The aws_session_token is optional and only needed when using temporary AWS credentials +> +> Alternatively Bedrock tries to resolve AWS credentials using the [Default Credentials Provider Chain](https://docs.aws.amazon.com/cli/v1/userguide/cli-chap-authentication.html). +> This means you can have credentials e.g. configured via the AWS CLI, stored in your ~/.aws/profile, use AWS SSO etc. +> In this case `aws_region` and optionally `aws_profile` should be specified via the bedrock config, e.g.: +> +> ```lua +> bedrock = { +> model = "us.anthropic.claude-3-5-sonnet-20241022-v2:0", +> aws_profile = "bedrock", +> aws_region = "us-east-1", +>}, +> ``` +> +> Note: Bedrock requires the [AWS CLI](https://aws.amazon.com/cli/) to be installed on your system. 1. Open a code file in Neovim. 2. Use the `:AvanteAsk` command to query the AI about the code. diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 336e90f..87eeaca 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -258,6 +258,8 @@ M._defaults = { timeout = 30000, -- Timeout in milliseconds temperature = 0, max_tokens = 20480, + aws_region = "", -- AWS region to use for authentication and bedrock API + aws_profile = "", -- AWS profile to use for authentication, if unspecified uses default credentials chain }, ---@type AvanteSupportedProvider gemini = { @@ -642,6 +644,8 @@ M.BASE_PROVIDER_KEYS = { "api_key_name", "timeout", "display_name", + "aws_region", + "aws_profile", -- internal "local", "_shellenv", diff --git a/lua/avante/providers/bedrock.lua b/lua/avante/providers/bedrock.lua index f820fc2..f7a04fe 100644 --- a/lua/avante/providers/bedrock.lua +++ b/lua/avante/providers/bedrock.lua @@ -6,6 +6,12 @@ local M = {} M.api_key_name = "BEDROCK_KEYS" +---@class AWSCreds +---@field access_key_id string +---@field secret_access_key string +---@field session_token string +local AWSCreds = {} + M = setmetatable(M, { __index = function(_, k) local model_handler = M.load_model_handler() @@ -13,6 +19,28 @@ M = setmetatable(M, { end, }) +function M.setup() + -- Check if AWS CLI is installed + if not M.check_aws_cli_installed() then + Utils.error( + "AWS CLI not found. Please install it to use the Bedrock provider: https://aws.amazon.com/cli/", + { once = true, title = "Avante Bedrock" } + ) + return false + end + + -- Check if curl supports AWS signature v4 + if not M.check_curl_supports_aws_sig() then + Utils.error( + "Your curl version doesn't support AWS signature v4 properly. Please upgrade to curl 8.10.0 or newer.", + { once = true, title = "Avante Bedrock" } + ) + return false + end + + return true +end + function M.load_model_handler() local provider_conf, _ = P.parse_config(P["bedrock"]) local bedrock_model = provider_conf.model @@ -73,17 +101,35 @@ end function M:parse_curl_args(prompt_opts) local provider_conf, request_body = P.parse_config(self) + local access_key_id, secret_access_key, session_token, region + + -- try to parse credentials from api key local api_key = self.parse_api_key() - if api_key == nil then error("Cannot get the bedrock api key!") end - local parts = vim.split(api_key, ",") - local aws_access_key_id = parts[1] - local aws_secret_access_key = parts[2] - local aws_region = parts[3] - local aws_session_token = parts[4] + if api_key ~= nil then + local parts = vim.split(api_key, ",") + access_key_id = parts[1] + secret_access_key = parts[2] + region = parts[3] + session_token = parts[4] + else + -- alternatively parse credentials from default AWS credentials provider chain + + ---@diagnostic disable-next-line: undefined-field + region = provider_conf.aws_region + ---@diagnostic disable-next-line: undefined-field + local profile = provider_conf.aws_profile + + local awsCreds = M:get_aws_credentials(region, profile) + if not region or region == "" then error("No aws_region specified in bedrock config") end + + access_key_id = awsCreds.access_key_id + secret_access_key = awsCreds.secret_access_key + session_token = awsCreds.session_token + end local endpoint = string.format( "https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream", - aws_region, + region, provider_conf.model ) @@ -91,15 +137,15 @@ function M:parse_curl_args(prompt_opts) ["Content-Type"] = "application/json", } - if aws_session_token and aws_session_token ~= "" then headers["x-amz-security-token"] = aws_session_token end + if session_token and session_token ~= "" then headers["x-amz-security-token"] = session_token end local body_payload = self:build_bedrock_payload(prompt_opts, request_body) local rawArgs = { "--aws-sigv4", - string.format("aws:amz:%s:bedrock", aws_region), + string.format("aws:amz:%s:bedrock", region), "--user", - string.format("%s:%s", aws_access_key_id, aws_secret_access_key), + string.format("%s:%s", access_key_id, secret_access_key), } return { @@ -127,4 +173,161 @@ function M.on_error(result) Utils.error(error_msg, { once = true, title = "Avante" }) end +--- Run a command and capture its output +---@param cmd string The command to run +---@param args table The command arguments +---@return string output The command output +---@return number exit_code The command exit code +local function run_command(cmd, args) + local stdout = vim.loop.new_pipe(false) + local stderr = vim.loop.new_pipe(false) + local output = "" + local error_output = "" + local exit_code = -1 + + local handle + handle = vim.loop.spawn(cmd, { + args = args, + stdio = { nil, stdout, stderr }, + }, function(code) + -- Safely close all handles + if stdout then + stdout:read_stop() + stdout:close() + end + if stderr then + stderr:read_stop() + stderr:close() + end + if handle then handle:close() end + exit_code = code + end) + + if not handle then + -- Clean up if spawn failed + if stdout then stdout:close() end + if stderr then stderr:close() end + return "", -1 + end + + if stdout then + stdout:read_start(function(err, data) + if err then + Utils.error("Error reading stdout: " .. err) + return + end + if data then output = output .. data end + end) + end + + if stderr then + stderr:read_start(function(err, data) + if err then + Utils.error("Error reading stderr: " .. err) + return + end + if data then error_output = error_output .. data end + end) + end + + -- Wait for the command to complete + vim.wait(10000, function() return exit_code ~= -1 end) + + -- If we timed out, clean up + if exit_code == -1 then + if stdout then + stdout:read_stop() + stdout:close() + end + if stderr then + stderr:read_stop() + stderr:close() + end + if handle then handle:close() end + end + + return output, exit_code +end + +--- get_aws_credentials returns aws credentials using the aws cli +---@param region string +---@param profile string +---@return AWSCreds +function M:get_aws_credentials(region, profile) + local awsCreds = { + access_key_id = "", + secret_access_key = "", + session_token = "", + } + + local args = { "configure", "export-credentials" } + + if profile and profile ~= "" then + table.insert(args, "--profile") + table.insert(args, profile) + end + + if region and region ~= "" then + table.insert(args, "--region") + table.insert(args, region) + end + + -- run aws configure export-credentials and capture the json output + local start_time = vim.loop.hrtime() + local output, exit_code = run_command("aws", args) + + if exit_code == 0 then + local credentials = vim.json.decode(output) + awsCreds.access_key_id = credentials.AccessKeyId + awsCreds.secret_access_key = credentials.SecretAccessKey + awsCreds.session_token = credentials.SessionToken + else + print("Failed to run AWS command") + end + + local end_time = vim.loop.hrtime() + local duration_ms = (end_time - start_time) / 1000000 + Utils.debug(string.format("AWS credentials fetch took %.2f ms", duration_ms)) + + return awsCreds +end + +--- check_aws_cli_installed returns true when the aws cli is installed +--- @return boolean +function M.check_aws_cli_installed() + local _, exit_code = run_command("aws", { "--version" }) + return exit_code == 0 +end + +--- check_curl_version_supports_aws_sig checks if the given curl version supports aws sigv4 correctly +--- we require at least version 8.10.0 because it contains critical fixes for aws sigv4 support +--- https://curl.se/ch/8.10.0.html +--- @param version_string string The curl version string to check +--- @return boolean +function M.check_curl_version_supports_aws_sig(version_string) + -- Extract the version number + local major, minor = version_string:match("curl (%d+)%.(%d+)") + + if major and minor then + major = tonumber(major) + minor = tonumber(minor) + + -- Check if the version is at least 8.10 + if major > 8 or (major == 8 and minor >= 10) then return true end + end + + return false +end + +--- check_curl_supports_aws_sig returns true when the installed curl version supports aws sigv4 +--- @return boolean +function M.check_curl_supports_aws_sig() + local output, exit_code = run_command("curl", { "--version" }) + if exit_code ~= 0 then return false end + + -- Get first line of output which contains version info + local version_string = output:match("^[^\n]+") + return M.check_curl_version_supports_aws_sig(version_string) +end + return M diff --git a/tests/providers/bedrock_spec.lua b/tests/providers/bedrock_spec.lua new file mode 100644 index 0000000..9128cb1 --- /dev/null +++ b/tests/providers/bedrock_spec.lua @@ -0,0 +1,43 @@ +local bedrock_provider = require("avante.providers.bedrock") + +describe("bedrock_provider", function() + describe("check_curl_version_supports_aws_sig", function() + it( + "should return true for curl version 8.10.0", + function() + assert.is_true( + bedrock_provider.check_curl_version_supports_aws_sig( + "curl 8.10.0 (x86_64-pc-linux-gnu) libcurl/7.68.0 OpenSSL/1.1.1f zlib/1.2.11 brotli/1.0.7 libidn2/2.2.0 libpsl/0.21.0 (+libidn2/2.2.0) libssh2/1.8.0 nghttp2/1.40.0 librtmp/2.3" + ) + ) + end + ) + + it( + "should return true for curl version higher than 8.10.0", + function() + assert.is_true( + bedrock_provider.check_curl_version_supports_aws_sig( + "curl 8.11.0 (aarch64-apple-darwin23.6.0) libcurl/8.11.0 OpenSSL/3.4.0 (SecureTransport) zlib/1.2.12 brotli/1.1.0 zstd/1.5.6 AppleIDN libssh2/1.11.1 nghttp2/1.64.0 librtmp/2.3" + ) + ) + end + ) + + it( + "should return false for curl version lower than 8.10.0", + function() + assert.is_false( + bedrock_provider.check_curl_version_supports_aws_sig( + "curl 7.68.0 (x86_64-pc-linux-gnu) libcurl/7.68.0 OpenSSL/1.1.1f zlib/1.2.11 brotli/1.0.7 libidn2/2.2.0 libpsl/0.21.0 (+libidn2/2.2.0) libssh2/1.8.0 nghttp2/1.40.0 librtmp/2.3" + ) + ) + end + ) + + it( + "should return false for invalid version string", + function() assert.is_false(bedrock_provider.check_curl_version_supports_aws_sig("Invalid version string")) end + ) + end) +end)