fix: tool calling (#2297)

This commit is contained in:
yetone
2025-06-22 21:50:26 +08:00
committed by GitHub
parent db39f5fe1b
commit 3033556d5b
12 changed files with 127 additions and 99 deletions

View File

@@ -166,7 +166,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
content = content_block.text,
}, {
state = "generating",
session_id = ctx.session_id,
turn_id = ctx.turn_id,
})
content_block.uuid = msg.uuid
if opts.on_messages_add then opts.on_messages_add({ msg }) end
@@ -185,7 +185,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
},
}, {
state = "generating",
session_id = ctx.session_id,
turn_id = ctx.turn_id,
})
content_block.uuid = msg.uuid
opts.on_messages_add({ msg })
@@ -205,11 +205,11 @@ function M:parse_response(ctx, data_stream, event_state, opts)
},
}, {
state = "generating",
session_id = ctx.session_id,
turn_id = ctx.turn_id,
})
content_block.uuid = msg.uuid
opts.on_messages_add({ msg })
opts.on_stop({ reason = "tool_use", streaming_tool_use = true })
-- opts.on_stop({ reason = "tool_use", streaming_tool_use = true })
end
elseif event_state == "content_block_delta" then
local ok, jsn = pcall(vim.json.decode, data_stream)
@@ -234,7 +234,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
}, {
state = "generating",
uuid = content_block.uuid,
session_id = ctx.session_id,
turn_id = ctx.turn_id,
})
if opts.on_messages_add then opts.on_messages_add({ msg }) end
elseif jsn.delta.type == "text_delta" then
@@ -246,7 +246,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
}, {
state = "generating",
uuid = content_block.uuid,
session_id = ctx.session_id,
turn_id = ctx.turn_id,
})
if opts.on_messages_add then opts.on_messages_add({ msg }) end
elseif jsn.delta.type == "signature_delta" then
@@ -265,7 +265,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
}, {
state = "generated",
uuid = content_block.uuid,
session_id = ctx.session_id,
turn_id = ctx.turn_id,
})
if opts.on_messages_add then opts.on_messages_add({ msg }) end
end
@@ -284,7 +284,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
}, {
state = "generated",
uuid = content_block.uuid,
session_id = ctx.session_id,
turn_id = ctx.turn_id,
})
if opts.on_messages_add then opts.on_messages_add({ msg }) end
end
@@ -308,7 +308,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
}, {
state = "generated",
uuid = content_block.uuid,
session_id = ctx.session_id,
turn_id = ctx.turn_id,
})
if opts.on_messages_add then opts.on_messages_add({ msg }) end
end
@@ -317,6 +317,8 @@ function M:parse_response(ctx, data_stream, event_state, opts)
if not ok then return end
if jsn.delta.stop_reason == "end_turn" then
opts.on_stop({ reason = "complete", usage = jsn.usage })
elseif jsn.delta.stop_reason == "max_tokens" then
opts.on_stop({ reason = "max_tokens", usage = jsn.usage })
elseif jsn.delta.stop_reason == "tool_use" then
opts.on_stop({
reason = "tool_use",