local Utils = require("avante.utils") local P = require("avante.providers") ---@class AvanteBedrockProviderFunctor local M = {} M.api_key_name = "BEDROCK_KEYS" ---@class AWSCreds ---@field access_key_id string ---@field secret_access_key string ---@field session_token string local AWSCreds = {} M = setmetatable(M, { __index = function(_, k) local model_handler = M.load_model_handler() return model_handler[k] end, }) function M.setup() -- Check if curl supports AWS signature v4 if not M.check_curl_supports_aws_sig() then Utils.error( "Your curl version doesn't support AWS signature v4 properly. Please upgrade to curl 8.10.0 or newer.", { once = true, title = "Avante Bedrock" } ) return false end return true end function M.load_model_handler() local provider_conf, _ = P.parse_config(P["bedrock"]) local bedrock_model = provider_conf.model if provider_conf.model:match("anthropic") then bedrock_model = "claude" end local ok, model_module = pcall(require, "avante.providers.bedrock." .. bedrock_model) if ok then return model_module end local error_msg = "Bedrock model handler not found: " .. bedrock_model error(error_msg) end function M.is_env_set() local provider_conf, _ = P.parse_config(P["bedrock"]) ---@diagnostic disable-next-line: undefined-field local profile = provider_conf.aws_profile if profile ~= nil and profile ~= "" then -- AWS CLI is only needed if aws_profile is used for authoriztion if not M.check_aws_cli_installed() then Utils.error( "AWS CLI not found. Please install it to use the Bedrock provider: https://aws.amazon.com/cli/", { once = true, title = "Avante Bedrock" } ) return false end return true end local value = Utils.environment.parse(M.api_key_name) return value ~= nil end function M:parse_messages(prompt_opts) local model_handler = M.load_model_handler() return model_handler.parse_messages(self, prompt_opts) end function M:parse_response(ctx, data_stream, event_state, opts) local model_handler = M.load_model_handler() return model_handler.parse_response(self, ctx, data_stream, event_state, opts) end function M:transform_tool(tool) local model_handler = M.load_model_handler() return model_handler.transform_tool(self, tool) end function M:build_bedrock_payload(prompt_opts, request_body) local model_handler = M.load_model_handler() return model_handler.build_bedrock_payload(self, prompt_opts, request_body) end local function parse_exception(data) local exceptions_found = {} local bedrock_match = data:gmatch("exception(%b{})") for bedrock_data_match in bedrock_match do local jsn = vim.json.decode(bedrock_data_match) table.insert(exceptions_found, "- " .. jsn.message) end return exceptions_found end function M:parse_stream_data(ctx, data, opts) -- @NOTE: Decode and process Bedrock response -- Each response contains a Base64-encoded `bytes` field, which is decoded into JSON. -- The `type` field in the decoded JSON determines how the response is handled. local bedrock_match = data:gmatch("event(%b{})") for bedrock_data_match in bedrock_match do local jsn = vim.json.decode(bedrock_data_match) local data_stream = vim.base64.decode(jsn.bytes) local json = vim.json.decode(data_stream) self:parse_response(ctx, data_stream, json.type, opts) end local exceptions = parse_exception(data) if #exceptions > 0 then Utils.debug("Bedrock exceptions: ", vim.fn.json_encode(exceptions)) if opts.on_chunk then opts.on_chunk("\n**Exception caught**\n\n") opts.on_chunk(table.concat(exceptions, "\n")) end vim.schedule(function() opts.on_stop({ reason = "error" }) end) end end function M:parse_response_without_stream(data, event_state, opts) if opts.on_chunk == nil then return end local exceptions = parse_exception(data) if #exceptions > 0 then opts.on_chunk("\n**Exception caught**\n\n") opts.on_chunk(table.concat(exceptions, "\n")) end vim.schedule(function() opts.on_stop({ reason = "complete" }) end) end ---@param prompt_opts AvantePromptOptions ---@return AvanteCurlOutput|nil function M:parse_curl_args(prompt_opts) local provider_conf, request_body = P.parse_config(self) local access_key_id, secret_access_key, session_token, region ---@diagnostic disable-next-line: undefined-field local profile = provider_conf.aws_profile if profile ~= nil and profile ~= "" then ---@diagnostic disable-next-line: undefined-field region = provider_conf.aws_region if not region or region == "" then Utils.error("Bedrock: no aws_region specified in bedrock config") return nil end local awsCreds = M:get_aws_credentials(region, profile) access_key_id = awsCreds.access_key_id secret_access_key = awsCreds.secret_access_key session_token = awsCreds.session_token else -- try to parse credentials from api key local api_key = self.parse_api_key() if api_key ~= nil then local parts = vim.split(api_key, ",") access_key_id = parts[1] secret_access_key = parts[2] region = parts[3] session_token = parts[4] else Utils.error("Bedrock: API key not set correctly") return nil end end local endpoint if provider_conf.endpoint then -- Use custom endpoint if provided endpoint = provider_conf.endpoint else -- Default to AWS Bedrock endpoint endpoint = string.format( "https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream", region, provider_conf.model ) end local headers = { ["Content-Type"] = "application/json", } if session_token and session_token ~= "" then headers["x-amz-security-token"] = session_token end local body_payload = self:build_bedrock_payload(prompt_opts, request_body) local rawArgs = { "--aws-sigv4", string.format("aws:amz:%s:bedrock", region), "--user", string.format("%s:%s", access_key_id, secret_access_key), } return { url = endpoint, proxy = provider_conf.proxy, insecure = provider_conf.allow_insecure, headers = Utils.tbl_override(headers, self.extra_headers), body = body_payload, rawArgs = rawArgs, } end function M.on_error(result) if not result.body then return Utils.error("API request failed with status " .. result.status, { once = true, title = "Avante" }) end local ok, body = pcall(vim.json.decode, result.body) if not (ok and body and body.error) then return Utils.error("Failed to parse error response", { once = true, title = "Avante" }) end local error_msg = body.error.message Utils.error(error_msg, { once = true, title = "Avante" }) end --- Run a command and capture its output. Time out after 10 seconds ---@param ... string Command and its arguments ---@return string stdout ---@return integer exit code (0 for success, 124 for timeout, etc) local function run_command(...) local args = { ... } local result = vim.system(args, { text = true }):wait(10000) -- Wait up to 10 seconds -- result.code will be 124 if the command times out. return result.stdout, result.code end --- get_aws_credentials returns aws credentials using the aws cli ---@param region string ---@param profile string ---@return AWSCreds function M:get_aws_credentials(region, profile) local awsCreds = { access_key_id = "", secret_access_key = "", session_token = "", } local args = { "aws", "configure", "export-credentials" } if profile and profile ~= "" then table.insert(args, "--profile") table.insert(args, profile) end if region and region ~= "" then table.insert(args, "--region") table.insert(args, region) end -- run aws configure export-credentials and capture the json output local start_time = vim.uv.hrtime() local output, exit_code = run_command(unpack(args)) if exit_code == 0 then local credentials = vim.json.decode(output) awsCreds.access_key_id = credentials.AccessKeyId awsCreds.secret_access_key = credentials.SecretAccessKey awsCreds.session_token = credentials.SessionToken else print("Failed to run AWS command") end local end_time = vim.uv.hrtime() local duration_ms = (end_time - start_time) / 1000000 Utils.debug(string.format("AWS credentials fetch took %.2f ms", duration_ms)) return awsCreds end --- check_aws_cli_installed returns true when the aws cli is installed --- @return boolean function M.check_aws_cli_installed() local _, exit_code = run_command("aws", "--version") return exit_code == 0 end --- check_curl_version_supports_aws_sig checks if the given curl version supports aws sigv4 correctly --- we require at least version 8.10.0 because it contains critical fixes for aws sigv4 support --- https://curl.se/ch/8.10.0.html --- @param version_string string The curl version string to check --- @return boolean function M.check_curl_version_supports_aws_sig(version_string) -- Extract the version number local major, minor = version_string:match("curl (%d+)%.(%d+)") if major and minor then major = tonumber(major) minor = tonumber(minor) -- Check if the version is at least 8.10 if major > 8 or (major == 8 and minor >= 10) then return true end end return false end --- check_curl_supports_aws_sig returns true when the installed curl version supports aws sigv4 --- @return boolean function M.check_curl_supports_aws_sig() local output, exit_code = run_command("curl", "--version") if exit_code ~= 0 then return false end -- Get first line of output which contains version info local version_string = output:match("^[^\n]+") return M.check_curl_version_supports_aws_sig(version_string) end return M