fix: gemini ReAct (#2132)
This commit is contained in:
@@ -282,7 +282,7 @@ M._defaults = {
|
|||||||
use_ReAct_prompt = true,
|
use_ReAct_prompt = true,
|
||||||
extra_request_body = {
|
extra_request_body = {
|
||||||
temperature = 0.75,
|
temperature = 0.75,
|
||||||
max_tokens = 8192,
|
max_tokens = 65536,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
---@type AvanteSupportedProvider
|
---@type AvanteSupportedProvider
|
||||||
|
|||||||
@@ -33,11 +33,11 @@ M.param = {
|
|||||||
description = [[
|
description = [[
|
||||||
One or more SEARCH/REPLACE blocks following this exact format:
|
One or more SEARCH/REPLACE blocks following this exact format:
|
||||||
\`\`\`
|
\`\`\`
|
||||||
<<<<<<< SEARCH
|
------- SEARCH
|
||||||
[exact content to find]
|
[exact content to find]
|
||||||
=======
|
=======
|
||||||
[new content to replace with]
|
[new content to replace with]
|
||||||
>>>>>>> REPLACE
|
+++++++ REPLACE
|
||||||
\`\`\`
|
\`\`\`
|
||||||
Critical rules:
|
Critical rules:
|
||||||
1. SEARCH content must match the associated file section to find EXACTLY:
|
1. SEARCH content must match the associated file section to find EXACTLY:
|
||||||
@@ -85,7 +85,7 @@ M.returns = {
|
|||||||
---@param diff string
|
---@param diff string
|
||||||
---@return string
|
---@return string
|
||||||
local function fix_diff(diff)
|
local function fix_diff(diff)
|
||||||
local has_search_line = diff:match("^%s*<<<<<<<* SEARCH") ~= nil
|
local has_search_line = diff:match("^%s*-------* SEARCH") ~= nil
|
||||||
if has_search_line then return diff end
|
if has_search_line then return diff end
|
||||||
|
|
||||||
local fixed_diff_lines = {}
|
local fixed_diff_lines = {}
|
||||||
@@ -93,10 +93,10 @@ local function fix_diff(diff)
|
|||||||
local first_line = lines[1]
|
local first_line = lines[1]
|
||||||
if first_line and first_line:match("^%s*```") then
|
if first_line and first_line:match("^%s*```") then
|
||||||
table.insert(fixed_diff_lines, first_line)
|
table.insert(fixed_diff_lines, first_line)
|
||||||
table.insert(fixed_diff_lines, "<<<<<<< SEARCH")
|
table.insert(fixed_diff_lines, "------- SEARCH")
|
||||||
fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 2)
|
fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 2)
|
||||||
else
|
else
|
||||||
table.insert(fixed_diff_lines, "<<<<<<< SEARCH")
|
table.insert(fixed_diff_lines, "------- SEARCH")
|
||||||
fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 1)
|
fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 1)
|
||||||
end
|
end
|
||||||
return table.concat(fixed_diff_lines, "\n")
|
return table.concat(fixed_diff_lines, "\n")
|
||||||
@@ -125,7 +125,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
|||||||
local rough_diff_blocks = {}
|
local rough_diff_blocks = {}
|
||||||
|
|
||||||
for _, line in ipairs(diff_lines) do
|
for _, line in ipairs(diff_lines) do
|
||||||
if line:match("^%s*<<<<<<<* SEARCH") then
|
if line:match("^%s*-------* SEARCH") then
|
||||||
is_searching = true
|
is_searching = true
|
||||||
is_replacing = false
|
is_replacing = false
|
||||||
current_search = {}
|
current_search = {}
|
||||||
@@ -133,7 +133,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
|||||||
is_searching = false
|
is_searching = false
|
||||||
is_replacing = true
|
is_replacing = true
|
||||||
current_replace = {}
|
current_replace = {}
|
||||||
elseif line:match("^%s*>>>>>>>* REPLACE") and is_replacing then
|
elseif line:match("^%s*+++++++* REPLACE") and is_replacing then
|
||||||
is_replacing = false
|
is_replacing = false
|
||||||
table.insert(
|
table.insert(
|
||||||
rough_diff_blocks,
|
rough_diff_blocks,
|
||||||
|
|||||||
@@ -57,8 +57,8 @@ M.returns = {
|
|||||||
---@type AvanteLLMToolFunc<{ path: string, old_str: string, new_str: string, streaming?: boolean, tool_use_id?: string }>
|
---@type AvanteLLMToolFunc<{ path: string, old_str: string, new_str: string, streaming?: boolean, tool_use_id?: string }>
|
||||||
function M.func(opts, on_log, on_complete, session_ctx)
|
function M.func(opts, on_log, on_complete, session_ctx)
|
||||||
local replace_in_file = require("avante.llm_tools.replace_in_file")
|
local replace_in_file = require("avante.llm_tools.replace_in_file")
|
||||||
local diff = "<<<<<<< SEARCH\n" .. opts.old_str .. "\n=======\n" .. opts.new_str
|
local diff = "------- SEARCH\n" .. opts.old_str .. "\n=======\n" .. opts.new_str
|
||||||
if not opts.streaming then diff = diff .. "\n>>>>>>> REPLACE" end
|
if not opts.streaming then diff = diff .. "\n+++++++ REPLACE" end
|
||||||
local new_opts = {
|
local new_opts = {
|
||||||
path = opts.path,
|
path = opts.path,
|
||||||
diff = diff,
|
diff = diff,
|
||||||
|
|||||||
@@ -123,6 +123,12 @@ function M:parse_messages(opts)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
if tool_use then
|
if tool_use then
|
||||||
|
table.insert(contents, {
|
||||||
|
role = "model",
|
||||||
|
parts = {
|
||||||
|
{ text = Utils.tool_use_to_xml(tool_use.content[1]) },
|
||||||
|
},
|
||||||
|
})
|
||||||
role = "user"
|
role = "user"
|
||||||
table.insert(parts, {
|
table.insert(parts, {
|
||||||
text = "["
|
text = "["
|
||||||
@@ -187,9 +193,11 @@ function M.prepare_request_body(provider_instance, prompt_opts, provider_conf, r
|
|||||||
request_body.temperature = nil
|
request_body.temperature = nil
|
||||||
request_body.max_tokens = nil
|
request_body.max_tokens = nil
|
||||||
|
|
||||||
|
local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true
|
||||||
|
|
||||||
local disable_tools = provider_conf.disable_tools or false
|
local disable_tools = provider_conf.disable_tools or false
|
||||||
|
|
||||||
if not disable_tools and prompt_opts.tools then
|
if not use_ReAct_prompt and not disable_tools and prompt_opts.tools then
|
||||||
local function_declarations = {}
|
local function_declarations = {}
|
||||||
for _, tool in ipairs(prompt_opts.tools) do
|
for _, tool in ipairs(prompt_opts.tools) do
|
||||||
table.insert(function_declarations, provider_instance:transform_to_function_declaration(tool))
|
table.insert(function_declarations, provider_instance:transform_to_function_declaration(tool))
|
||||||
@@ -229,7 +237,7 @@ function M:parse_response(ctx, data_stream, _, opts)
|
|||||||
if json.candidates and #json.candidates > 0 then
|
if json.candidates and #json.candidates > 0 then
|
||||||
local candidate = json.candidates[1]
|
local candidate = json.candidates[1]
|
||||||
---@type AvanteLLMToolUse[]
|
---@type AvanteLLMToolUse[]
|
||||||
local tool_use_list = {}
|
ctx.tool_use_list = ctx.tool_use_list or {}
|
||||||
|
|
||||||
-- Check if candidate.content and candidate.content.parts exist before iterating
|
-- Check if candidate.content and candidate.content.parts exist before iterating
|
||||||
if candidate.content and candidate.content.parts then
|
if candidate.content and candidate.content.parts then
|
||||||
@@ -245,7 +253,7 @@ function M:parse_response(ctx, data_stream, _, opts)
|
|||||||
name = part.functionCall.name,
|
name = part.functionCall.name,
|
||||||
input_json = vim.json.encode(part.functionCall.args),
|
input_json = vim.json.encode(part.functionCall.args),
|
||||||
}
|
}
|
||||||
table.insert(tool_use_list, tool_use)
|
table.insert(ctx.tool_use_list, tool_use)
|
||||||
OpenAI:add_tool_use_message(tool_use, "generated", opts)
|
OpenAI:add_tool_use_message(tool_use, "generated", opts)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@@ -262,7 +270,7 @@ function M:parse_response(ctx, data_stream, _, opts)
|
|||||||
-- The tool_use list is added to the table in llm.lua
|
-- The tool_use list is added to the table in llm.lua
|
||||||
opts.on_stop(vim.tbl_deep_extend("force", { reason = "tool_use" }, stop_details))
|
opts.on_stop(vim.tbl_deep_extend("force", { reason = "tool_use" }, stop_details))
|
||||||
elseif reason_str == "STOP" then
|
elseif reason_str == "STOP" then
|
||||||
if #tool_use_list > 0 then
|
if ctx.tool_use_list and #ctx.tool_use_list > 0 then
|
||||||
-- Natural stop, but tools were found in this final chunk.
|
-- Natural stop, but tools were found in this final chunk.
|
||||||
opts.on_stop(vim.tbl_deep_extend("force", { reason = "tool_use" }, stop_details))
|
opts.on_stop(vim.tbl_deep_extend("force", { reason = "tool_use" }, stop_details))
|
||||||
else
|
else
|
||||||
|
|||||||
@@ -226,6 +226,8 @@ function M:add_text_message(ctx, text, state, opts)
|
|||||||
if llm_tool_names == nil then llm_tool_names = LlmTools.get_tool_names() end
|
if llm_tool_names == nil then llm_tool_names = LlmTools.get_tool_names() end
|
||||||
if ctx.content == nil then ctx.content = "" end
|
if ctx.content == nil then ctx.content = "" end
|
||||||
ctx.content = ctx.content .. text
|
ctx.content = ctx.content .. text
|
||||||
|
local content = ctx.content:gsub("<tool_code>", ""):gsub("</tool_code>", "")
|
||||||
|
ctx.content = content
|
||||||
local msg = HistoryMessage:new({
|
local msg = HistoryMessage:new({
|
||||||
role = "assistant",
|
role = "assistant",
|
||||||
content = ctx.content,
|
content = ctx.content,
|
||||||
|
|||||||
@@ -1424,6 +1424,16 @@ function M.get_tool_use_message(message, messages)
|
|||||||
return nil
|
return nil
|
||||||
end
|
end
|
||||||
|
|
||||||
|
---@param tool_use AvanteLLMToolUse
|
||||||
|
function M.tool_use_to_xml(tool_use)
|
||||||
|
local xml = string.format("<%s>\n", tool_use.name)
|
||||||
|
for k, v in pairs(tool_use.input or {}) do
|
||||||
|
xml = xml .. string.format("<%s>%s</%s>\n", k, tostring(v), k)
|
||||||
|
end
|
||||||
|
xml = xml .. "</" .. tool_use.name .. ">"
|
||||||
|
return xml
|
||||||
|
end
|
||||||
|
|
||||||
---@param tool_use AvanteLLMToolUse
|
---@param tool_use AvanteLLMToolUse
|
||||||
function M.is_replace_func_call_tool_use(tool_use)
|
function M.is_replace_func_call_tool_use(tool_use)
|
||||||
local is_replace_func_call = false
|
local is_replace_func_call = false
|
||||||
|
|||||||
@@ -26,11 +26,13 @@ Tool use is formatted using XML-style tags. The tool name is enclosed in opening
|
|||||||
|
|
||||||
For example:
|
For example:
|
||||||
|
|
||||||
<view>
|
<attempt_completion>
|
||||||
<path>src/main.js</path>
|
<result>
|
||||||
</view>
|
I have completed the task...
|
||||||
|
</result>
|
||||||
|
</attempt_completion>
|
||||||
|
|
||||||
Always adhere to this format for the tool use to ensure proper parsing and execution.
|
ALWAYS ADHERE TO this format for the tool use to ensure proper parsing and execution.
|
||||||
|
|
||||||
# Tools
|
# Tools
|
||||||
|
|
||||||
@@ -107,22 +109,22 @@ Parameters:
|
|||||||
<replace_in_file>
|
<replace_in_file>
|
||||||
<path>src/components/App.tsx</path>
|
<path>src/components/App.tsx</path>
|
||||||
<diff>
|
<diff>
|
||||||
<<<<<<< SEARCH
|
------- SEARCH
|
||||||
import React from 'react';
|
import React from 'react';
|
||||||
=======
|
=======
|
||||||
import React, { useState } from 'react';
|
import React, { useState } from 'react';
|
||||||
>>>>>>> REPLACE
|
+++++++ REPLACE
|
||||||
|
|
||||||
<<<<<<< SEARCH
|
------- SEARCH
|
||||||
function handleSubmit() {
|
function handleSubmit() {
|
||||||
saveData();
|
saveData();
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
=======
|
=======
|
||||||
>>>>>>> REPLACE
|
+++++++ REPLACE
|
||||||
|
|
||||||
<<<<<<< SEARCH
|
------- SEARCH
|
||||||
return (
|
return (
|
||||||
<div>
|
<div>
|
||||||
=======
|
=======
|
||||||
@@ -133,7 +135,7 @@ function handleSubmit() {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<div>
|
<div>
|
||||||
>>>>>>> REPLACE
|
+++++++ REPLACE
|
||||||
</diff>
|
</diff>
|
||||||
</replace_in_file>
|
</replace_in_file>
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user