feat: delete tool use messages tool (#2158)

This commit is contained in:
yetone
2025-06-05 02:51:46 +08:00
committed by GitHub
parent af8d373f22
commit 801adc4692
8 changed files with 313 additions and 222 deletions

View File

@@ -147,170 +147,6 @@ function M.generate_prompts(opts)
local project_root = Utils.root.get()
Path.prompts.initialize(Path.prompts.get_templates_dir(project_root), project_root)
local tool_id_to_tool_name = {}
local tool_id_to_path = {}
local viewed_files = {}
local last_modified_files = {}
local history_messages = {}
if opts.history_messages then
for idx, message in ipairs(opts.history_messages) do
if Utils.is_tool_result_message(message) then
local tool_use_message = Utils.get_tool_use_message(message, opts.history_messages)
local is_replace_func_call, _, _, path = Utils.is_replace_func_call_message(tool_use_message)
if is_replace_func_call and path and not message.message.content[1].is_error then
local uniformed_path = Utils.uniform_path(path)
last_modified_files[uniformed_path] = idx
end
end
end
for idx, message in ipairs(opts.history_messages) do
table.insert(history_messages, message)
if Utils.is_tool_result_message(message) then
local tool_use_message = Utils.get_tool_use_message(message, opts.history_messages)
local is_replace_func_call, is_str_replace_editor_func_call, is_str_replace_based_edit_tool_func_call, path =
Utils.is_replace_func_call_message(tool_use_message)
--- For models like gpt-4o, the input parameter of replace_in_file is treated as the latest file content, so here we need to insert a fake view tool call to ensure it uses the latest file content
if is_replace_func_call and path and not message.message.content[1].is_error then
local uniformed_path = Utils.uniform_path(path)
local view_result, view_error = require("avante.llm_tools.view").func({ path = path }, nil, nil, nil)
if view_error then view_result = "Error: " .. view_error end
local get_diagnostics_tool_use_id = Utils.uuid()
local view_tool_use_id = Utils.uuid()
local view_tool_name = "view"
local view_tool_input = { path = path }
if is_str_replace_editor_func_call then
view_tool_name = "str_replace_editor"
view_tool_input = { command = "view", path = path }
end
if is_str_replace_based_edit_tool_func_call then
view_tool_name = "str_replace_based_edit_tool"
view_tool_input = { command = "view", path = path }
end
history_messages = vim.list_extend(history_messages, {
HistoryMessage:new({
role = "assistant",
content = string.format("Viewing file %s to get the latest content", path),
}, {
is_dummy = true,
}),
HistoryMessage:new({
role = "assistant",
content = {
{
type = "tool_use",
id = view_tool_use_id,
name = view_tool_name,
input = view_tool_input,
},
},
}, {
is_dummy = true,
}),
HistoryMessage:new({
role = "user",
content = {
{
type = "tool_result",
tool_use_id = view_tool_use_id,
content = view_result,
is_error = view_error ~= nil,
},
},
}, {
is_dummy = true,
}),
})
if last_modified_files[uniformed_path] == idx then
local diagnostics = Utils.lsp.get_diagnostics_from_filepath(path)
history_messages = vim.list_extend(history_messages, {
HistoryMessage:new({
role = "assistant",
content = string.format(
"The file %s has been modified, let me check if there are any errors in the changes.",
path
),
}, {
is_dummy = true,
}),
HistoryMessage:new({
role = "assistant",
content = {
{
type = "tool_use",
id = get_diagnostics_tool_use_id,
name = "get_diagnostics",
input = { path = path },
},
},
}, {
is_dummy = true,
}),
HistoryMessage:new({
role = "user",
content = {
{
type = "tool_result",
tool_use_id = get_diagnostics_tool_use_id,
content = vim.json.encode(diagnostics),
is_error = false,
},
},
}, {
is_dummy = true,
}),
})
end
end
end
end
for _, message in ipairs(history_messages) do
local content = message.message.content
if type(content) ~= "table" then goto continue end
for _, item in ipairs(content) do
if type(item) ~= "table" then goto continue1 end
if item.type ~= "tool_use" then goto continue1 end
local tool_name = item.name
if tool_name ~= "view" then goto continue1 end
local path = item.input.path
tool_id_to_tool_name[item.id] = tool_name
if path then
local uniform_path = Utils.uniform_path(path)
tool_id_to_path[item.id] = uniform_path
viewed_files[uniform_path] = item.id
end
::continue1::
end
::continue::
end
for _, message in ipairs(history_messages) do
local content = message.message.content
if type(content) == "table" then
for _, item in ipairs(content) do
if type(item) ~= "table" then goto continue end
if item.type ~= "tool_result" then goto continue end
local tool_name = tool_id_to_tool_name[item.tool_use_id]
if tool_name ~= "view" then goto continue end
if item.is_error then goto continue end
local path = tool_id_to_path[item.tool_use_id]
local latest_tool_id = viewed_files[path]
if not latest_tool_id then goto continue end
if latest_tool_id ~= item.tool_use_id then
item.content =
string.format("The file %s has been updated. Please use the latest `view` tool result!", path)
else
local view_result, view_error = require("avante.llm_tools.view").func({ path = path }, nil, nil, nil)
if view_error then view_result = "Error: " .. view_error end
item.content = view_result
item.is_error = view_error ~= nil
end
::continue::
end
end
end
end
local system_info = Utils.get_system_info()
local selected_files = opts.selected_files or {}
@@ -329,6 +165,28 @@ function M.generate_prompts(opts)
end
end
local viewed_files = {}
if opts.history_messages then
for _, message in ipairs(opts.history_messages) do
local content = message.message.content
if type(content) ~= "table" then goto continue end
for _, item in ipairs(content) do
if type(item) ~= "table" then goto continue1 end
if item.type ~= "tool_use" then goto continue1 end
local tool_name = item.name
if tool_name ~= "view" then goto continue1 end
local path = item.input.path
if path then
local uniform_path = Utils.uniform_path(path)
viewed_files[uniform_path] = item.id
end
::continue1::
end
::continue::
end
end
selected_files = vim.iter(selected_files):filter(function(file) return viewed_files[file.path] == nil end):totable()
local template_opts = {
@@ -405,26 +263,9 @@ function M.generate_prompts(opts)
vim.list_extend(pending_compaction_history_messages, opts.prompt_opts.pending_compaction_history_messages)
end
local cleaned_history_messages = history_messages
local final_history_messages = {}
if cleaned_history_messages then
for _, msg in ipairs(cleaned_history_messages) do
local tool_result_message
if Utils.is_tool_use_message(msg) then
tool_result_message = Utils.get_tool_result_message(msg, cleaned_history_messages)
if not tool_result_message then goto continue end
end
if Utils.is_tool_result_message(msg) then goto continue end
table.insert(final_history_messages, msg)
if tool_result_message then table.insert(final_history_messages, tool_result_message) end
::continue::
end
end
---@type AvanteLLMMessage[]
local messages = vim.deepcopy(context_messages)
for _, msg in ipairs(final_history_messages) do
for _, msg in ipairs(opts.history_messages or {}) do
local message = msg.message
if msg.is_user_submission then
message = vim.deepcopy(message)
@@ -824,13 +665,9 @@ function M._stream(opts)
table.insert(tool_results, tool_result)
return handle_next_tool_use(partial_tool_use_list, tool_use_index + 1, tool_results)
end
local is_replace_tool_use = Utils.is_replace_func_call_tool_use(partial_tool_use)
local is_edit_tool_use = Utils.is_edit_func_call_tool_use(partial_tool_use)
local is_attempt_completion_tool_use = partial_tool_use.name == "attempt_completion"
if
partial_tool_use.state == "generating"
and not is_replace_tool_use
and not is_attempt_completion_tool_use
then
if partial_tool_use.state == "generating" and not is_edit_tool_use and not is_attempt_completion_tool_use then
return
end
if type(partial_tool_use.input) == "table" then partial_tool_use.input.tool_use_id = partial_tool_use.id end
@@ -874,7 +711,7 @@ function M._stream(opts)
end
local partial_tool_use_list = {} ---@type AvantePartialLLMToolUse[]
local tool_result_seen = {}
local history_messages = opts.get_history_messages and opts.get_history_messages() or {}
local history_messages = opts.get_history_messages and opts.get_history_messages({ all = true }) or {}
for idx = #history_messages, 1, -1 do
local message = history_messages[idx]
local content = message.message.content

View File

@@ -58,6 +58,7 @@ M.returns = {
function M.on_render(opts)
local lines = {}
table.insert(lines, Line:new({ { "✓ Task Completed", Highlights.AVANTE_TASK_COMPLETED } }))
table.insert(lines, Line:new({ { "" } }))
local result = opts.result or ""
local text_lines = vim.split(result, "\n")
for _, text_line in ipairs(text_lines) do

View File

@@ -0,0 +1,62 @@
local Base = require("avante.llm_tools.base")
local Utils = require("avante.utils")
---@class AvanteLLMTool
local M = setmetatable({}, Base)
M.name = "delete_tool_use_messages"
M.description =
"Since many tool use messages are useless for completing subsequent tasks and may cause excessive token consumption or even prevent task completion, you need to decide whether to invoke this tool to delete the useless tool use messages."
---@type AvanteLLMToolParam
M.param = {
type = "table",
fields = {
{
name = "tool_use_id",
description = "The tool use id",
type = "string",
},
},
usage = {
tool_use_id = "The tool use id",
},
}
---@type AvanteLLMToolReturn[]
M.returns = {
{
name = "success",
description = "True if the deletion was successful, false otherwise",
type = "boolean",
},
{
name = "error",
description = "Error message",
type = "string",
optional = true,
},
}
---@type AvanteLLMToolFunc<{ tool_use_id: string }>
function M.func(opts)
local sidebar = require("avante").get()
if not sidebar then return false, "Avante sidebar not found" end
local history_messages = Utils.get_history_messages(sidebar.chat_history)
local the_deleted_message_uuids = {}
for _, msg in ipairs(history_messages) do
if Utils.is_tool_use_message(msg) then
local content = msg.message.content
if type(content) == "table" then
for _, item in ipairs(content) do
if item.id == opts.tool_use_id then table.insert(the_deleted_message_uuids, msg.uuid) end
end
end
end
end
sidebar:delete_history_messages(the_deleted_message_uuids)
return true, nil
end
return M

View File

@@ -107,7 +107,6 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub
local stream_options = {
ask = true,
disable_compact_history_messages = true,
memory = memory_content,
code_lang = "unknown",
provider = Providers[Config.provider],

View File

@@ -761,6 +761,7 @@ M._tools = {
},
require("avante.llm_tools.ls"),
require("avante.llm_tools.grep"),
require("avante.llm_tools.delete_tool_use_messages"),
{
name = "read_file_toplevel_symbols",
description = "Read the top-level symbols of a file in current project scope",

View File

@@ -1955,6 +1955,15 @@ local _save_history = Utils.debounce(function(self) Path.history.save(self.code.
local save_history = vim.schedule_wrap(_save_history)
---@param uuids string[]
function Sidebar:delete_history_messages(uuids)
local history_messages = Utils.get_history_messages(self.chat_history)
for _, msg in ipairs(history_messages) do
if vim.list_contains(uuids, msg.uuid) then msg.is_deleted = true end
end
Path.history.save(self.code.bufnr, self.chat_history)
end
---@param messages avante.HistoryMessage | avante.HistoryMessage[]
function Sidebar:add_history_messages(messages)
local history_messages = Utils.get_history_messages(self.chat_history)
@@ -1978,7 +1987,6 @@ function Sidebar:add_history_messages(messages)
end
end
self.chat_history.messages = history_messages
-- 历史消息变更时,标记缓存失效
self._history_cache_invalidated = true
save_history(self)
if
@@ -1987,15 +1995,6 @@ function Sidebar:add_history_messages(messages)
and messages[1].just_for_display ~= true
and messages[1].state == "generated"
then
-- self.chat_history.title = "generating..."
-- Llm.summarize_chat_thread_title(messages[1].message.content, function(title)
-- if title then
-- self.chat_history.title = title
-- else
-- self.chat_history.title = "untitled"
-- end
-- save_history(self)
-- end)
local first_msg_text = Utils.message_to_text(messages[1], messages)
local lines_ = vim.split(first_msg_text, "\n")
if #lines_ > 0 then
@@ -2190,24 +2189,219 @@ end
function Sidebar:reload_chat_history()
if not self.code.bufnr or not api.nvim_buf_is_valid(self.code.bufnr) then return end
self.chat_history = Path.history.load(self.code.bufnr)
-- 重新加载历史时,标记缓存失效
self._history_cache_invalidated = true
end
---@param opts? {all?: boolean}
---@return avante.HistoryMessage[]
function Sidebar:get_history_messages_for_api()
local history_messages = Utils.get_history_messages(self.chat_history)
self.chat_history.messages = history_messages
function Sidebar:get_history_messages_for_api(opts)
opts = opts or {}
local history_messages0 = Utils.get_history_messages(self.chat_history)
self.chat_history.messages = history_messages0
if self.chat_history.memory then
history_messages = {}
for i = #self.chat_history.messages, 1, -1 do
local message = self.chat_history.messages[i]
if message.uuid == self.chat_history.memory.last_message_uuid then break end
table.insert(history_messages, 1, message)
history_messages0 = vim
.iter(history_messages0)
:filter(function(message) return not message.just_for_display and not message.is_compacted end)
:totable()
if opts.all then return history_messages0 end
local tool_id_to_tool_name = {}
local tool_id_to_path = {}
local tool_id_to_start_line = {}
local tool_id_to_end_line = {}
local viewed_files = {}
local last_modified_files = {}
local history_messages = {}
local failed_edit_tool_ids = {}
for idx, message in ipairs(history_messages0) do
if Utils.is_tool_result_message(message) then
local tool_use_message = Utils.get_tool_use_message(message, history_messages0)
local is_edit_func_call, _, _, path = Utils.is_edit_func_call_message(tool_use_message)
if is_edit_func_call and message.message.content[1].is_error then
failed_edit_tool_ids[message.message.content[1].tool_use_id] = true
end
if is_edit_func_call and path and not message.message.content[1].is_error then
local uniformed_path = Utils.uniform_path(path)
last_modified_files[uniformed_path] = idx
end
end
end
return vim.iter(history_messages):filter(function(message) return not message.just_for_display end):totable()
for idx, message in ipairs(history_messages0) do
if Utils.is_tool_use_message(message) and failed_edit_tool_ids[message.message.content[1].id] then
goto continue
end
table.insert(history_messages, message)
if Utils.is_tool_result_message(message) then
local tool_use_message = Utils.get_tool_use_message(message, history_messages0)
local is_edit_func_call, is_str_replace_editor_func_call, is_str_replace_based_edit_tool_func_call, path =
Utils.is_edit_func_call_message(tool_use_message)
--- For models like gpt-4o, the input parameter of replace_in_file is treated as the latest file content, so here we need to insert a fake view tool call to ensure it uses the latest file content
if is_edit_func_call and path and not message.message.content[1].is_error then
local uniformed_path = Utils.uniform_path(path)
local view_result, view_error = require("avante.llm_tools.view").func({ path = path }, nil, nil, nil)
if view_error then view_result = "Error: " .. view_error end
local get_diagnostics_tool_use_id = Utils.uuid()
local view_tool_use_id = Utils.uuid()
local view_tool_name = "view"
local view_tool_input = { path = path }
if is_str_replace_editor_func_call then
view_tool_name = "str_replace_editor"
view_tool_input = { command = "view", path = path }
end
if is_str_replace_based_edit_tool_func_call then
view_tool_name = "str_replace_based_edit_tool"
view_tool_input = { command = "view", path = path }
end
history_messages = vim.list_extend(history_messages, {
HistoryMessage:new({
role = "assistant",
content = string.format("Viewing file %s to get the latest content", path),
}, {
is_dummy = true,
}),
HistoryMessage:new({
role = "assistant",
content = {
{
type = "tool_use",
id = view_tool_use_id,
name = view_tool_name,
input = view_tool_input,
},
},
}, {
is_dummy = true,
}),
HistoryMessage:new({
role = "user",
content = {
{
type = "tool_result",
tool_use_id = view_tool_use_id,
content = view_result,
is_error = view_error ~= nil,
},
},
}, {
is_dummy = true,
}),
})
if last_modified_files[uniformed_path] == idx then
local diagnostics = Utils.lsp.get_diagnostics_from_filepath(path)
history_messages = vim.list_extend(history_messages, {
HistoryMessage:new({
role = "assistant",
content = string.format(
"The file %s has been modified, let me check if there are any errors in the changes.",
path
),
}, {
is_dummy = true,
}),
HistoryMessage:new({
role = "assistant",
content = {
{
type = "tool_use",
id = get_diagnostics_tool_use_id,
name = "get_diagnostics",
input = { path = path },
},
},
}, {
is_dummy = true,
}),
HistoryMessage:new({
role = "user",
content = {
{
type = "tool_result",
tool_use_id = get_diagnostics_tool_use_id,
content = vim.json.encode(diagnostics),
is_error = false,
},
},
}, {
is_dummy = true,
}),
})
end
end
end
::continue::
end
for _, message in ipairs(history_messages) do
local content = message.message.content
if type(content) ~= "table" then goto continue end
for _, item in ipairs(content) do
if type(item) ~= "table" then goto continue1 end
if item.type ~= "tool_use" then goto continue1 end
local tool_name = item.name
if tool_name ~= "view" then goto continue1 end
local path = item.input.path
tool_id_to_tool_name[item.id] = tool_name
if path then
local uniform_path = Utils.uniform_path(path)
tool_id_to_path[item.id] = uniform_path
tool_id_to_start_line[item.id] = item.input.start_line
tool_id_to_end_line[item.id] = item.input.end_line
viewed_files[uniform_path] = item.id
end
::continue1::
end
::continue::
end
for _, message in ipairs(history_messages) do
local content = message.message.content
if type(content) == "table" then
for _, item in ipairs(content) do
if type(item) ~= "table" then goto continue end
if item.type ~= "tool_result" then goto continue end
local tool_name = tool_id_to_tool_name[item.tool_use_id]
if tool_name ~= "view" then goto continue end
if item.is_error then goto continue end
local path = tool_id_to_path[item.tool_use_id]
local latest_tool_id = viewed_files[path]
if not latest_tool_id then goto continue end
if latest_tool_id ~= item.tool_use_id then
item.content = string.format("The file %s has been updated. Please use the latest `view` tool result!", path)
else
local start_line = tool_id_to_start_line[item.tool_use_id]
local end_line = tool_id_to_end_line[item.tool_use_id]
local view_result, view_error = require("avante.llm_tools.view").func(
{ path = path, start_line = start_line, end_line = end_line },
nil,
nil,
nil
)
if view_error then view_result = "Error: " .. view_error end
item.content = view_result
item.is_error = view_error ~= nil
end
::continue::
end
end
end
local final_history_messages = {}
for _, msg in ipairs(history_messages) do
local tool_result_message
if Utils.is_tool_use_message(msg) then
tool_result_message = Utils.get_tool_result_message(msg, history_messages)
if not tool_result_message then goto continue end
end
if Utils.is_tool_result_message(msg) then goto continue end
table.insert(final_history_messages, msg)
if tool_result_message then table.insert(final_history_messages, tool_result_message) end
::continue::
end
return final_history_messages
end
---@param request string
@@ -2292,8 +2486,6 @@ function Sidebar:get_generate_prompts_options(request, cb)
history_messages = history_messages,
code_lang = filetype,
selected_code = selected_code,
disable_compact_history_messages = true,
-- instructions = request,
tools = tools,
}
@@ -2409,8 +2601,6 @@ function Sidebar:create_input_container()
self:render_state()
end
local save_history = Utils.debounce(function() Path.history.save(self.code.bufnr, self.chat_history) end, 3000)
---@param tool_id string
---@param tool_name string
---@param log string
@@ -2434,7 +2624,7 @@ function Sidebar:create_input_container()
local content = string.format("[%s]: %s", tool_name, log)
table.insert(tool_use_logs, content)
tool_use_message.tool_use_logs = tool_use_logs
save_history()
save_history(self)
self:update_content("")
end
@@ -2502,7 +2692,7 @@ function Sidebar:create_input_container()
on_tool_log = on_tool_log,
on_messages_add = on_messages_add,
on_state_change = on_state_change,
get_history_messages = function() return self:get_history_messages_for_api() end,
get_history_messages = function(opts) return self:get_history_messages_for_api(opts) end,
session_ctx = {},
})

View File

@@ -102,6 +102,8 @@ vim.g.avante_login = vim.g.avante_login
---@field tool_use_logs string[] | nil
---@field just_for_display boolean | nil
---@field is_dummy boolean | nil
---@field is_compacted boolean | nil
---@field is_deleted boolean | nil
---
---@class AvanteLLMToolResult
---@field tool_name string
@@ -351,7 +353,6 @@ vim.g.avante_login = vim.g.avante_login
---@field update_snippets? string[]
---@field prompt_opts? AvantePromptOptions
---@field session_ctx? table
---@field disable_compact_history_messages? boolean
---
---@class AvanteLLMToolHistory
---@field tool_result? AvanteLLMToolResult
@@ -368,7 +369,7 @@ vim.g.avante_login = vim.g.avante_login
---@field on_stop AvanteLLMStopCallback
---@field on_memory_summarize? AvanteLLMMemorySummarizeCallback
---@field on_tool_log? fun(tool_id: string, tool_name: string, log: string, state: AvanteLLMToolUseState): nil
---@field get_history_messages? fun(): avante.HistoryMessage[]
---@field get_history_messages? fun(opts?: { all?: boolean }): avante.HistoryMessage[]
---@field on_messages_add? fun(messages: avante.HistoryMessage[]): nil
---@field on_state_change? fun(state: avante.GenerateState): nil
---

View File

@@ -1435,7 +1435,7 @@ function M.tool_use_to_xml(tool_use)
end
---@param tool_use AvanteLLMToolUse
function M.is_replace_func_call_tool_use(tool_use)
function M.is_edit_func_call_tool_use(tool_use)
local is_replace_func_call = false
local is_str_replace_editor_func_call = false
local is_str_replace_based_edit_tool_func_call = false
@@ -1466,7 +1466,7 @@ function M.is_replace_func_call_tool_use(tool_use)
end
---@param tool_use_message avante.HistoryMessage | nil
function M.is_replace_func_call_message(tool_use_message)
function M.is_edit_func_call_message(tool_use_message)
local is_replace_func_call = false
local is_str_replace_editor_func_call = false
local is_str_replace_based_edit_tool_func_call = false
@@ -1474,7 +1474,7 @@ function M.is_replace_func_call_message(tool_use_message)
if tool_use_message and M.is_tool_use_message(tool_use_message) then
local tool_use = tool_use_message.message.content[1]
---@cast tool_use AvanteLLMToolUse
return M.is_replace_func_call_tool_use(tool_use)
return M.is_edit_func_call_tool_use(tool_use)
end
return is_replace_func_call, is_str_replace_editor_func_call, is_str_replace_based_edit_tool_func_call, path
end