feat: ReAct tool calling (#2104)

This commit is contained in:
yetone
2025-05-31 08:53:34 +08:00
committed by GitHub
parent 22418bff8b
commit bc403ddcbf
25 changed files with 1358 additions and 188 deletions

View File

@@ -3,6 +3,9 @@ local Config = require("avante.config")
local Clipboard = require("avante.clipboard")
local Providers = require("avante.providers")
local HistoryMessage = require("avante.history_message")
local XMLParser = require("avante.libs.xmlparser")
local Prompts = require("avante.utils.prompts")
local LlmTools = require("avante.llm_tools")
---@class AvanteProviderFunctor
local M = {}
@@ -76,10 +79,15 @@ function M:parse_messages(opts)
local messages = {}
local provider_conf, _ = Providers.parse_config(self)
local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true
local system_prompt = opts.system_prompt
if use_ReAct_prompt then system_prompt = Prompts.get_ReAct_system_prompt(provider_conf, opts) end
if self.is_reasoning_model(provider_conf.model) then
table.insert(messages, { role = "developer", content = opts.system_prompt })
table.insert(messages, { role = "developer", content = system_prompt })
else
table.insert(messages, { role = "system", content = opts.system_prompt })
table.insert(messages, { role = "system", content = system_prompt })
end
local has_tool_use = false
@@ -103,22 +111,50 @@ function M:parse_messages(opts)
url = "data:" .. item.source.media_type .. ";" .. item.source.type .. "," .. item.source.data,
},
})
elseif item.type == "tool_use" then
elseif item.type == "tool_use" and not use_ReAct_prompt then
has_tool_use = true
table.insert(tool_calls, {
id = item.id,
type = "function",
["function"] = { name = item.name, arguments = vim.json.encode(item.input) },
})
elseif item.type == "tool_result" and has_tool_use then
elseif item.type == "tool_result" and has_tool_use and not use_ReAct_prompt then
table.insert(
tool_results,
{ tool_call_id = item.tool_use_id, content = item.is_error and "Error: " .. item.content or item.content }
)
end
end
if not provider_conf.disable_tools and use_ReAct_prompt then
if msg.content[1].type == "tool_result" then
local tool_use = nil
for _, msg_ in ipairs(opts.messages) do
if type(msg_.content) == "table" and #msg_.content > 0 then
if msg_.content[1].type == "tool_use" and msg_.content[1].id == msg.content[1].tool_use_id then
tool_use = msg_
break
end
end
end
if tool_use then
msg.role = "user"
table.insert(content, {
type = "text",
text = "["
.. tool_use.content[1].name
.. " for '"
.. (tool_use.content[1].input.path or tool_use.content[1].input.rel_path or "")
.. "'] Result:",
})
table.insert(content, {
type = "text",
text = msg.content[1].content,
})
end
end
end
if #content > 0 then table.insert(messages, { role = self.role_map[msg.role], content = content }) end
if not provider_conf.disable_tools then
if not provider_conf.disable_tools and not use_ReAct_prompt then
if #tool_calls > 0 then
local last_message = messages[#messages]
if last_message and last_message.role == self.role_map["assistant"] and last_message.tool_calls then
@@ -183,7 +219,10 @@ function M:finish_pending_messages(ctx, opts)
end
end
local llm_tool_names = nil
function M:add_text_message(ctx, text, state, opts)
if llm_tool_names == nil then llm_tool_names = LlmTools.get_tool_names() end
if ctx.content == nil then ctx.content = "" end
ctx.content = ctx.content .. text
local msg = HistoryMessage:new({
@@ -194,7 +233,75 @@ function M:add_text_message(ctx, text, state, opts)
uuid = ctx.content_uuid,
})
ctx.content_uuid = msg.uuid
if opts.on_messages_add then opts.on_messages_add({ msg }) end
local msgs = { msg }
local stream_parser = XMLParser.createStreamParser()
stream_parser:addData(ctx.content)
local xml = stream_parser:getAllElements()
if xml then
local new_content_list = {}
local xml_md_openned = false
for idx, item in ipairs(xml) do
if item._name == "_text" then
local cleaned_lines = {}
local lines = vim.split(item._text, "\n")
for _, line in ipairs(lines) do
if line:match("^```xml") or line:match("^```tool_code") or line:match("^```tool_use") then
xml_md_openned = true
elseif line:match("^```$") then
if xml_md_openned then
xml_md_openned = false
else
table.insert(cleaned_lines, line)
end
else
table.insert(cleaned_lines, line)
end
end
table.insert(new_content_list, table.concat(cleaned_lines, "\n"))
goto continue
end
if not vim.tbl_contains(llm_tool_names, item._name) then goto continue end
local ok, input = pcall(vim.json.decode, item._text)
if not ok and item.children and #item.children > 0 then
input = {}
for _, item_ in ipairs(item.children) do
local ok_, input_ = pcall(vim.json.decode, item_._text)
if ok_ and input_ then
input[item_._name] = input_
else
input[item_._name] = item_._text
end
end
end
if input then
local tool_use_id = Utils.uuid()
local msg_ = HistoryMessage:new({
role = "assistant",
content = {
{
type = "tool_use",
name = item._name,
id = tool_use_id,
input = input,
},
},
}, {
state = state,
uuid = ctx.content_uuid .. "-" .. idx,
})
msgs[#msgs + 1] = msg_
ctx.tool_use_list = ctx.tool_use_list or {}
ctx.tool_use_list[#ctx.tool_use_list + 1] = {
id = tool_use_id,
name = item._name,
input_json = input,
}
end
if #new_content_list > 0 then msg.message.content = table.concat(new_content_list, "\n") end
::continue::
end
end
if opts.on_messages_add then opts.on_messages_add(msgs) end
end
function M:add_thinking_message(ctx, text, state, opts)
@@ -242,7 +349,11 @@ end
function M:parse_response(ctx, data_stream, _, opts)
if data_stream:match('"%[DONE%]":') then
self:finish_pending_messages(ctx, opts)
opts.on_stop({ reason = "complete" })
if ctx.tool_use_list and #ctx.tool_use_list > 0 then
opts.on_stop({ reason = "tool_use" })
else
opts.on_stop({ reason = "complete" })
end
return
end
if data_stream == "[DONE]" then return end
@@ -316,9 +427,13 @@ function M:parse_response(ctx, data_stream, _, opts)
self:add_text_message(ctx, delta.content, "generating", opts)
end
end
if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" then
if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" or choice.finish_reason == "length" then
self:finish_pending_messages(ctx, opts)
opts.on_stop({ reason = "complete" })
if ctx.tool_use_list and #ctx.tool_use_list > 0 then
opts.on_stop({ reason = "tool_use", usage = jsn.usage })
else
opts.on_stop({ reason = "complete", usage = jsn.usage })
end
end
if choice.finish_reason == "tool_calls" then
self:finish_pending_messages(ctx, opts)
@@ -372,8 +487,10 @@ function M:parse_curl_args(prompt_opts)
self.set_allowed_params(provider_conf, request_body)
local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true
local tools = nil
if not disable_tools and prompt_opts.tools then
if not disable_tools and prompt_opts.tools and not use_ReAct_prompt then
tools = {}
for _, tool in ipairs(prompt_opts.tools) do
table.insert(tools, self:transform_tool(tool))