diff --git a/.gitignore b/.gitignore index 039cd59..e65c79b 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,6 @@ .venv __pycache__/ -data/ # Neovim plugin specific files plugin/packer_compiled.lua diff --git a/lua/avante/providers/bedrock.lua b/lua/avante/providers/bedrock.lua index eb4ef84..2faea64 100644 --- a/lua/avante/providers/bedrock.lua +++ b/lua/avante/providers/bedrock.lua @@ -79,6 +79,16 @@ function M:build_bedrock_payload(prompt_opts, request_body) return model_handler.build_bedrock_payload(self, prompt_opts, request_body) end +local function parse_exception(data) + local exceptions_found = {} + local bedrock_match = data:gmatch("exception(%b{})") + for bedrock_data_match in bedrock_match do + local jsn = vim.json.decode(bedrock_data_match) + table.insert(exceptions_found, "- " .. jsn.message) + end + return exceptions_found +end + function M:parse_stream_data(ctx, data, opts) -- @NOTE: Decode and process Bedrock response -- Each response contains a Base64-encoded `bytes` field, which is decoded into JSON. @@ -90,15 +100,23 @@ function M:parse_stream_data(ctx, data, opts) local json = vim.json.decode(data_stream) self:parse_response(ctx, data_stream, json.type, opts) end + local exceptions = parse_exception(data) + if #exceptions > 0 then + Utils.debug("Bedrock exceptions: ", vim.fn.json_encode(exceptions)) + if opts.on_chunk then + opts.on_chunk("\n**Exception caught**\n\n") + opts.on_chunk(table.concat(exceptions, "\n")) + end + vim.schedule(function() opts.on_stop({ reason = "error" }) end) + end end function M:parse_response_without_stream(data, event_state, opts) if opts.on_chunk == nil then return end - local bedrock_match = data:gmatch("exception(%b{})") - opts.on_chunk("\n**Exception caught**\n\n") - for bedrock_data_match in bedrock_match do - local jsn = vim.json.decode(bedrock_data_match) - opts.on_chunk("- " .. jsn.message .. "\n") + local exceptions = parse_exception(data) + if #exceptions > 0 then + opts.on_chunk("\n**Exception caught**\n\n") + opts.on_chunk(table.concat(exceptions, "\n")) end vim.schedule(function() opts.on_stop({ reason = "complete" }) end) end diff --git a/lua/avante/utils/test.lua b/lua/avante/utils/test.lua new file mode 100644 index 0000000..c273b69 --- /dev/null +++ b/lua/avante/utils/test.lua @@ -0,0 +1,15 @@ +-- This is a helper for unit tests. +local M = {} + +function M.read_file(fn) + fn = vim.uv.cwd() .. "/" .. fn + local file = io.open(fn, "r") + if file then + local data = file:read("*all") + file:close() + return data + end + return fn +end + +return M diff --git a/tests/data/bedrock_response_stream.bin b/tests/data/bedrock_response_stream.bin new file mode 100644 index 0000000..91f9ad4 Binary files /dev/null and b/tests/data/bedrock_response_stream.bin differ diff --git a/tests/data/bedrock_response_stream_with_exception.bin b/tests/data/bedrock_response_stream_with_exception.bin new file mode 100644 index 0000000..dd20426 Binary files /dev/null and b/tests/data/bedrock_response_stream_with_exception.bin differ diff --git a/tests/providers/bedrock_spec.lua b/tests/providers/bedrock_spec.lua index 9128cb1..11fca92 100644 --- a/tests/providers/bedrock_spec.lua +++ b/tests/providers/bedrock_spec.lua @@ -1,6 +1,37 @@ local bedrock_provider = require("avante.providers.bedrock") +local test_util = require("avante.utils.test") +local Config = require("avante.config") +Config.setup({}) + describe("bedrock_provider", function() + describe("parse_stream_data", function() + it("should parse response in a stream.", function() + local data = test_util.read_file("tests/data/bedrock_response_stream.bin") + local message = "" + bedrock_provider:parse_stream_data({}, data, { + on_chunk = function(msg) message = message .. msg end, + on_stop = function() end, + }) + assert.equals( + "I'll help you fix errors in the HelloLog4j.java file. Let me first understand what errors might be present by examining the code and related files.", + message + ) + end) + + it("should parse exception inside a stream.", function() + local data = test_util.read_file("tests/data/bedrock_response_stream_with_exception.bin") + local message = "" + bedrock_provider:parse_stream_data({}, data, { + on_chunk = function(msg) message = msg end, + }) + assert.equals( + "- Too many requests, please wait before trying again. You have sent too many requests. Wait before trying again.", + message + ) + end) + end) + describe("check_curl_version_supports_aws_sig", function() it( "should return true for curl version 8.10.0",