fix: tool calling (#2297)
This commit is contained in:
@@ -166,7 +166,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
content = content_block.text,
|
||||
}, {
|
||||
state = "generating",
|
||||
session_id = ctx.session_id,
|
||||
turn_id = ctx.turn_id,
|
||||
})
|
||||
content_block.uuid = msg.uuid
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
@@ -185,7 +185,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
},
|
||||
}, {
|
||||
state = "generating",
|
||||
session_id = ctx.session_id,
|
||||
turn_id = ctx.turn_id,
|
||||
})
|
||||
content_block.uuid = msg.uuid
|
||||
opts.on_messages_add({ msg })
|
||||
@@ -205,11 +205,11 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
},
|
||||
}, {
|
||||
state = "generating",
|
||||
session_id = ctx.session_id,
|
||||
turn_id = ctx.turn_id,
|
||||
})
|
||||
content_block.uuid = msg.uuid
|
||||
opts.on_messages_add({ msg })
|
||||
opts.on_stop({ reason = "tool_use", streaming_tool_use = true })
|
||||
-- opts.on_stop({ reason = "tool_use", streaming_tool_use = true })
|
||||
end
|
||||
elseif event_state == "content_block_delta" then
|
||||
local ok, jsn = pcall(vim.json.decode, data_stream)
|
||||
@@ -234,7 +234,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
}, {
|
||||
state = "generating",
|
||||
uuid = content_block.uuid,
|
||||
session_id = ctx.session_id,
|
||||
turn_id = ctx.turn_id,
|
||||
})
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
elseif jsn.delta.type == "text_delta" then
|
||||
@@ -246,7 +246,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
}, {
|
||||
state = "generating",
|
||||
uuid = content_block.uuid,
|
||||
session_id = ctx.session_id,
|
||||
turn_id = ctx.turn_id,
|
||||
})
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
elseif jsn.delta.type == "signature_delta" then
|
||||
@@ -265,7 +265,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
}, {
|
||||
state = "generated",
|
||||
uuid = content_block.uuid,
|
||||
session_id = ctx.session_id,
|
||||
turn_id = ctx.turn_id,
|
||||
})
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
end
|
||||
@@ -284,7 +284,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
}, {
|
||||
state = "generated",
|
||||
uuid = content_block.uuid,
|
||||
session_id = ctx.session_id,
|
||||
turn_id = ctx.turn_id,
|
||||
})
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
end
|
||||
@@ -308,7 +308,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
}, {
|
||||
state = "generated",
|
||||
uuid = content_block.uuid,
|
||||
session_id = ctx.session_id,
|
||||
turn_id = ctx.turn_id,
|
||||
})
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
end
|
||||
@@ -317,6 +317,8 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
if not ok then return end
|
||||
if jsn.delta.stop_reason == "end_turn" then
|
||||
opts.on_stop({ reason = "complete", usage = jsn.usage })
|
||||
elseif jsn.delta.stop_reason == "max_tokens" then
|
||||
opts.on_stop({ reason = "max_tokens", usage = jsn.usage })
|
||||
elseif jsn.delta.stop_reason == "tool_use" then
|
||||
opts.on_stop({
|
||||
reason = "tool_use",
|
||||
|
||||
@@ -243,7 +243,7 @@ function M:parse_response(ctx, data_stream, _, opts)
|
||||
if not ctx.function_call_id then ctx.function_call_id = 0 end
|
||||
ctx.function_call_id = ctx.function_call_id + 1
|
||||
local tool_use = {
|
||||
id = ctx.session_id .. "-" .. tostring(ctx.function_call_id),
|
||||
id = ctx.turn_id .. "-" .. tostring(ctx.function_call_id),
|
||||
name = part.functionCall.name,
|
||||
input_json = vim.json.encode(part.functionCall.args),
|
||||
}
|
||||
|
||||
@@ -264,7 +264,6 @@ function M:add_text_message(ctx, text, state, opts)
|
||||
local cleaned_xml_content = table.concat(cleaned_xml_lines, "\n")
|
||||
local stream_parser = XMLParser.createStreamParser()
|
||||
stream_parser:addData(cleaned_xml_content)
|
||||
local has_tool_use = false
|
||||
local xml = stream_parser:getAllElements()
|
||||
if xml then
|
||||
local new_content_list = {}
|
||||
@@ -318,7 +317,7 @@ function M:add_text_message(ctx, text, state, opts)
|
||||
}, {
|
||||
state = state,
|
||||
uuid = msg_uuid,
|
||||
session_id = ctx.session_id,
|
||||
turn_id = ctx.turn_id,
|
||||
})
|
||||
msgs[#msgs + 1] = msg_
|
||||
ctx.tool_use_list = ctx.tool_use_list or {}
|
||||
@@ -327,14 +326,13 @@ function M:add_text_message(ctx, text, state, opts)
|
||||
name = item._name,
|
||||
input_json = input,
|
||||
}
|
||||
has_tool_use = true
|
||||
end
|
||||
if #new_content_list > 0 then msg.message.content = table.concat(new_content_list, "\n") end
|
||||
::continue::
|
||||
end
|
||||
end
|
||||
if opts.on_messages_add then opts.on_messages_add(msgs) end
|
||||
if has_tool_use and state == "generating" then opts.on_stop({ reason = "tool_use", streaming_tool_use = true }) end
|
||||
-- if has_tool_use and state == "generating" then opts.on_stop({ reason = "tool_use", streaming_tool_use = true }) end
|
||||
end
|
||||
|
||||
function M:add_thinking_message(ctx, text, state, opts)
|
||||
@@ -352,7 +350,7 @@ function M:add_thinking_message(ctx, text, state, opts)
|
||||
}, {
|
||||
state = state,
|
||||
uuid = ctx.reasonging_content_uuid,
|
||||
session_id = ctx.session_id,
|
||||
turn_id = ctx.turn_id,
|
||||
})
|
||||
ctx.reasonging_content_uuid = msg.uuid
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
@@ -373,38 +371,25 @@ function M:add_tool_use_message(ctx, tool_use, state, opts)
|
||||
}, {
|
||||
state = state,
|
||||
uuid = tool_use.uuid,
|
||||
session_id = ctx.session_id,
|
||||
turn_id = ctx.turn_id,
|
||||
})
|
||||
tool_use.uuid = msg.uuid
|
||||
tool_use.state = state
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
if state == "generating" then opts.on_stop({ reason = "tool_use", streaming_tool_use = true }) end
|
||||
-- if state == "generating" then opts.on_stop({ reason = "tool_use", streaming_tool_use = true }) end
|
||||
end
|
||||
|
||||
function M:parse_response(ctx, data_stream, _, opts)
|
||||
local orig_on_stop = opts.on_stop
|
||||
local stopped = false
|
||||
---@param stop_opts AvanteLLMStopCallbackOptions
|
||||
opts.on_stop = function(stop_opts)
|
||||
if stop_opts and not stop_opts.streaming_tool_use then
|
||||
if stopped then return end
|
||||
stopped = true
|
||||
end
|
||||
return orig_on_stop(stop_opts)
|
||||
end
|
||||
if data_stream:match('"%[DONE%]":') then
|
||||
if data_stream:match('"%[DONE%]":') or data_stream == "[DONE]" then
|
||||
self:finish_pending_messages(ctx, opts)
|
||||
if ctx.tool_use_list and #ctx.tool_use_list > 0 then
|
||||
ctx.tool_use_list = {}
|
||||
opts.on_stop({ reason = "tool_use" })
|
||||
else
|
||||
opts.on_stop({ reason = "complete" })
|
||||
end
|
||||
return
|
||||
end
|
||||
if data_stream == "[DONE]" then
|
||||
opts.on_stop({ reason = "complete" })
|
||||
return
|
||||
end
|
||||
local jsn = vim.json.decode(data_stream)
|
||||
---@cast jsn AvanteOpenAIChatResponse
|
||||
if not jsn.choices then return end
|
||||
@@ -453,7 +438,7 @@ function M:parse_response(ctx, data_stream, _, opts)
|
||||
else
|
||||
local tool_use = ctx.tool_use_list[tool_call.index + 1]
|
||||
tool_use.input_json = tool_use.input_json .. tool_call["function"].arguments
|
||||
self:add_tool_use_message(ctx, tool_use, "generating", opts)
|
||||
-- self:add_tool_use_message(ctx, tool_use, "generating", opts)
|
||||
end
|
||||
end
|
||||
elseif delta.content then
|
||||
|
||||
Reference in New Issue
Block a user