fix: limit tool use count (#2294)
This commit is contained in:
@@ -166,6 +166,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
content = content_block.text,
|
||||
}, {
|
||||
state = "generating",
|
||||
session_id = ctx.session_id,
|
||||
})
|
||||
content_block.uuid = msg.uuid
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
@@ -184,6 +185,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
},
|
||||
}, {
|
||||
state = "generating",
|
||||
session_id = ctx.session_id,
|
||||
})
|
||||
content_block.uuid = msg.uuid
|
||||
opts.on_messages_add({ msg })
|
||||
@@ -203,6 +205,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
},
|
||||
}, {
|
||||
state = "generating",
|
||||
session_id = ctx.session_id,
|
||||
})
|
||||
content_block.uuid = msg.uuid
|
||||
opts.on_messages_add({ msg })
|
||||
@@ -231,6 +234,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
}, {
|
||||
state = "generating",
|
||||
uuid = content_block.uuid,
|
||||
session_id = ctx.session_id,
|
||||
})
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
elseif jsn.delta.type == "text_delta" then
|
||||
@@ -242,6 +246,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
}, {
|
||||
state = "generating",
|
||||
uuid = content_block.uuid,
|
||||
session_id = ctx.session_id,
|
||||
})
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
elseif jsn.delta.type == "signature_delta" then
|
||||
@@ -260,6 +265,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
}, {
|
||||
state = "generated",
|
||||
uuid = content_block.uuid,
|
||||
session_id = ctx.session_id,
|
||||
})
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
end
|
||||
@@ -278,6 +284,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
}, {
|
||||
state = "generated",
|
||||
uuid = content_block.uuid,
|
||||
session_id = ctx.session_id,
|
||||
})
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
end
|
||||
@@ -301,6 +308,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
}, {
|
||||
state = "generated",
|
||||
uuid = content_block.uuid,
|
||||
session_id = ctx.session_id,
|
||||
})
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
end
|
||||
|
||||
@@ -248,7 +248,7 @@ function M:parse_response(ctx, data_stream, _, opts)
|
||||
input_json = vim.json.encode(part.functionCall.args),
|
||||
}
|
||||
table.insert(ctx.tool_use_list, tool_use)
|
||||
OpenAI:add_tool_use_message(tool_use, "generated", opts)
|
||||
OpenAI:add_tool_use_message(ctx, tool_use, "generated", opts)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@@ -215,7 +215,7 @@ function M:finish_pending_messages(ctx, opts)
|
||||
if ctx.content ~= nil and ctx.content ~= "" then self:add_text_message(ctx, "", "generated", opts) end
|
||||
if ctx.tool_use_list then
|
||||
for _, tool_use in ipairs(ctx.tool_use_list) do
|
||||
if tool_use.state == "generating" then self:add_tool_use_message(tool_use, "generated", opts) end
|
||||
if tool_use.state == "generating" then self:add_tool_use_message(ctx, tool_use, "generated", opts) end
|
||||
end
|
||||
end
|
||||
end
|
||||
@@ -318,6 +318,7 @@ function M:add_text_message(ctx, text, state, opts)
|
||||
}, {
|
||||
state = state,
|
||||
uuid = msg_uuid,
|
||||
session_id = ctx.session_id,
|
||||
})
|
||||
msgs[#msgs + 1] = msg_
|
||||
ctx.tool_use_list = ctx.tool_use_list or {}
|
||||
@@ -351,12 +352,13 @@ function M:add_thinking_message(ctx, text, state, opts)
|
||||
}, {
|
||||
state = state,
|
||||
uuid = ctx.reasonging_content_uuid,
|
||||
session_id = ctx.session_id,
|
||||
})
|
||||
ctx.reasonging_content_uuid = msg.uuid
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
end
|
||||
|
||||
function M:add_tool_use_message(tool_use, state, opts)
|
||||
function M:add_tool_use_message(ctx, tool_use, state, opts)
|
||||
local jsn = JsonParser.parse(tool_use.input_json)
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
@@ -371,6 +373,7 @@ function M:add_tool_use_message(tool_use, state, opts)
|
||||
}, {
|
||||
state = state,
|
||||
uuid = tool_use.uuid,
|
||||
session_id = ctx.session_id,
|
||||
})
|
||||
tool_use.uuid = msg.uuid
|
||||
tool_use.state = state
|
||||
@@ -438,7 +441,7 @@ function M:parse_response(ctx, data_stream, _, opts)
|
||||
if not ctx.tool_use_list[tool_call.index + 1] then
|
||||
if tool_call.index > 0 and ctx.tool_use_list[tool_call.index] then
|
||||
local prev_tool_use = ctx.tool_use_list[tool_call.index]
|
||||
self:add_tool_use_message(prev_tool_use, "generated", opts)
|
||||
self:add_tool_use_message(ctx, prev_tool_use, "generated", opts)
|
||||
end
|
||||
local tool_use = {
|
||||
name = tool_call["function"].name,
|
||||
@@ -446,11 +449,11 @@ function M:parse_response(ctx, data_stream, _, opts)
|
||||
input_json = type(tool_call["function"].arguments) == "string" and tool_call["function"].arguments or "",
|
||||
}
|
||||
ctx.tool_use_list[tool_call.index + 1] = tool_use
|
||||
self:add_tool_use_message(tool_use, "generating", opts)
|
||||
self:add_tool_use_message(ctx, tool_use, "generating", opts)
|
||||
else
|
||||
local tool_use = ctx.tool_use_list[tool_call.index + 1]
|
||||
tool_use.input_json = tool_use.input_json .. tool_call["function"].arguments
|
||||
self:add_tool_use_message(tool_use, "generating", opts)
|
||||
self:add_tool_use_message(ctx, tool_use, "generating", opts)
|
||||
end
|
||||
end
|
||||
elseif delta.content then
|
||||
|
||||
Reference in New Issue
Block a user