fix: limit tool use count (#2294)
This commit is contained in:
@@ -5,7 +5,7 @@ local M = {}
|
|||||||
M.__index = M
|
M.__index = M
|
||||||
|
|
||||||
---@param message AvanteLLMMessage
|
---@param message AvanteLLMMessage
|
||||||
---@param opts? {is_user_submission?: boolean, visible?: boolean, displayed_content?: string, state?: avante.HistoryMessageState, uuid?: string, selected_filepaths?: string[], selected_code?: AvanteSelectedCode, just_for_display?: boolean, is_dummy?: boolean}
|
---@param opts? {is_user_submission?: boolean, visible?: boolean, displayed_content?: string, state?: avante.HistoryMessageState, uuid?: string, selected_filepaths?: string[], selected_code?: AvanteSelectedCode, just_for_display?: boolean, is_dummy?: boolean, session_id?: string}
|
||||||
---@return avante.HistoryMessage
|
---@return avante.HistoryMessage
|
||||||
function M:new(message, opts)
|
function M:new(message, opts)
|
||||||
opts = opts or {}
|
opts = opts or {}
|
||||||
@@ -23,6 +23,7 @@ function M:new(message, opts)
|
|||||||
if opts.selected_code ~= nil then obj.selected_code = opts.selected_code end
|
if opts.selected_code ~= nil then obj.selected_code = opts.selected_code end
|
||||||
if opts.just_for_display ~= nil then obj.just_for_display = opts.just_for_display end
|
if opts.just_for_display ~= nil then obj.just_for_display = opts.just_for_display end
|
||||||
if opts.is_dummy ~= nil then obj.is_dummy = opts.is_dummy end
|
if opts.is_dummy ~= nil then obj.is_dummy = opts.is_dummy end
|
||||||
|
obj.session_id = opts.session_id
|
||||||
return obj
|
return obj
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
@@ -166,6 +166,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
|||||||
content = content_block.text,
|
content = content_block.text,
|
||||||
}, {
|
}, {
|
||||||
state = "generating",
|
state = "generating",
|
||||||
|
session_id = ctx.session_id,
|
||||||
})
|
})
|
||||||
content_block.uuid = msg.uuid
|
content_block.uuid = msg.uuid
|
||||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||||
@@ -184,6 +185,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
|||||||
},
|
},
|
||||||
}, {
|
}, {
|
||||||
state = "generating",
|
state = "generating",
|
||||||
|
session_id = ctx.session_id,
|
||||||
})
|
})
|
||||||
content_block.uuid = msg.uuid
|
content_block.uuid = msg.uuid
|
||||||
opts.on_messages_add({ msg })
|
opts.on_messages_add({ msg })
|
||||||
@@ -203,6 +205,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
|||||||
},
|
},
|
||||||
}, {
|
}, {
|
||||||
state = "generating",
|
state = "generating",
|
||||||
|
session_id = ctx.session_id,
|
||||||
})
|
})
|
||||||
content_block.uuid = msg.uuid
|
content_block.uuid = msg.uuid
|
||||||
opts.on_messages_add({ msg })
|
opts.on_messages_add({ msg })
|
||||||
@@ -231,6 +234,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
|||||||
}, {
|
}, {
|
||||||
state = "generating",
|
state = "generating",
|
||||||
uuid = content_block.uuid,
|
uuid = content_block.uuid,
|
||||||
|
session_id = ctx.session_id,
|
||||||
})
|
})
|
||||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||||
elseif jsn.delta.type == "text_delta" then
|
elseif jsn.delta.type == "text_delta" then
|
||||||
@@ -242,6 +246,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
|||||||
}, {
|
}, {
|
||||||
state = "generating",
|
state = "generating",
|
||||||
uuid = content_block.uuid,
|
uuid = content_block.uuid,
|
||||||
|
session_id = ctx.session_id,
|
||||||
})
|
})
|
||||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||||
elseif jsn.delta.type == "signature_delta" then
|
elseif jsn.delta.type == "signature_delta" then
|
||||||
@@ -260,6 +265,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
|||||||
}, {
|
}, {
|
||||||
state = "generated",
|
state = "generated",
|
||||||
uuid = content_block.uuid,
|
uuid = content_block.uuid,
|
||||||
|
session_id = ctx.session_id,
|
||||||
})
|
})
|
||||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||||
end
|
end
|
||||||
@@ -278,6 +284,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
|||||||
}, {
|
}, {
|
||||||
state = "generated",
|
state = "generated",
|
||||||
uuid = content_block.uuid,
|
uuid = content_block.uuid,
|
||||||
|
session_id = ctx.session_id,
|
||||||
})
|
})
|
||||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||||
end
|
end
|
||||||
@@ -301,6 +308,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
|||||||
}, {
|
}, {
|
||||||
state = "generated",
|
state = "generated",
|
||||||
uuid = content_block.uuid,
|
uuid = content_block.uuid,
|
||||||
|
session_id = ctx.session_id,
|
||||||
})
|
})
|
||||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||||
end
|
end
|
||||||
|
|||||||
@@ -248,7 +248,7 @@ function M:parse_response(ctx, data_stream, _, opts)
|
|||||||
input_json = vim.json.encode(part.functionCall.args),
|
input_json = vim.json.encode(part.functionCall.args),
|
||||||
}
|
}
|
||||||
table.insert(ctx.tool_use_list, tool_use)
|
table.insert(ctx.tool_use_list, tool_use)
|
||||||
OpenAI:add_tool_use_message(tool_use, "generated", opts)
|
OpenAI:add_tool_use_message(ctx, tool_use, "generated", opts)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|||||||
@@ -215,7 +215,7 @@ function M:finish_pending_messages(ctx, opts)
|
|||||||
if ctx.content ~= nil and ctx.content ~= "" then self:add_text_message(ctx, "", "generated", opts) end
|
if ctx.content ~= nil and ctx.content ~= "" then self:add_text_message(ctx, "", "generated", opts) end
|
||||||
if ctx.tool_use_list then
|
if ctx.tool_use_list then
|
||||||
for _, tool_use in ipairs(ctx.tool_use_list) do
|
for _, tool_use in ipairs(ctx.tool_use_list) do
|
||||||
if tool_use.state == "generating" then self:add_tool_use_message(tool_use, "generated", opts) end
|
if tool_use.state == "generating" then self:add_tool_use_message(ctx, tool_use, "generated", opts) end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@@ -318,6 +318,7 @@ function M:add_text_message(ctx, text, state, opts)
|
|||||||
}, {
|
}, {
|
||||||
state = state,
|
state = state,
|
||||||
uuid = msg_uuid,
|
uuid = msg_uuid,
|
||||||
|
session_id = ctx.session_id,
|
||||||
})
|
})
|
||||||
msgs[#msgs + 1] = msg_
|
msgs[#msgs + 1] = msg_
|
||||||
ctx.tool_use_list = ctx.tool_use_list or {}
|
ctx.tool_use_list = ctx.tool_use_list or {}
|
||||||
@@ -351,12 +352,13 @@ function M:add_thinking_message(ctx, text, state, opts)
|
|||||||
}, {
|
}, {
|
||||||
state = state,
|
state = state,
|
||||||
uuid = ctx.reasonging_content_uuid,
|
uuid = ctx.reasonging_content_uuid,
|
||||||
|
session_id = ctx.session_id,
|
||||||
})
|
})
|
||||||
ctx.reasonging_content_uuid = msg.uuid
|
ctx.reasonging_content_uuid = msg.uuid
|
||||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||||
end
|
end
|
||||||
|
|
||||||
function M:add_tool_use_message(tool_use, state, opts)
|
function M:add_tool_use_message(ctx, tool_use, state, opts)
|
||||||
local jsn = JsonParser.parse(tool_use.input_json)
|
local jsn = JsonParser.parse(tool_use.input_json)
|
||||||
local msg = HistoryMessage:new({
|
local msg = HistoryMessage:new({
|
||||||
role = "assistant",
|
role = "assistant",
|
||||||
@@ -371,6 +373,7 @@ function M:add_tool_use_message(tool_use, state, opts)
|
|||||||
}, {
|
}, {
|
||||||
state = state,
|
state = state,
|
||||||
uuid = tool_use.uuid,
|
uuid = tool_use.uuid,
|
||||||
|
session_id = ctx.session_id,
|
||||||
})
|
})
|
||||||
tool_use.uuid = msg.uuid
|
tool_use.uuid = msg.uuid
|
||||||
tool_use.state = state
|
tool_use.state = state
|
||||||
@@ -438,7 +441,7 @@ function M:parse_response(ctx, data_stream, _, opts)
|
|||||||
if not ctx.tool_use_list[tool_call.index + 1] then
|
if not ctx.tool_use_list[tool_call.index + 1] then
|
||||||
if tool_call.index > 0 and ctx.tool_use_list[tool_call.index] then
|
if tool_call.index > 0 and ctx.tool_use_list[tool_call.index] then
|
||||||
local prev_tool_use = ctx.tool_use_list[tool_call.index]
|
local prev_tool_use = ctx.tool_use_list[tool_call.index]
|
||||||
self:add_tool_use_message(prev_tool_use, "generated", opts)
|
self:add_tool_use_message(ctx, prev_tool_use, "generated", opts)
|
||||||
end
|
end
|
||||||
local tool_use = {
|
local tool_use = {
|
||||||
name = tool_call["function"].name,
|
name = tool_call["function"].name,
|
||||||
@@ -446,11 +449,11 @@ function M:parse_response(ctx, data_stream, _, opts)
|
|||||||
input_json = type(tool_call["function"].arguments) == "string" and tool_call["function"].arguments or "",
|
input_json = type(tool_call["function"].arguments) == "string" and tool_call["function"].arguments or "",
|
||||||
}
|
}
|
||||||
ctx.tool_use_list[tool_call.index + 1] = tool_use
|
ctx.tool_use_list[tool_call.index + 1] = tool_use
|
||||||
self:add_tool_use_message(tool_use, "generating", opts)
|
self:add_tool_use_message(ctx, tool_use, "generating", opts)
|
||||||
else
|
else
|
||||||
local tool_use = ctx.tool_use_list[tool_call.index + 1]
|
local tool_use = ctx.tool_use_list[tool_call.index + 1]
|
||||||
tool_use.input_json = tool_use.input_json .. tool_call["function"].arguments
|
tool_use.input_json = tool_use.input_json .. tool_call["function"].arguments
|
||||||
self:add_tool_use_message(tool_use, "generating", opts)
|
self:add_tool_use_message(ctx, tool_use, "generating", opts)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
elseif delta.content then
|
elseif delta.content then
|
||||||
|
|||||||
@@ -2176,6 +2176,11 @@ function Sidebar:get_history_messages_for_api(opts)
|
|||||||
|
|
||||||
if opts.all then return history_messages0 end
|
if opts.all then return history_messages0 end
|
||||||
|
|
||||||
|
history_messages0 = vim
|
||||||
|
.iter(history_messages0)
|
||||||
|
:filter(function(message) return message.state ~= "generating" end)
|
||||||
|
:totable()
|
||||||
|
|
||||||
if self.chat_history and self.chat_history.memory then
|
if self.chat_history and self.chat_history.memory then
|
||||||
local picked_messages = {}
|
local picked_messages = {}
|
||||||
for idx = #history_messages0, 1, -1 do
|
for idx = #history_messages0, 1, -1 do
|
||||||
@@ -2186,6 +2191,44 @@ function Sidebar:get_history_messages_for_api(opts)
|
|||||||
history_messages0 = picked_messages
|
history_messages0 = picked_messages
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local picked_messages = {}
|
||||||
|
local max_tool_use_count = 15
|
||||||
|
local tool_use_count = 0
|
||||||
|
for idx = #history_messages0, 1, -1 do
|
||||||
|
local msg = history_messages0[idx]
|
||||||
|
if tool_use_count > max_tool_use_count then
|
||||||
|
if Utils.is_tool_result_message(msg) then
|
||||||
|
local tool_use_message = Utils.get_tool_use_message(msg, history_messages0)
|
||||||
|
if tool_use_message then
|
||||||
|
local msg_content = {}
|
||||||
|
table.insert(
|
||||||
|
msg_content,
|
||||||
|
string.format(
|
||||||
|
"Tool use %s(%s)",
|
||||||
|
tool_use_message.message.content[1].name,
|
||||||
|
vim.json.encode(tool_use_message.message.content[1].input)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
table.insert(msg_content, string.format("Result: %s", msg.message.content[1].content))
|
||||||
|
table.insert(
|
||||||
|
picked_messages,
|
||||||
|
1,
|
||||||
|
HistoryMessage:new({ role = "user", content = msg_content }, { is_dummy = true })
|
||||||
|
)
|
||||||
|
end
|
||||||
|
elseif Utils.is_tool_use_message(history_messages0[idx]) then
|
||||||
|
tool_use_count = tool_use_count + 1
|
||||||
|
goto continue
|
||||||
|
else
|
||||||
|
table.insert(picked_messages, 1, msg)
|
||||||
|
end
|
||||||
|
else
|
||||||
|
if Utils.is_tool_use_message(history_messages0[idx]) then tool_use_count = tool_use_count + 1 end
|
||||||
|
table.insert(picked_messages, 1, msg)
|
||||||
|
end
|
||||||
|
::continue::
|
||||||
|
end
|
||||||
|
|
||||||
local tool_id_to_tool_name = {}
|
local tool_id_to_tool_name = {}
|
||||||
local tool_id_to_path = {}
|
local tool_id_to_path = {}
|
||||||
local tool_id_to_start_line = {}
|
local tool_id_to_start_line = {}
|
||||||
|
|||||||
@@ -110,6 +110,7 @@ vim.g.avante_login = vim.g.avante_login
|
|||||||
---@field is_dummy boolean | nil
|
---@field is_dummy boolean | nil
|
||||||
---@field is_compacted boolean | nil
|
---@field is_compacted boolean | nil
|
||||||
---@field is_deleted boolean | nil
|
---@field is_deleted boolean | nil
|
||||||
|
---@field session_id string | nil
|
||||||
---
|
---
|
||||||
---@class AvanteLLMToolResult
|
---@class AvanteLLMToolResult
|
||||||
---@field tool_name string
|
---@field tool_name string
|
||||||
|
|||||||
Reference in New Issue
Block a user