feat: gemini support ReAct (#2125)
This commit is contained in:
@@ -279,6 +279,7 @@ M._defaults = {
|
|||||||
endpoint = "https://generativelanguage.googleapis.com/v1beta/models",
|
endpoint = "https://generativelanguage.googleapis.com/v1beta/models",
|
||||||
model = "gemini-2.0-flash",
|
model = "gemini-2.0-flash",
|
||||||
timeout = 30000, -- Timeout in milliseconds
|
timeout = 30000, -- Timeout in milliseconds
|
||||||
|
use_ReAct_prompt = true,
|
||||||
extra_request_body = {
|
extra_request_body = {
|
||||||
temperature = 0.75,
|
temperature = 0.75,
|
||||||
max_tokens = 8192,
|
max_tokens = 8192,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ local Utils = require("avante.utils")
|
|||||||
local Providers = require("avante.providers")
|
local Providers = require("avante.providers")
|
||||||
local Clipboard = require("avante.clipboard")
|
local Clipboard = require("avante.clipboard")
|
||||||
local OpenAI = require("avante.providers").openai
|
local OpenAI = require("avante.providers").openai
|
||||||
|
local Prompts = require("avante.utils.prompts")
|
||||||
|
|
||||||
---@class AvanteProviderFunctor
|
---@class AvanteProviderFunctor
|
||||||
local M = {}
|
local M = {}
|
||||||
@@ -33,6 +34,9 @@ function M:transform_to_function_declaration(tool)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function M:parse_messages(opts)
|
function M:parse_messages(opts)
|
||||||
|
local provider_conf, _ = Providers.parse_config(self)
|
||||||
|
local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true
|
||||||
|
|
||||||
local contents = {}
|
local contents = {}
|
||||||
local prev_role = nil
|
local prev_role = nil
|
||||||
|
|
||||||
@@ -72,7 +76,7 @@ function M:parse_messages(opts)
|
|||||||
data = item.source.data,
|
data = item.source.data,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
elseif type(item) == "table" and item.type == "tool_use" then
|
elseif type(item) == "table" and item.type == "tool_use" and not use_ReAct_prompt then
|
||||||
tool_id_to_name[item.id] = item.name
|
tool_id_to_name[item.id] = item.name
|
||||||
role = "model"
|
role = "model"
|
||||||
table.insert(parts, {
|
table.insert(parts, {
|
||||||
@@ -81,7 +85,7 @@ function M:parse_messages(opts)
|
|||||||
args = item.input,
|
args = item.input,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
elseif type(item) == "table" and item.type == "tool_result" then
|
elseif type(item) == "table" and item.type == "tool_result" and not use_ReAct_prompt then
|
||||||
role = "function"
|
role = "function"
|
||||||
local ok, content = pcall(vim.json.decode, item.content)
|
local ok, content = pcall(vim.json.decode, item.content)
|
||||||
if not ok then content = item.content end
|
if not ok then content = item.content end
|
||||||
@@ -107,8 +111,34 @@ function M:parse_messages(opts)
|
|||||||
table.insert(parts, { text = item.data })
|
table.insert(parts, { text = item.data })
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
if not provider_conf.disable_tools and use_ReAct_prompt then
|
||||||
|
if content_items[1].type == "tool_result" then
|
||||||
|
local tool_use = nil
|
||||||
|
for _, msg_ in ipairs(opts.messages) do
|
||||||
|
if type(msg_.content) == "table" and #msg_.content > 0 then
|
||||||
|
if msg_.content[1].type == "tool_use" and msg_.content[1].id == content_items[1].tool_use_id then
|
||||||
|
tool_use = msg_
|
||||||
|
break
|
||||||
end
|
end
|
||||||
table.insert(contents, { role = M.role_map[role] or role, parts = parts })
|
end
|
||||||
|
end
|
||||||
|
if tool_use then
|
||||||
|
role = "user"
|
||||||
|
table.insert(parts, {
|
||||||
|
text = "["
|
||||||
|
.. tool_use.content[1].name
|
||||||
|
.. " for '"
|
||||||
|
.. (tool_use.content[1].input.path or tool_use.content[1].input.rel_path or "")
|
||||||
|
.. "'] Result:",
|
||||||
|
})
|
||||||
|
table.insert(parts, {
|
||||||
|
text = content_items[1].content,
|
||||||
|
})
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
if #parts > 0 then table.insert(contents, { role = M.role_map[role] or role, parts = parts }) end
|
||||||
end)
|
end)
|
||||||
|
|
||||||
if Clipboard.support_paste_image() and opts.image_paths then
|
if Clipboard.support_paste_image() and opts.image_paths then
|
||||||
@@ -124,12 +154,16 @@ function M:parse_messages(opts)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local system_prompt = opts.system_prompt
|
||||||
|
|
||||||
|
if use_ReAct_prompt then system_prompt = Prompts.get_ReAct_system_prompt(provider_conf, opts) end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
systemInstruction = {
|
systemInstruction = {
|
||||||
role = "user",
|
role = "user",
|
||||||
parts = {
|
parts = {
|
||||||
{
|
{
|
||||||
text = opts.system_prompt,
|
text = system_prompt,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -136,6 +136,18 @@ return (
|
|||||||
>>>>>>> REPLACE
|
>>>>>>> REPLACE
|
||||||
</diff>
|
</diff>
|
||||||
</replace_in_file>
|
</replace_in_file>
|
||||||
|
|
||||||
|
## Example 4: Complete current task
|
||||||
|
|
||||||
|
<attempt_completion>
|
||||||
|
<result>
|
||||||
|
I've successfully created the requested React component with the following features:
|
||||||
|
- Responsive layout
|
||||||
|
- Dark/light mode toggle
|
||||||
|
- Form validation
|
||||||
|
- API integration
|
||||||
|
</result>
|
||||||
|
</attempt_completion>
|
||||||
]]
|
]]
|
||||||
end
|
end
|
||||||
return system_prompt
|
return system_prompt
|
||||||
|
|||||||
Reference in New Issue
Block a user