fix: bedrock exception could be found at end of a stream (#2654)
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -5,7 +5,6 @@
|
||||
|
||||
.venv
|
||||
__pycache__/
|
||||
data/
|
||||
|
||||
# Neovim plugin specific files
|
||||
plugin/packer_compiled.lua
|
||||
|
||||
@@ -79,6 +79,16 @@ function M:build_bedrock_payload(prompt_opts, request_body)
|
||||
return model_handler.build_bedrock_payload(self, prompt_opts, request_body)
|
||||
end
|
||||
|
||||
local function parse_exception(data)
|
||||
local exceptions_found = {}
|
||||
local bedrock_match = data:gmatch("exception(%b{})")
|
||||
for bedrock_data_match in bedrock_match do
|
||||
local jsn = vim.json.decode(bedrock_data_match)
|
||||
table.insert(exceptions_found, "- " .. jsn.message)
|
||||
end
|
||||
return exceptions_found
|
||||
end
|
||||
|
||||
function M:parse_stream_data(ctx, data, opts)
|
||||
-- @NOTE: Decode and process Bedrock response
|
||||
-- Each response contains a Base64-encoded `bytes` field, which is decoded into JSON.
|
||||
@@ -90,15 +100,23 @@ function M:parse_stream_data(ctx, data, opts)
|
||||
local json = vim.json.decode(data_stream)
|
||||
self:parse_response(ctx, data_stream, json.type, opts)
|
||||
end
|
||||
local exceptions = parse_exception(data)
|
||||
if #exceptions > 0 then
|
||||
Utils.debug("Bedrock exceptions: ", vim.fn.json_encode(exceptions))
|
||||
if opts.on_chunk then
|
||||
opts.on_chunk("\n**Exception caught**\n\n")
|
||||
opts.on_chunk(table.concat(exceptions, "\n"))
|
||||
end
|
||||
vim.schedule(function() opts.on_stop({ reason = "error" }) end)
|
||||
end
|
||||
end
|
||||
|
||||
function M:parse_response_without_stream(data, event_state, opts)
|
||||
if opts.on_chunk == nil then return end
|
||||
local bedrock_match = data:gmatch("exception(%b{})")
|
||||
opts.on_chunk("\n**Exception caught**\n\n")
|
||||
for bedrock_data_match in bedrock_match do
|
||||
local jsn = vim.json.decode(bedrock_data_match)
|
||||
opts.on_chunk("- " .. jsn.message .. "\n")
|
||||
local exceptions = parse_exception(data)
|
||||
if #exceptions > 0 then
|
||||
opts.on_chunk("\n**Exception caught**\n\n")
|
||||
opts.on_chunk(table.concat(exceptions, "\n"))
|
||||
end
|
||||
vim.schedule(function() opts.on_stop({ reason = "complete" }) end)
|
||||
end
|
||||
|
||||
15
lua/avante/utils/test.lua
Normal file
15
lua/avante/utils/test.lua
Normal file
@@ -0,0 +1,15 @@
|
||||
-- This is a helper for unit tests.
|
||||
local M = {}
|
||||
|
||||
function M.read_file(fn)
|
||||
fn = vim.uv.cwd() .. "/" .. fn
|
||||
local file = io.open(fn, "r")
|
||||
if file then
|
||||
local data = file:read("*all")
|
||||
file:close()
|
||||
return data
|
||||
end
|
||||
return fn
|
||||
end
|
||||
|
||||
return M
|
||||
BIN
tests/data/bedrock_response_stream.bin
Normal file
BIN
tests/data/bedrock_response_stream.bin
Normal file
Binary file not shown.
BIN
tests/data/bedrock_response_stream_with_exception.bin
Normal file
BIN
tests/data/bedrock_response_stream_with_exception.bin
Normal file
Binary file not shown.
@@ -1,6 +1,37 @@
|
||||
local bedrock_provider = require("avante.providers.bedrock")
|
||||
|
||||
local test_util = require("avante.utils.test")
|
||||
local Config = require("avante.config")
|
||||
Config.setup({})
|
||||
|
||||
describe("bedrock_provider", function()
|
||||
describe("parse_stream_data", function()
|
||||
it("should parse response in a stream.", function()
|
||||
local data = test_util.read_file("tests/data/bedrock_response_stream.bin")
|
||||
local message = ""
|
||||
bedrock_provider:parse_stream_data({}, data, {
|
||||
on_chunk = function(msg) message = message .. msg end,
|
||||
on_stop = function() end,
|
||||
})
|
||||
assert.equals(
|
||||
"I'll help you fix errors in the HelloLog4j.java file. Let me first understand what errors might be present by examining the code and related files.",
|
||||
message
|
||||
)
|
||||
end)
|
||||
|
||||
it("should parse exception inside a stream.", function()
|
||||
local data = test_util.read_file("tests/data/bedrock_response_stream_with_exception.bin")
|
||||
local message = ""
|
||||
bedrock_provider:parse_stream_data({}, data, {
|
||||
on_chunk = function(msg) message = msg end,
|
||||
})
|
||||
assert.equals(
|
||||
"- Too many requests, please wait before trying again. You have sent too many requests. Wait before trying again.",
|
||||
message
|
||||
)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("check_curl_version_supports_aws_sig", function()
|
||||
it(
|
||||
"should return true for curl version 8.10.0",
|
||||
|
||||
Reference in New Issue
Block a user