feat: ReAct tool calling (#2104)
This commit is contained in:
@@ -1,10 +1,15 @@
|
||||
local Utils = require("avante.utils")
|
||||
local P = require("avante.providers")
|
||||
local Providers = require("avante.providers")
|
||||
local Config = require("avante.config")
|
||||
local Clipboard = require("avante.clipboard")
|
||||
local HistoryMessage = require("avante.history_message")
|
||||
local Prompts = require("avante.utils.prompts")
|
||||
|
||||
---@class AvanteProviderFunctor
|
||||
local M = {}
|
||||
|
||||
setmetatable(M, { __index = Providers.openai })
|
||||
|
||||
M.api_key_name = "" -- Ollama typically doesn't require API keys for local use
|
||||
|
||||
M.role_map = {
|
||||
@@ -12,35 +17,182 @@ M.role_map = {
|
||||
assistant = "assistant",
|
||||
}
|
||||
|
||||
M.parse_messages = P.openai.parse_messages
|
||||
M.is_reasoning_model = P.openai.is_reasoning_model
|
||||
function M:parse_messages(opts)
|
||||
local messages = {}
|
||||
local provider_conf, _ = Providers.parse_config(self)
|
||||
|
||||
local system_prompt = Prompts.get_ReAct_system_prompt(provider_conf, opts)
|
||||
|
||||
if self.is_reasoning_model(provider_conf.model) then
|
||||
table.insert(messages, { role = "developer", content = system_prompt })
|
||||
else
|
||||
table.insert(messages, { role = "system", content = system_prompt })
|
||||
end
|
||||
|
||||
vim.iter(opts.messages):each(function(msg)
|
||||
if type(msg.content) == "string" then
|
||||
table.insert(messages, { role = self.role_map[msg.role], content = msg.content })
|
||||
elseif type(msg.content) == "table" then
|
||||
local content = {}
|
||||
for _, item in ipairs(msg.content) do
|
||||
if type(item) == "string" then
|
||||
table.insert(content, { type = "text", text = item })
|
||||
elseif item.type == "text" then
|
||||
table.insert(content, { type = "text", text = item.text })
|
||||
elseif item.type == "image" then
|
||||
table.insert(content, {
|
||||
type = "image_url",
|
||||
image_url = {
|
||||
url = "data:" .. item.source.media_type .. ";" .. item.source.type .. "," .. item.source.data,
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
if not provider_conf.disable_tools then
|
||||
if msg.content[1].type == "tool_result" then
|
||||
local tool_use = nil
|
||||
for _, msg_ in ipairs(opts.messages) do
|
||||
if type(msg_.content) == "table" and #msg_.content > 0 then
|
||||
if msg_.content[1].type == "tool_use" and msg_.content[1].id == msg.content[1].tool_use_id then
|
||||
tool_use = msg_
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
if tool_use then
|
||||
msg.role = "user"
|
||||
table.insert(content, {
|
||||
type = "text",
|
||||
text = "["
|
||||
.. tool_use.content[1].name
|
||||
.. " for '"
|
||||
.. (tool_use.content[1].input.path or tool_use.content[1].input.rel_path or "")
|
||||
.. "'] Result:",
|
||||
})
|
||||
table.insert(content, {
|
||||
type = "text",
|
||||
text = msg.content[1].content,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
if #content > 0 then
|
||||
local text_content = {}
|
||||
for _, item in ipairs(content) do
|
||||
if type(item) == "table" and item.type == "text" then table.insert(text_content, item.text) end
|
||||
end
|
||||
table.insert(messages, { role = self.role_map[msg.role], content = table.concat(text_content, "\n\n") })
|
||||
end
|
||||
end
|
||||
end)
|
||||
|
||||
if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then
|
||||
local message_content = messages[#messages].content
|
||||
if type(message_content) ~= "table" or message_content[1] == nil then
|
||||
message_content = { { type = "text", text = message_content } }
|
||||
end
|
||||
for _, image_path in ipairs(opts.image_paths) do
|
||||
table.insert(message_content, {
|
||||
type = "image_url",
|
||||
image_url = {
|
||||
url = "data:image/png;base64," .. Clipboard.get_base64_content(image_path),
|
||||
},
|
||||
})
|
||||
end
|
||||
messages[#messages].content = message_content
|
||||
end
|
||||
|
||||
local final_messages = {}
|
||||
local prev_role = nil
|
||||
|
||||
vim.iter(messages):each(function(message)
|
||||
local role = message.role
|
||||
if role == prev_role and role ~= "tool" then
|
||||
if role == self.role_map["assistant"] then
|
||||
table.insert(final_messages, { role = self.role_map["user"], content = "Ok" })
|
||||
else
|
||||
table.insert(final_messages, { role = self.role_map["assistant"], content = "Ok, I understand." })
|
||||
end
|
||||
end
|
||||
prev_role = role
|
||||
table.insert(final_messages, message)
|
||||
end)
|
||||
|
||||
return final_messages
|
||||
end
|
||||
|
||||
function M:is_disable_stream() return false end
|
||||
|
||||
---@class avante.OllamaFunction
|
||||
---@field name string
|
||||
---@field arguments table
|
||||
|
||||
---@class avante.OllamaToolCall
|
||||
---@field function avante.OllamaFunction
|
||||
|
||||
---@param tool_calls avante.OllamaToolCall[]
|
||||
---@param opts AvanteLLMStreamOptions
|
||||
function M:add_tool_use_messages(tool_calls, opts)
|
||||
local msgs = {}
|
||||
for _, tool_call in ipairs(tool_calls) do
|
||||
local id = Utils.uuid()
|
||||
local func = tool_call["function"]
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = {
|
||||
{
|
||||
type = "tool_use",
|
||||
name = func.name,
|
||||
id = id,
|
||||
input = func.arguments,
|
||||
},
|
||||
},
|
||||
}, {
|
||||
state = "generated",
|
||||
uuid = id,
|
||||
})
|
||||
table.insert(msgs, msg)
|
||||
end
|
||||
if opts.on_messages_add then opts.on_messages_add(msgs) end
|
||||
end
|
||||
|
||||
function M:parse_stream_data(ctx, data, opts)
|
||||
local ok, json_data = pcall(vim.json.decode, data)
|
||||
if not ok or not json_data then
|
||||
local ok, jsn = pcall(vim.json.decode, data)
|
||||
if not ok or not jsn then
|
||||
-- Add debug logging
|
||||
Utils.debug("Failed to parse JSON", data)
|
||||
return
|
||||
end
|
||||
|
||||
if json_data.message and json_data.message.content then
|
||||
local content = json_data.message.content
|
||||
P.openai:add_text_message(ctx, content, "generating", opts)
|
||||
if content and content ~= "" and opts.on_chunk then opts.on_chunk(content) end
|
||||
if jsn.message then
|
||||
if jsn.message.content then
|
||||
local content = jsn.message.content
|
||||
if content and content ~= "" then
|
||||
Providers.openai:add_text_message(ctx, content, "generating", opts)
|
||||
if opts.on_chunk then opts.on_chunk(content) end
|
||||
end
|
||||
end
|
||||
if jsn.message.tool_calls then
|
||||
ctx.has_tool_use = true
|
||||
local tool_calls = jsn.message.tool_calls
|
||||
self:add_tool_use_messages(tool_calls, opts)
|
||||
end
|
||||
end
|
||||
|
||||
if json_data.done then
|
||||
P.openai:finish_pending_messages(ctx, opts)
|
||||
opts.on_stop({ reason = "complete" })
|
||||
if jsn.done then
|
||||
Providers.openai:finish_pending_messages(ctx, opts)
|
||||
if ctx.has_tool_use or (ctx.tool_use_list and #ctx.tool_use_list > 0) then
|
||||
opts.on_stop({ reason = "tool_use" })
|
||||
else
|
||||
opts.on_stop({ reason = "complete" })
|
||||
end
|
||||
return
|
||||
end
|
||||
end
|
||||
|
||||
---@param prompt_opts AvantePromptOptions
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(self)
|
||||
local provider_conf, request_body = Providers.parse_config(self)
|
||||
local keep_alive = provider_conf.keep_alive or "5m"
|
||||
|
||||
if not provider_conf.model or provider_conf.model == "" then error("Ollama model must be specified in config") end
|
||||
@@ -50,7 +202,7 @@ function M:parse_curl_args(prompt_opts)
|
||||
["Accept"] = "application/json",
|
||||
}
|
||||
|
||||
if P.env.require_api_key(provider_conf) then
|
||||
if Providers.env.require_api_key(provider_conf) then
|
||||
local api_key = self.parse_api_key()
|
||||
if api_key and api_key ~= "" then
|
||||
headers["Authorization"] = "Bearer " .. api_key
|
||||
@@ -66,7 +218,6 @@ function M:parse_curl_args(prompt_opts)
|
||||
model = provider_conf.model,
|
||||
messages = self:parse_messages(prompt_opts),
|
||||
stream = true,
|
||||
system = prompt_opts.system_prompt,
|
||||
keep_alive = keep_alive,
|
||||
}, request_body),
|
||||
}
|
||||
|
||||
@@ -3,6 +3,9 @@ local Config = require("avante.config")
|
||||
local Clipboard = require("avante.clipboard")
|
||||
local Providers = require("avante.providers")
|
||||
local HistoryMessage = require("avante.history_message")
|
||||
local XMLParser = require("avante.libs.xmlparser")
|
||||
local Prompts = require("avante.utils.prompts")
|
||||
local LlmTools = require("avante.llm_tools")
|
||||
|
||||
---@class AvanteProviderFunctor
|
||||
local M = {}
|
||||
@@ -76,10 +79,15 @@ function M:parse_messages(opts)
|
||||
local messages = {}
|
||||
local provider_conf, _ = Providers.parse_config(self)
|
||||
|
||||
local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true
|
||||
local system_prompt = opts.system_prompt
|
||||
|
||||
if use_ReAct_prompt then system_prompt = Prompts.get_ReAct_system_prompt(provider_conf, opts) end
|
||||
|
||||
if self.is_reasoning_model(provider_conf.model) then
|
||||
table.insert(messages, { role = "developer", content = opts.system_prompt })
|
||||
table.insert(messages, { role = "developer", content = system_prompt })
|
||||
else
|
||||
table.insert(messages, { role = "system", content = opts.system_prompt })
|
||||
table.insert(messages, { role = "system", content = system_prompt })
|
||||
end
|
||||
|
||||
local has_tool_use = false
|
||||
@@ -103,22 +111,50 @@ function M:parse_messages(opts)
|
||||
url = "data:" .. item.source.media_type .. ";" .. item.source.type .. "," .. item.source.data,
|
||||
},
|
||||
})
|
||||
elseif item.type == "tool_use" then
|
||||
elseif item.type == "tool_use" and not use_ReAct_prompt then
|
||||
has_tool_use = true
|
||||
table.insert(tool_calls, {
|
||||
id = item.id,
|
||||
type = "function",
|
||||
["function"] = { name = item.name, arguments = vim.json.encode(item.input) },
|
||||
})
|
||||
elseif item.type == "tool_result" and has_tool_use then
|
||||
elseif item.type == "tool_result" and has_tool_use and not use_ReAct_prompt then
|
||||
table.insert(
|
||||
tool_results,
|
||||
{ tool_call_id = item.tool_use_id, content = item.is_error and "Error: " .. item.content or item.content }
|
||||
)
|
||||
end
|
||||
end
|
||||
if not provider_conf.disable_tools and use_ReAct_prompt then
|
||||
if msg.content[1].type == "tool_result" then
|
||||
local tool_use = nil
|
||||
for _, msg_ in ipairs(opts.messages) do
|
||||
if type(msg_.content) == "table" and #msg_.content > 0 then
|
||||
if msg_.content[1].type == "tool_use" and msg_.content[1].id == msg.content[1].tool_use_id then
|
||||
tool_use = msg_
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
if tool_use then
|
||||
msg.role = "user"
|
||||
table.insert(content, {
|
||||
type = "text",
|
||||
text = "["
|
||||
.. tool_use.content[1].name
|
||||
.. " for '"
|
||||
.. (tool_use.content[1].input.path or tool_use.content[1].input.rel_path or "")
|
||||
.. "'] Result:",
|
||||
})
|
||||
table.insert(content, {
|
||||
type = "text",
|
||||
text = msg.content[1].content,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
if #content > 0 then table.insert(messages, { role = self.role_map[msg.role], content = content }) end
|
||||
if not provider_conf.disable_tools then
|
||||
if not provider_conf.disable_tools and not use_ReAct_prompt then
|
||||
if #tool_calls > 0 then
|
||||
local last_message = messages[#messages]
|
||||
if last_message and last_message.role == self.role_map["assistant"] and last_message.tool_calls then
|
||||
@@ -183,7 +219,10 @@ function M:finish_pending_messages(ctx, opts)
|
||||
end
|
||||
end
|
||||
|
||||
local llm_tool_names = nil
|
||||
|
||||
function M:add_text_message(ctx, text, state, opts)
|
||||
if llm_tool_names == nil then llm_tool_names = LlmTools.get_tool_names() end
|
||||
if ctx.content == nil then ctx.content = "" end
|
||||
ctx.content = ctx.content .. text
|
||||
local msg = HistoryMessage:new({
|
||||
@@ -194,7 +233,75 @@ function M:add_text_message(ctx, text, state, opts)
|
||||
uuid = ctx.content_uuid,
|
||||
})
|
||||
ctx.content_uuid = msg.uuid
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
local msgs = { msg }
|
||||
local stream_parser = XMLParser.createStreamParser()
|
||||
stream_parser:addData(ctx.content)
|
||||
local xml = stream_parser:getAllElements()
|
||||
if xml then
|
||||
local new_content_list = {}
|
||||
local xml_md_openned = false
|
||||
for idx, item in ipairs(xml) do
|
||||
if item._name == "_text" then
|
||||
local cleaned_lines = {}
|
||||
local lines = vim.split(item._text, "\n")
|
||||
for _, line in ipairs(lines) do
|
||||
if line:match("^```xml") or line:match("^```tool_code") or line:match("^```tool_use") then
|
||||
xml_md_openned = true
|
||||
elseif line:match("^```$") then
|
||||
if xml_md_openned then
|
||||
xml_md_openned = false
|
||||
else
|
||||
table.insert(cleaned_lines, line)
|
||||
end
|
||||
else
|
||||
table.insert(cleaned_lines, line)
|
||||
end
|
||||
end
|
||||
table.insert(new_content_list, table.concat(cleaned_lines, "\n"))
|
||||
goto continue
|
||||
end
|
||||
if not vim.tbl_contains(llm_tool_names, item._name) then goto continue end
|
||||
local ok, input = pcall(vim.json.decode, item._text)
|
||||
if not ok and item.children and #item.children > 0 then
|
||||
input = {}
|
||||
for _, item_ in ipairs(item.children) do
|
||||
local ok_, input_ = pcall(vim.json.decode, item_._text)
|
||||
if ok_ and input_ then
|
||||
input[item_._name] = input_
|
||||
else
|
||||
input[item_._name] = item_._text
|
||||
end
|
||||
end
|
||||
end
|
||||
if input then
|
||||
local tool_use_id = Utils.uuid()
|
||||
local msg_ = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = {
|
||||
{
|
||||
type = "tool_use",
|
||||
name = item._name,
|
||||
id = tool_use_id,
|
||||
input = input,
|
||||
},
|
||||
},
|
||||
}, {
|
||||
state = state,
|
||||
uuid = ctx.content_uuid .. "-" .. idx,
|
||||
})
|
||||
msgs[#msgs + 1] = msg_
|
||||
ctx.tool_use_list = ctx.tool_use_list or {}
|
||||
ctx.tool_use_list[#ctx.tool_use_list + 1] = {
|
||||
id = tool_use_id,
|
||||
name = item._name,
|
||||
input_json = input,
|
||||
}
|
||||
end
|
||||
if #new_content_list > 0 then msg.message.content = table.concat(new_content_list, "\n") end
|
||||
::continue::
|
||||
end
|
||||
end
|
||||
if opts.on_messages_add then opts.on_messages_add(msgs) end
|
||||
end
|
||||
|
||||
function M:add_thinking_message(ctx, text, state, opts)
|
||||
@@ -242,7 +349,11 @@ end
|
||||
function M:parse_response(ctx, data_stream, _, opts)
|
||||
if data_stream:match('"%[DONE%]":') then
|
||||
self:finish_pending_messages(ctx, opts)
|
||||
opts.on_stop({ reason = "complete" })
|
||||
if ctx.tool_use_list and #ctx.tool_use_list > 0 then
|
||||
opts.on_stop({ reason = "tool_use" })
|
||||
else
|
||||
opts.on_stop({ reason = "complete" })
|
||||
end
|
||||
return
|
||||
end
|
||||
if data_stream == "[DONE]" then return end
|
||||
@@ -316,9 +427,13 @@ function M:parse_response(ctx, data_stream, _, opts)
|
||||
self:add_text_message(ctx, delta.content, "generating", opts)
|
||||
end
|
||||
end
|
||||
if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" then
|
||||
if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" or choice.finish_reason == "length" then
|
||||
self:finish_pending_messages(ctx, opts)
|
||||
opts.on_stop({ reason = "complete" })
|
||||
if ctx.tool_use_list and #ctx.tool_use_list > 0 then
|
||||
opts.on_stop({ reason = "tool_use", usage = jsn.usage })
|
||||
else
|
||||
opts.on_stop({ reason = "complete", usage = jsn.usage })
|
||||
end
|
||||
end
|
||||
if choice.finish_reason == "tool_calls" then
|
||||
self:finish_pending_messages(ctx, opts)
|
||||
@@ -372,8 +487,10 @@ function M:parse_curl_args(prompt_opts)
|
||||
|
||||
self.set_allowed_params(provider_conf, request_body)
|
||||
|
||||
local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true
|
||||
|
||||
local tools = nil
|
||||
if not disable_tools and prompt_opts.tools then
|
||||
if not disable_tools and prompt_opts.tools and not use_ReAct_prompt then
|
||||
tools = {}
|
||||
for _, tool in ipairs(prompt_opts.tools) do
|
||||
table.insert(tools, self:transform_tool(tool))
|
||||
|
||||
Reference in New Issue
Block a user