fix: gemini ReAct (#2132)

This commit is contained in:
yetone
2025-06-04 01:14:06 +08:00
committed by GitHub
parent fe8e34ce86
commit fc0e78e88f
7 changed files with 46 additions and 24 deletions

View File

@@ -282,7 +282,7 @@ M._defaults = {
use_ReAct_prompt = true, use_ReAct_prompt = true,
extra_request_body = { extra_request_body = {
temperature = 0.75, temperature = 0.75,
max_tokens = 8192, max_tokens = 65536,
}, },
}, },
---@type AvanteSupportedProvider ---@type AvanteSupportedProvider

View File

@@ -33,11 +33,11 @@ M.param = {
description = [[ description = [[
One or more SEARCH/REPLACE blocks following this exact format: One or more SEARCH/REPLACE blocks following this exact format:
\`\`\` \`\`\`
<<<<<<< SEARCH ------- SEARCH
[exact content to find] [exact content to find]
======= =======
[new content to replace with] [new content to replace with]
>>>>>>> REPLACE +++++++ REPLACE
\`\`\` \`\`\`
Critical rules: Critical rules:
1. SEARCH content must match the associated file section to find EXACTLY: 1. SEARCH content must match the associated file section to find EXACTLY:
@@ -85,7 +85,7 @@ M.returns = {
---@param diff string ---@param diff string
---@return string ---@return string
local function fix_diff(diff) local function fix_diff(diff)
local has_search_line = diff:match("^%s*<<<<<<<* SEARCH") ~= nil local has_search_line = diff:match("^%s*-------* SEARCH") ~= nil
if has_search_line then return diff end if has_search_line then return diff end
local fixed_diff_lines = {} local fixed_diff_lines = {}
@@ -93,10 +93,10 @@ local function fix_diff(diff)
local first_line = lines[1] local first_line = lines[1]
if first_line and first_line:match("^%s*```") then if first_line and first_line:match("^%s*```") then
table.insert(fixed_diff_lines, first_line) table.insert(fixed_diff_lines, first_line)
table.insert(fixed_diff_lines, "<<<<<<< SEARCH") table.insert(fixed_diff_lines, "------- SEARCH")
fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 2) fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 2)
else else
table.insert(fixed_diff_lines, "<<<<<<< SEARCH") table.insert(fixed_diff_lines, "------- SEARCH")
fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 1) fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 1)
end end
return table.concat(fixed_diff_lines, "\n") return table.concat(fixed_diff_lines, "\n")
@@ -125,7 +125,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
local rough_diff_blocks = {} local rough_diff_blocks = {}
for _, line in ipairs(diff_lines) do for _, line in ipairs(diff_lines) do
if line:match("^%s*<<<<<<<* SEARCH") then if line:match("^%s*-------* SEARCH") then
is_searching = true is_searching = true
is_replacing = false is_replacing = false
current_search = {} current_search = {}
@@ -133,7 +133,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
is_searching = false is_searching = false
is_replacing = true is_replacing = true
current_replace = {} current_replace = {}
elseif line:match("^%s*>>>>>>>* REPLACE") and is_replacing then elseif line:match("^%s*+++++++* REPLACE") and is_replacing then
is_replacing = false is_replacing = false
table.insert( table.insert(
rough_diff_blocks, rough_diff_blocks,

View File

@@ -57,8 +57,8 @@ M.returns = {
---@type AvanteLLMToolFunc<{ path: string, old_str: string, new_str: string, streaming?: boolean, tool_use_id?: string }> ---@type AvanteLLMToolFunc<{ path: string, old_str: string, new_str: string, streaming?: boolean, tool_use_id?: string }>
function M.func(opts, on_log, on_complete, session_ctx) function M.func(opts, on_log, on_complete, session_ctx)
local replace_in_file = require("avante.llm_tools.replace_in_file") local replace_in_file = require("avante.llm_tools.replace_in_file")
local diff = "<<<<<<< SEARCH\n" .. opts.old_str .. "\n=======\n" .. opts.new_str local diff = "------- SEARCH\n" .. opts.old_str .. "\n=======\n" .. opts.new_str
if not opts.streaming then diff = diff .. "\n>>>>>>> REPLACE" end if not opts.streaming then diff = diff .. "\n+++++++ REPLACE" end
local new_opts = { local new_opts = {
path = opts.path, path = opts.path,
diff = diff, diff = diff,

View File

@@ -123,6 +123,12 @@ function M:parse_messages(opts)
end end
end end
if tool_use then if tool_use then
table.insert(contents, {
role = "model",
parts = {
{ text = Utils.tool_use_to_xml(tool_use.content[1]) },
},
})
role = "user" role = "user"
table.insert(parts, { table.insert(parts, {
text = "[" text = "["
@@ -187,9 +193,11 @@ function M.prepare_request_body(provider_instance, prompt_opts, provider_conf, r
request_body.temperature = nil request_body.temperature = nil
request_body.max_tokens = nil request_body.max_tokens = nil
local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true
local disable_tools = provider_conf.disable_tools or false local disable_tools = provider_conf.disable_tools or false
if not disable_tools and prompt_opts.tools then if not use_ReAct_prompt and not disable_tools and prompt_opts.tools then
local function_declarations = {} local function_declarations = {}
for _, tool in ipairs(prompt_opts.tools) do for _, tool in ipairs(prompt_opts.tools) do
table.insert(function_declarations, provider_instance:transform_to_function_declaration(tool)) table.insert(function_declarations, provider_instance:transform_to_function_declaration(tool))
@@ -229,7 +237,7 @@ function M:parse_response(ctx, data_stream, _, opts)
if json.candidates and #json.candidates > 0 then if json.candidates and #json.candidates > 0 then
local candidate = json.candidates[1] local candidate = json.candidates[1]
---@type AvanteLLMToolUse[] ---@type AvanteLLMToolUse[]
local tool_use_list = {} ctx.tool_use_list = ctx.tool_use_list or {}
-- Check if candidate.content and candidate.content.parts exist before iterating -- Check if candidate.content and candidate.content.parts exist before iterating
if candidate.content and candidate.content.parts then if candidate.content and candidate.content.parts then
@@ -245,7 +253,7 @@ function M:parse_response(ctx, data_stream, _, opts)
name = part.functionCall.name, name = part.functionCall.name,
input_json = vim.json.encode(part.functionCall.args), input_json = vim.json.encode(part.functionCall.args),
} }
table.insert(tool_use_list, tool_use) table.insert(ctx.tool_use_list, tool_use)
OpenAI:add_tool_use_message(tool_use, "generated", opts) OpenAI:add_tool_use_message(tool_use, "generated", opts)
end end
end end
@@ -262,7 +270,7 @@ function M:parse_response(ctx, data_stream, _, opts)
-- The tool_use list is added to the table in llm.lua -- The tool_use list is added to the table in llm.lua
opts.on_stop(vim.tbl_deep_extend("force", { reason = "tool_use" }, stop_details)) opts.on_stop(vim.tbl_deep_extend("force", { reason = "tool_use" }, stop_details))
elseif reason_str == "STOP" then elseif reason_str == "STOP" then
if #tool_use_list > 0 then if ctx.tool_use_list and #ctx.tool_use_list > 0 then
-- Natural stop, but tools were found in this final chunk. -- Natural stop, but tools were found in this final chunk.
opts.on_stop(vim.tbl_deep_extend("force", { reason = "tool_use" }, stop_details)) opts.on_stop(vim.tbl_deep_extend("force", { reason = "tool_use" }, stop_details))
else else

View File

@@ -226,6 +226,8 @@ function M:add_text_message(ctx, text, state, opts)
if llm_tool_names == nil then llm_tool_names = LlmTools.get_tool_names() end if llm_tool_names == nil then llm_tool_names = LlmTools.get_tool_names() end
if ctx.content == nil then ctx.content = "" end if ctx.content == nil then ctx.content = "" end
ctx.content = ctx.content .. text ctx.content = ctx.content .. text
local content = ctx.content:gsub("<tool_code>", ""):gsub("</tool_code>", "")
ctx.content = content
local msg = HistoryMessage:new({ local msg = HistoryMessage:new({
role = "assistant", role = "assistant",
content = ctx.content, content = ctx.content,

View File

@@ -1424,6 +1424,16 @@ function M.get_tool_use_message(message, messages)
return nil return nil
end end
---@param tool_use AvanteLLMToolUse
function M.tool_use_to_xml(tool_use)
local xml = string.format("<%s>\n", tool_use.name)
for k, v in pairs(tool_use.input or {}) do
xml = xml .. string.format("<%s>%s</%s>\n", k, tostring(v), k)
end
xml = xml .. "</" .. tool_use.name .. ">"
return xml
end
---@param tool_use AvanteLLMToolUse ---@param tool_use AvanteLLMToolUse
function M.is_replace_func_call_tool_use(tool_use) function M.is_replace_func_call_tool_use(tool_use)
local is_replace_func_call = false local is_replace_func_call = false

View File

@@ -26,11 +26,13 @@ Tool use is formatted using XML-style tags. The tool name is enclosed in opening
For example: For example:
<view> <attempt_completion>
<path>src/main.js</path> <result>
</view> I have completed the task...
</result>
</attempt_completion>
Always adhere to this format for the tool use to ensure proper parsing and execution. ALWAYS ADHERE TO this format for the tool use to ensure proper parsing and execution.
# Tools # Tools
@@ -107,22 +109,22 @@ Parameters:
<replace_in_file> <replace_in_file>
<path>src/components/App.tsx</path> <path>src/components/App.tsx</path>
<diff> <diff>
<<<<<<< SEARCH ------- SEARCH
import React from 'react'; import React from 'react';
======= =======
import React, { useState } from 'react'; import React, { useState } from 'react';
>>>>>>> REPLACE +++++++ REPLACE
<<<<<<< SEARCH ------- SEARCH
function handleSubmit() { function handleSubmit() {
saveData(); saveData();
setLoading(false); setLoading(false);
} }
======= =======
>>>>>>> REPLACE +++++++ REPLACE
<<<<<<< SEARCH ------- SEARCH
return ( return (
<div> <div>
======= =======
@@ -133,7 +135,7 @@ function handleSubmit() {
return ( return (
<div> <div>
>>>>>>> REPLACE +++++++ REPLACE
</diff> </diff>
</replace_in_file> </replace_in_file>