fix: react prompts (#2537)
This commit is contained in:
@@ -113,29 +113,25 @@ function M:parse_messages(opts)
|
||||
end
|
||||
if not provider_conf.disable_tools and use_ReAct_prompt then
|
||||
if content_items[1].type == "tool_result" then
|
||||
local tool_use = nil
|
||||
local tool_use_msg = nil
|
||||
for _, msg_ in ipairs(opts.messages) do
|
||||
if type(msg_.content) == "table" and #msg_.content > 0 then
|
||||
if msg_.content[1].type == "tool_use" and msg_.content[1].id == content_items[1].tool_use_id then
|
||||
tool_use = msg_
|
||||
tool_use_msg = msg_
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
if tool_use then
|
||||
if tool_use_msg then
|
||||
table.insert(contents, {
|
||||
role = "model",
|
||||
parts = {
|
||||
{ text = Utils.tool_use_to_xml(tool_use.content[1]) },
|
||||
{ text = Utils.tool_use_to_xml(tool_use_msg.content[1]) },
|
||||
},
|
||||
})
|
||||
role = "user"
|
||||
table.insert(parts, {
|
||||
text = "["
|
||||
.. tool_use.content[1].name
|
||||
.. " for '"
|
||||
.. (tool_use.content[1].input.path or tool_use.content[1].input.rel_path or "")
|
||||
.. "'] Result:",
|
||||
text = "The result of tool use " .. Utils.tool_use_to_xml(tool_use_msg.content[1]) .. " is:\n",
|
||||
})
|
||||
table.insert(parts, {
|
||||
text = content_items[1].content,
|
||||
@@ -189,6 +185,8 @@ function M.prepare_request_body(provider_instance, prompt_opts, provider_conf, r
|
||||
|
||||
local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true
|
||||
|
||||
if use_ReAct_prompt then request_body.generationConfig.stopSequences = { "</tool_use>" } end
|
||||
|
||||
local disable_tools = provider_conf.disable_tools or false
|
||||
|
||||
if not use_ReAct_prompt and not disable_tools and prompt_opts.tools then
|
||||
|
||||
@@ -3,7 +3,7 @@ local Config = require("avante.config")
|
||||
local Clipboard = require("avante.clipboard")
|
||||
local Providers = require("avante.providers")
|
||||
local HistoryMessage = require("avante.history.message")
|
||||
local ReActParser = require("avante.libs.ReAct_parser")
|
||||
local ReActParser = require("avante.libs.ReAct_parser2")
|
||||
local JsonParser = require("avante.libs.jsonparser")
|
||||
local Prompts = require("avante.utils.prompts")
|
||||
local LlmTools = require("avante.llm_tools")
|
||||
@@ -130,24 +130,20 @@ function M:parse_messages(opts)
|
||||
end
|
||||
if not provider_conf.disable_tools and use_ReAct_prompt then
|
||||
if msg.content[1].type == "tool_result" then
|
||||
local tool_use = nil
|
||||
local tool_use_msg = nil
|
||||
for _, msg_ in ipairs(opts.messages) do
|
||||
if type(msg_.content) == "table" and #msg_.content > 0 then
|
||||
if msg_.content[1].type == "tool_use" and msg_.content[1].id == msg.content[1].tool_use_id then
|
||||
tool_use = msg_
|
||||
tool_use_msg = msg_
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
if tool_use then
|
||||
if tool_use_msg then
|
||||
msg.role = "user"
|
||||
table.insert(content, {
|
||||
type = "text",
|
||||
text = "["
|
||||
.. tool_use.content[1].name
|
||||
.. " for '"
|
||||
.. (tool_use.content[1].input.path or tool_use.content[1].input.rel_path or "")
|
||||
.. "'] Result:",
|
||||
text = "The result of tool use " .. Utils.tool_use_to_xml(tool_use_msg.content[1]) .. " is:\n",
|
||||
})
|
||||
table.insert(content, {
|
||||
type = "text",
|
||||
@@ -258,7 +254,6 @@ function M:add_text_message(ctx, text, state, opts)
|
||||
end
|
||||
local cleaned_xml_content = table.concat(cleaned_xml_lines, "\n")
|
||||
local xml = ReActParser.parse(cleaned_xml_content)
|
||||
local has_tool_use = false
|
||||
if xml and #xml > 0 then
|
||||
local new_content_list = {}
|
||||
local xml_md_openned = false
|
||||
@@ -293,42 +288,45 @@ function M:add_text_message(ctx, text, state, opts)
|
||||
end
|
||||
end
|
||||
if next(input) ~= nil then
|
||||
has_tool_use = true
|
||||
local msg_uuid = ctx.content_uuid .. "-" .. idx
|
||||
local tool_use_id = msg_uuid
|
||||
local tool_message_state = item.partial and "generating" or "generated"
|
||||
local msg_ = HistoryMessage:new("assistant", {
|
||||
type = "tool_use",
|
||||
name = item.tool_name,
|
||||
id = tool_use_id,
|
||||
input = input,
|
||||
}, {
|
||||
state = state,
|
||||
state = tool_message_state,
|
||||
uuid = msg_uuid,
|
||||
turn_id = ctx.turn_id,
|
||||
})
|
||||
msgs[#msgs + 1] = msg_
|
||||
ctx.tool_use_list = ctx.tool_use_list or {}
|
||||
local input_json = type(input) == "string" and input or vim.json.encode(input)
|
||||
local exists = false
|
||||
for _, tool_use in ipairs(ctx.tool_use_list) do
|
||||
if tool_use.id == tool_use_id then
|
||||
tool_use.input_json = input
|
||||
tool_use.input_json = input_json
|
||||
exists = true
|
||||
end
|
||||
end
|
||||
if not exists then
|
||||
ctx.tool_use_list[#ctx.tool_use_list + 1] = {
|
||||
uuid = tool_use_id,
|
||||
id = tool_use_id,
|
||||
name = item.tool_name,
|
||||
input_json = input,
|
||||
input_json = input_json,
|
||||
state = "generating",
|
||||
}
|
||||
end
|
||||
opts.on_stop({ reason = "tool_use", streaming_tool_use = item.partial })
|
||||
end
|
||||
::continue::
|
||||
end
|
||||
msg.message.content = table.concat(new_content_list, "\n")
|
||||
msg.message.content = table.concat(new_content_list, "\n"):gsub("\n+$", "\n")
|
||||
end
|
||||
if opts.on_messages_add then opts.on_messages_add(msgs) end
|
||||
if has_tool_use and state == "generating" then opts.on_stop({ reason = "tool_use", streaming_tool_use = true }) end
|
||||
end
|
||||
|
||||
function M:add_thinking_message(ctx, text, state, opts)
|
||||
@@ -537,6 +535,9 @@ function M:parse_curl_args(prompt_opts)
|
||||
Utils.debug("endpoint", provider_conf.endpoint)
|
||||
Utils.debug("model", provider_conf.model)
|
||||
|
||||
local stop = nil
|
||||
if use_ReAct_prompt then stop = { "</tool_use>" } end
|
||||
|
||||
return {
|
||||
url = Utils.url_join(provider_conf.endpoint, "/chat/completions"),
|
||||
proxy = provider_conf.proxy,
|
||||
@@ -545,6 +546,7 @@ function M:parse_curl_args(prompt_opts)
|
||||
body = vim.tbl_deep_extend("force", {
|
||||
model = provider_conf.model,
|
||||
messages = self:parse_messages(prompt_opts),
|
||||
stop = stop,
|
||||
stream = true,
|
||||
stream_options = not M.is_mistral(provider_conf.endpoint) and {
|
||||
include_usage = true,
|
||||
|
||||
Reference in New Issue
Block a user