feat(ui): expand tool use message (#2684)

This commit is contained in:
yetone
2025-09-04 00:17:32 +08:00
committed by GitHub
parent 20c4f44195
commit 4ac4c8ed3f
6 changed files with 243 additions and 84 deletions

View File

@@ -569,6 +569,7 @@ M._defaults = {
repomap = "<leader>aR",
},
sidebar = {
expand_tool_use = "<S-Tab>",
next_prompt = "]p",
prev_prompt = "[p",
apply_all = "A",

View File

@@ -7,11 +7,11 @@ local rshift, band = bit.rshift, bit.band
local Highlights = {
TITLE = { name = "AvanteTitle", fg = "#1e222a", bg = "#98c379" },
REVERSED_TITLE = { name = "AvanteReversedTitle", fg = "#98c379" },
REVERSED_TITLE = { name = "AvanteReversedTitle", fg = "#98c379", bg_link = "NormalFloat" },
SUBTITLE = { name = "AvanteSubtitle", fg = "#1e222a", bg = "#56b6c2" },
REVERSED_SUBTITLE = { name = "AvanteReversedSubtitle", fg = "#56b6c2" },
REVERSED_SUBTITLE = { name = "AvanteReversedSubtitle", fg = "#56b6c2", bg_link = "NormalFloat" },
THIRD_TITLE = { name = "AvanteThirdTitle", fg = "#ABB2BF", bg = "#353B45" },
REVERSED_THIRD_TITLE = { name = "AvanteReversedThirdTitle", fg = "#353B45" },
REVERSED_THIRD_TITLE = { name = "AvanteReversedThirdTitle", fg = "#353B45", bg_link = "NormalFloat" },
SUGGESTION = { name = "AvanteSuggestion", link = "Comment" },
ANNOTATION = { name = "AvanteAnnotation", link = "Comment" },
POPUP_HINT = { name = "AvantePopupHint", link = "NormalFloat" },

View File

@@ -137,39 +137,45 @@ function M.get_diff_lines(old_str, new_str, decoration, truncate)
ctxlen = vim.o.scrolloff,
})
local prev_start_a = 0
for idx, hunk in ipairs(patch) do
if truncate and line_count > 10 then
table.insert(
lines,
Line:new({
{ decoration },
{
string.format("... (Result truncated, remaining %d hunks not shown)", #patch - idx + 1),
Highlights.AVANTE_COMMENT_FG,
},
})
)
break
end
local truncated_lines = 0
for _, hunk in ipairs(patch) do
local start_a, count_a, start_b, count_b = unpack(hunk)
local no_change_lines = vim.list_slice(old_lines, prev_start_a, start_a - 1)
local last_tree_no_change_lines = vim.list_slice(no_change_lines, #no_change_lines - 3)
if #no_change_lines > 3 then table.insert(lines, Line:new({ { decoration }, { "..." } })) end
for _, line in ipairs(last_tree_no_change_lines) do
if truncate then
local last_three_no_change_lines = vim.list_slice(no_change_lines, #no_change_lines - 3)
truncated_lines = truncated_lines + #no_change_lines - #last_three_no_change_lines
if #no_change_lines > 4 then
table.insert(lines, Line:new({ { decoration }, { "...", Highlights.AVANTE_COMMENT_FG } }))
end
no_change_lines = last_three_no_change_lines
end
for idx, line in ipairs(no_change_lines) do
if truncate and line_count > 10 then
truncated_lines = truncated_lines + #no_change_lines - idx
break
end
line_count = line_count + 1
table.insert(lines, Line:new({ { decoration }, { line } }))
end
prev_start_a = start_a + count_a
if count_a > 0 then
local delete_lines = vim.list_slice(old_lines, start_a, start_a + count_a - 1)
for _, line in ipairs(delete_lines) do
for idx, line in ipairs(delete_lines) do
if truncate and line_count > 10 then
truncated_lines = truncated_lines + #delete_lines - idx
break
end
line_count = line_count + 1
table.insert(lines, Line:new({ { decoration }, { line, Highlights.TO_BE_DELETED_WITHOUT_STRIKETHROUGH } }))
end
end
if count_b > 0 then
local create_lines = vim.list_slice(new_lines, start_b, start_b + count_b - 1)
for _, line in ipairs(create_lines) do
for idx, line in ipairs(create_lines) do
if truncate and line_count > 10 then
truncated_lines = truncated_lines + #create_lines - idx
break
end
line_count = line_count + 1
table.insert(lines, Line:new({ { decoration }, { line, Highlights.INCOMING } }))
end
@@ -178,12 +184,27 @@ function M.get_diff_lines(old_str, new_str, decoration, truncate)
if prev_start_a < #old_lines then
-- Append remaining old_lines
local no_change_lines = vim.list_slice(old_lines, prev_start_a, #old_lines)
local first_tree_no_change_lines = vim.list_slice(no_change_lines, 1, 3)
for _, line in ipairs(first_tree_no_change_lines) do
local first_three_no_change_lines = vim.list_slice(no_change_lines, 1, 3)
for idx, line in ipairs(first_three_no_change_lines) do
if truncate and line_count > 10 then
truncated_lines = truncated_lines + #first_three_no_change_lines - idx
break
end
line_count = line_count + 1
table.insert(lines, Line:new({ { decoration }, { line } }))
end
if #no_change_lines > 3 then table.insert(lines, Line:new({ { decoration }, { "..." } })) end
end
if truncate and truncated_lines > 0 then
table.insert(
lines,
Line:new({
{ decoration },
{
string.format("... (Result truncated, remaining %d lines not shown)", truncated_lines),
Highlights.AVANTE_COMMENT_FG,
},
})
)
end
return lines
end
@@ -369,8 +390,9 @@ end
---@param item AvanteLLMMessageContentItem
---@param message avante.HistoryMessage
---@param messages avante.HistoryMessage[]
---@param expanded boolean | nil
---@return avante.ui.Line[]
local function tool_to_lines(item, message, messages)
local function tool_to_lines(item, message, messages, expanded)
-- local logs = message.tool_use_logs
local lines = {}
@@ -405,7 +427,7 @@ local function tool_to_lines(item, message, messages)
local lines_ = text_to_lines(table.concat(rest_input_text_lines, "\n"), decoration)
local line_count = 0
for idx, line in ipairs(lines_) do
if line_count > 3 then
if not expanded and line_count > 3 then
table.insert(
lines,
Line:new({
@@ -425,21 +447,21 @@ local function tool_to_lines(item, message, messages)
end
if item.input and type(item.input) == "table" then
if type(item.input.old_str) == "string" and type(item.input.new_str) == "string" then
local diff_lines = M.get_diff_lines(item.input.old_str, item.input.new_str, decoration, true)
local diff_lines = M.get_diff_lines(item.input.old_str, item.input.new_str, decoration, not expanded)
vim.list_extend(lines, diff_lines)
end
end
if message.acp_tool_call and message.acp_tool_call.content then
local content = message.acp_tool_call.content
if content then
local content_lines = M.get_content_lines(content, decoration, true)
local content_lines = M.get_content_lines(content, decoration, not expanded)
vim.list_extend(lines, content_lines)
end
else
if result and result.content then
local result_content = result.content
if result_content then
local content_lines = M.get_content_lines(result_content, decoration, true)
local content_lines = M.get_content_lines(result_content, decoration, not expanded)
vim.list_extend(lines, content_lines)
end
end
@@ -454,8 +476,9 @@ end
---@param item AvanteLLMMessageContentItem
---@param message avante.HistoryMessage
---@param messages avante.HistoryMessage[]
---@param expanded boolean | nil
---@return avante.ui.Line[]
local function message_content_item_to_lines(item, message, messages)
local function message_content_item_to_lines(item, message, messages, expanded)
if type(item) == "string" then
return text_to_lines(item)
elseif type(item) == "table" then
@@ -480,7 +503,7 @@ local function message_content_item_to_lines(item, message, messages)
end
end
local lines = tool_to_lines(item, message, messages)
local lines = tool_to_lines(item, message, messages, expanded)
if message.tool_use_log_lines then lines = vim.list_extend(lines, message.tool_use_log_lines) end
return lines
end
@@ -491,15 +514,16 @@ end
---Converts a message into representation suitable for UI
---@param message avante.HistoryMessage
---@param messages avante.HistoryMessage[]
---@param expanded boolean | nil
---@return avante.ui.Line[]
function M.message_to_lines(message, messages)
function M.message_to_lines(message, messages, expanded)
if message.displayed_content then return text_to_lines(message.displayed_content) end
local content = message.message.content
if type(content) == "string" then return text_to_lines(content) end
if islist(content) then
local lines = {}
for _, item in ipairs(content) do
local item_lines = message_content_item_to_lines(item, message, messages)
local item_lines = message_content_item_to_lines(item, message, messages, expanded)
lines = vim.list_extend(lines, item_lines)
end
return lines

View File

@@ -25,6 +25,7 @@ local logo = require("avante.utils.logo")
local RESULT_BUF_NAME = "AVANTE_RESULT"
local VIEW_BUFFER_UPDATED_PATTERN = "AvanteViewBufferUpdated"
local CODEBLOCK_KEYBINDING_NAMESPACE = api.nvim_create_namespace("AVANTE_CODEBLOCK_KEYBINDING")
local TOOL_MESSAGE_KEYBINDING_NAMESPACE = api.nvim_create_namespace("AVANTE_TOOL_MESSAGE_KEYBINDING")
local USER_REQUEST_BLOCK_KEYBINDING_NAMESPACE = api.nvim_create_namespace("AVANTE_USER_REQUEST_BLOCK_KEYBINDING")
local SELECTED_FILES_HINT_NAMESPACE = api.nvim_create_namespace("AVANTE_SELECTED_FILES_HINT")
local SELECTED_FILES_ICON_NAMESPACE = api.nvim_create_namespace("AVANTE_SELECTED_FILES_ICON")
@@ -78,6 +79,11 @@ Sidebar.__index = Sidebar
---@field acp_client avante.acp.ACPClient | nil
---@field acp_session_id string | nil
---@field post_render? fun(sidebar: avante.Sidebar)
---@field message_button_handlers table<string, table<string, fun(arg: any)>>
---@field expanded_message_uuids table<string, boolean>
---@field tool_message_positions table<string, [integer, integer]>
---@field skip_line_count integer | nil
---@field current_tool_use_extmark_id integer | nil
---@param id integer the tabpage id retrieved from api.nvim_get_current_tabpage()
function Sidebar:new(id)
@@ -110,6 +116,10 @@ function Sidebar:new(id)
_cached_history_lines = nil,
_history_cache_invalidated = true,
post_render = nil,
message_handlers = {},
tool_message_positions = {},
expanded_message_ids = {},
current_tool_use_extmark_id = nil,
}, Sidebar)
end
@@ -146,6 +156,10 @@ function Sidebar:reset()
self.scroll = true
self.old_result_lines = {}
self.token_count = nil
self.message_button_handlers = {}
self.tool_message_positions = {}
self.expanded_message_uuids = {}
self.current_tool_use_extmark_id = nil
end
---@class SidebarOpenOptions: AskOptions
@@ -745,6 +759,16 @@ function Sidebar:is_cursor_in_user_request_block()
return cursor_line >= block.start_line and cursor_line <= block.end_line
end
function Sidebar:get_current_tool_use_message_uuid()
local skip_line_count = self.skip_line_count or 0
local cursor_line = api.nvim_win_get_cursor(self.containers.result.winid)[1]
for message_uuid, positions in pairs(self.tool_message_positions) do
if skip_line_count + positions[1] + 1 <= cursor_line and cursor_line <= skip_line_count + positions[2] then
return message_uuid, positions
end
end
end
---@class AvanteCodeblock
---@field start_line integer 1-indexed
---@field end_line integer 1-indexed
@@ -840,6 +864,13 @@ function Sidebar:retry_user_request()
self.handle_submit(block.content)
end
function Sidebar:handle_expand_message(message_uuid, expanded)
Utils.debug("handle_expand_message", message_uuid, expanded)
self.expanded_message_uuids[message_uuid] = expanded
self._history_cache_invalidated = true
self:update_content("")
end
function Sidebar:edit_user_request()
local block = self:get_current_user_request_block()
if not block then return end
@@ -1066,6 +1097,24 @@ function Sidebar:unbind_retry_user_request_key()
end
end
function Sidebar:bind_expand_tool_use_key(message_uuid)
if self.containers.result then
local expanded = self.expanded_message_uuids[message_uuid]
vim.keymap.set(
"n",
Config.mappings.sidebar.expand_tool_use,
function() self:handle_expand_message(message_uuid, not expanded) end,
{ buffer = self.containers.result.bufnr, noremap = true, silent = true }
)
end
end
function Sidebar:unbind_expand_tool_use_key()
if self.containers.result then
pcall(vim.keymap.del, "n", Config.mappings.sidebar.expand_tool_use, { buffer = self.containers.result.bufnr })
end
end
function Sidebar:bind_edit_user_request_key()
if self.containers.result then
vim.keymap.set(
@@ -1083,6 +1132,50 @@ function Sidebar:unbind_edit_user_request_key()
end
end
function Sidebar:render_tool_use_control_buttons()
local function show_current_tool_use_control_buttons()
if self.current_tool_use_extmark_id then
api.nvim_buf_del_extmark(
self.containers.result.bufnr,
TOOL_MESSAGE_KEYBINDING_NAMESPACE,
self.current_tool_use_extmark_id
)
end
local message_uuid, positions = self:get_current_tool_use_message_uuid()
if not message_uuid then return end
local expanded = self.expanded_message_uuids[message_uuid]
local skip_line_count = self.skip_line_count or 0
self.current_tool_use_extmark_id = api.nvim_buf_set_extmark(
self.containers.result.bufnr,
TOOL_MESSAGE_KEYBINDING_NAMESPACE,
skip_line_count + positions[1] + 2,
-1,
{
virt_text = {
{
string.format(" [%s: %s] ", Config.mappings.sidebar.expand_tool_use, expanded and "Collapse" or "Expand"),
"AvanteInlineHint",
},
},
virt_text_pos = "right_align",
hl_group = "AvanteInlineHint",
priority = PRIORITY,
}
)
end
local current_tool_use_message_uuid = self:get_current_tool_use_message_uuid()
if current_tool_use_message_uuid then
show_current_tool_use_control_buttons()
self:bind_expand_tool_use_key(current_tool_use_message_uuid)
else
api.nvim_buf_clear_namespace(self.containers.result.bufnr, TOOL_MESSAGE_KEYBINDING_NAMESPACE, 0, -1)
self:unbind_expand_tool_use_key()
end
end
function Sidebar:bind_sidebar_keys(codeblocks)
---@param direction "next" | "prev"
local function jump_to_codeblock(direction)
@@ -1291,6 +1384,8 @@ function Sidebar:on_mount(opts)
group = self.augroup,
buffer = self.containers.result.bufnr,
callback = function(ev)
self:render_tool_use_control_buttons()
local in_codeblock = is_cursor_in_codeblock(codeblocks)
if in_codeblock then
@@ -1628,8 +1723,10 @@ function Sidebar:update_content(content, opts)
)
local history_lines
local tool_message_positions
if not self._cached_history_lines or self._history_cache_invalidated then
history_lines = self.get_history_lines(self.chat_history, self.show_logo)
history_lines, tool_message_positions = self:get_history_lines(self.chat_history, self.show_logo)
self.tool_message_positions = tool_message_positions
self._cached_history_lines = history_lines
self._history_cache_invalidated = false
else
@@ -1648,7 +1745,10 @@ function Sidebar:update_content(content, opts)
self:clear_state()
local skip_line_count = 0
if self.show_logo then skip_line_count = self:render_logo() end
if self.show_logo then
skip_line_count = self:render_logo()
self.skip_line_count = skip_line_count
end
local bufnr = self.containers.result.bufnr
Utils.unlock_buf(bufnr)
@@ -1673,6 +1773,7 @@ function Sidebar:update_content(content, opts)
vim.schedule(function()
self:render_state()
self:render_tool_use_control_buttons()
vim.defer_fn(function() vim.cmd("redraw") end, 10)
end)
@@ -1743,10 +1844,11 @@ end
---@param messages avante.HistoryMessage[]
---@param ctx table
---@param ignore_record_prefix boolean | nil
---@param expanded boolean | nil
---@return avante.ui.Line[]
local function _get_message_lines(message, messages, ctx, ignore_record_prefix)
local function _get_message_lines(message, messages, ctx, ignore_record_prefix, expanded)
if message.visible == false then return {} end
local lines = Render.message_to_lines(message, messages)
local lines = Render.message_to_lines(message, messages, expanded)
if message.is_user_submission and not ignore_record_prefix then
ctx.selected_filepaths = message.selected_filepaths
local text = table.concat(vim.tbl_map(function(line) return tostring(line) end, lines), "\n")
@@ -1794,10 +1896,11 @@ local _message_to_lines_lru_cache = LRUCache:new(100)
---@param messages avante.HistoryMessage[]
---@param ctx table
---@param ignore_record_prefix boolean | nil
---@param expanded boolean | nil
---@return avante.ui.Line[]
local function get_message_lines(message, messages, ctx, ignore_record_prefix)
local function get_message_lines(message, messages, ctx, ignore_record_prefix, expanded)
if message.state == "generating" or message.is_calling then
return _get_message_lines(message, messages, ctx, ignore_record_prefix)
return _get_message_lines(message, messages, ctx, ignore_record_prefix, expanded)
end
local text_len = 0
local content = message.message.content
@@ -1814,52 +1917,51 @@ local function get_message_lines(message, messages, ctx, ignore_record_prefix)
elseif type(content) == "string" then
text_len = #content
end
local cache_key = message.uuid .. ":" .. tostring(text_len)
local cache_key = message.uuid .. ":" .. tostring(text_len) .. ":" .. tostring(expanded == true)
local cached_lines = _message_to_lines_lru_cache:get(cache_key)
if cached_lines then return cached_lines end
local lines = _get_message_lines(message, messages, ctx, ignore_record_prefix)
local lines = _get_message_lines(message, messages, ctx, ignore_record_prefix, expanded)
_message_to_lines_lru_cache:set(cache_key, lines)
return lines
end
---@param history avante.ChatHistory
---@param ignore_record_prefix boolean | nil
---@return avante.ui.Line[]
function Sidebar.get_history_lines(history, ignore_record_prefix)
---@return avante.ui.Line[] history_lines
---@return table<string, [integer, integer]> tool_message_positions
function Sidebar:get_history_lines(history, ignore_record_prefix)
local history_messages = History.get_history_messages(history)
local ctx = {}
---@type avante.ui.Line[][]
local group = {}
---@type avante.ui.Line[]
local res = {}
local tool_message_positions = {}
local is_first_user_submission = true
for _, message in ipairs(history_messages) do
local lines = get_message_lines(message, history_messages, ctx, ignore_record_prefix)
local expanded = self.expanded_message_uuids[message.uuid]
local lines = get_message_lines(message, history_messages, ctx, ignore_record_prefix, expanded)
if #lines == 0 then goto continue end
if message.is_user_submission then table.insert(group, {}) end
local last_item = group[#group]
if last_item == nil then
table.insert(group, {})
last_item = group[#group]
if message.is_user_submission then
if not is_first_user_submission then
if ignore_record_prefix then
res = vim.list_extend(res, { Line:new({ { "" } }), Line:new({ { "" } }) })
else
res = vim.list_extend(res, { Line:new({ { "" } }), Line:new({ { RESP_SEPARATOR } }), Line:new({ { "" } }) })
end
end
is_first_user_submission = false
end
if message.message.role == "assistant" and not message.just_for_display and tostring(lines[1]) ~= "" then
table.insert(lines, 1, Line:new({ { "" } }))
table.insert(lines, 1, Line:new({ { "" } }))
end
last_item = vim.list_extend(last_item, lines)
group[#group] = last_item
if History.Helpers.is_tool_use_message(message) then
tool_message_positions[message.uuid] = { #res, #res + #lines }
end
res = vim.list_extend(res, lines)
::continue::
end
local res = {}
for idx, item in ipairs(group) do
if idx ~= 1 then
if ignore_record_prefix then
res = vim.list_extend(res, { Line:new({ { "" } }), Line:new({ { "" } }) })
else
res = vim.list_extend(res, { Line:new({ { "" } }), Line:new({ { RESP_SEPARATOR } }), Line:new({ { "" } }) })
end
end
res = vim.list_extend(res, item)
end
table.insert(res, Line:new({ { "" } }))
return res
return res, tool_message_positions
end
---@param message avante.HistoryMessage
@@ -2086,6 +2188,10 @@ function Sidebar:new_chat(args, cb)
self:reload_chat_history()
self.current_state = nil
self.acp_session_id = nil
self.message_button_handlers = {}
self.expanded_message_uuids = {}
self.tool_message_positions = {}
self.current_tool_use_extmark_id = nil
self:update_content("New chat", { focus = false, scroll = false, callback = function() self:focus_input() end })
--- goto first line then go to last line
vim.schedule(function()

View File

@@ -118,14 +118,16 @@ vim.g.avante_login = vim.g.avante_login
---@field is_calling boolean | nil
---@field original_content AvanteLLMMessageContent | nil
---@field acp_tool_call? avante.acp.ToolCall
---
---@field permission_options? avante.acp.PermissionOption[]
---@field is_permission_confirming? boolean
---@class AvanteLLMToolResult
---@field tool_name string
---@field tool_use_id string
---@field content string
---@field is_error? boolean
---@field is_user_declined? boolean
---
---@class AvantePromptOptions: table<[string], string>
---@field system_prompt string
---@field messages AvanteLLMMessage[]

View File

@@ -1300,29 +1300,55 @@ end
---@param skip_line_count? integer
function M.update_buffer_lines(ns_id, bufnr, old_lines, new_lines, skip_line_count)
skip_line_count = skip_line_count or 0
local diffs = get_lines_diff(old_lines, new_lines)
if #diffs == 0 then return end
for _, diff in ipairs(diffs) do
local lines = diff.content
local text_lines = vim.tbl_map(function(line) return tostring(line) end, lines)
--- rmeove newlines from text_lines
local cleaned_lines = {}
for _, line in ipairs(text_lines) do
local lines_ = vim.split(line, "\n")
cleaned_lines = vim.list_extend(cleaned_lines, lines_)
local diff_start_idx = 0
for i, line in ipairs(new_lines) do
local old_line = old_lines[i]
if not old_line or old_line ~= line then
diff_start_idx = i
break
end
end
if diff_start_idx > 0 then
local changed_lines = vim.list_slice(new_lines, diff_start_idx)
local text_lines = vim.tbl_map(function(line) return tostring(line) end, changed_lines)
vim.api.nvim_buf_set_lines(
bufnr,
skip_line_count + diff.start_line - 1,
skip_line_count + diff.end_line - 1,
skip_line_count + diff_start_idx - 1,
skip_line_count + diff_start_idx + #changed_lines,
false,
cleaned_lines
text_lines
)
for i, line in ipairs(lines) do
line:set_highlights(ns_id, bufnr, skip_line_count + diff.start_line + i - 2)
for i, line in ipairs(changed_lines) do
line:set_highlights(ns_id, bufnr, skip_line_count + diff_start_idx + i - 2)
end
vim.cmd("redraw")
end
if #old_lines > #new_lines then
vim.api.nvim_buf_set_lines(bufnr, skip_line_count + #new_lines, skip_line_count + #old_lines, false, {})
end
vim.cmd("redraw")
-- local diffs = get_lines_diff(old_lines, new_lines)
-- if #diffs == 0 then return end
-- for _, diff in ipairs(diffs) do
-- local lines = diff.content
-- local text_lines = vim.tbl_map(function(line) return tostring(line) end, lines)
-- --- remove newlines from text_lines
-- local cleaned_lines = {}
-- for _, line in ipairs(text_lines) do
-- local lines_ = vim.split(line, "\n")
-- cleaned_lines = vim.list_extend(cleaned_lines, lines_)
-- end
-- vim.api.nvim_buf_set_lines(
-- bufnr,
-- skip_line_count + diff.start_line - 1,
-- skip_line_count + diff.end_line - 1,
-- false,
-- cleaned_lines
-- )
-- for i, line in ipairs(lines) do
-- line:set_highlights(ns_id, bufnr, skip_line_count + diff.start_line + i - 2)
-- end
-- vim.cmd("redraw")
-- end
end
function M.uniform_path(path)