feat(ui): expand tool use message (#2684)
This commit is contained in:
@@ -25,6 +25,7 @@ local logo = require("avante.utils.logo")
|
||||
local RESULT_BUF_NAME = "AVANTE_RESULT"
|
||||
local VIEW_BUFFER_UPDATED_PATTERN = "AvanteViewBufferUpdated"
|
||||
local CODEBLOCK_KEYBINDING_NAMESPACE = api.nvim_create_namespace("AVANTE_CODEBLOCK_KEYBINDING")
|
||||
local TOOL_MESSAGE_KEYBINDING_NAMESPACE = api.nvim_create_namespace("AVANTE_TOOL_MESSAGE_KEYBINDING")
|
||||
local USER_REQUEST_BLOCK_KEYBINDING_NAMESPACE = api.nvim_create_namespace("AVANTE_USER_REQUEST_BLOCK_KEYBINDING")
|
||||
local SELECTED_FILES_HINT_NAMESPACE = api.nvim_create_namespace("AVANTE_SELECTED_FILES_HINT")
|
||||
local SELECTED_FILES_ICON_NAMESPACE = api.nvim_create_namespace("AVANTE_SELECTED_FILES_ICON")
|
||||
@@ -78,6 +79,11 @@ Sidebar.__index = Sidebar
|
||||
---@field acp_client avante.acp.ACPClient | nil
|
||||
---@field acp_session_id string | nil
|
||||
---@field post_render? fun(sidebar: avante.Sidebar)
|
||||
---@field message_button_handlers table<string, table<string, fun(arg: any)>>
|
||||
---@field expanded_message_uuids table<string, boolean>
|
||||
---@field tool_message_positions table<string, [integer, integer]>
|
||||
---@field skip_line_count integer | nil
|
||||
---@field current_tool_use_extmark_id integer | nil
|
||||
|
||||
---@param id integer the tabpage id retrieved from api.nvim_get_current_tabpage()
|
||||
function Sidebar:new(id)
|
||||
@@ -110,6 +116,10 @@ function Sidebar:new(id)
|
||||
_cached_history_lines = nil,
|
||||
_history_cache_invalidated = true,
|
||||
post_render = nil,
|
||||
message_handlers = {},
|
||||
tool_message_positions = {},
|
||||
expanded_message_ids = {},
|
||||
current_tool_use_extmark_id = nil,
|
||||
}, Sidebar)
|
||||
end
|
||||
|
||||
@@ -146,6 +156,10 @@ function Sidebar:reset()
|
||||
self.scroll = true
|
||||
self.old_result_lines = {}
|
||||
self.token_count = nil
|
||||
self.message_button_handlers = {}
|
||||
self.tool_message_positions = {}
|
||||
self.expanded_message_uuids = {}
|
||||
self.current_tool_use_extmark_id = nil
|
||||
end
|
||||
|
||||
---@class SidebarOpenOptions: AskOptions
|
||||
@@ -745,6 +759,16 @@ function Sidebar:is_cursor_in_user_request_block()
|
||||
return cursor_line >= block.start_line and cursor_line <= block.end_line
|
||||
end
|
||||
|
||||
function Sidebar:get_current_tool_use_message_uuid()
|
||||
local skip_line_count = self.skip_line_count or 0
|
||||
local cursor_line = api.nvim_win_get_cursor(self.containers.result.winid)[1]
|
||||
for message_uuid, positions in pairs(self.tool_message_positions) do
|
||||
if skip_line_count + positions[1] + 1 <= cursor_line and cursor_line <= skip_line_count + positions[2] then
|
||||
return message_uuid, positions
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
---@class AvanteCodeblock
|
||||
---@field start_line integer 1-indexed
|
||||
---@field end_line integer 1-indexed
|
||||
@@ -840,6 +864,13 @@ function Sidebar:retry_user_request()
|
||||
self.handle_submit(block.content)
|
||||
end
|
||||
|
||||
function Sidebar:handle_expand_message(message_uuid, expanded)
|
||||
Utils.debug("handle_expand_message", message_uuid, expanded)
|
||||
self.expanded_message_uuids[message_uuid] = expanded
|
||||
self._history_cache_invalidated = true
|
||||
self:update_content("")
|
||||
end
|
||||
|
||||
function Sidebar:edit_user_request()
|
||||
local block = self:get_current_user_request_block()
|
||||
if not block then return end
|
||||
@@ -1066,6 +1097,24 @@ function Sidebar:unbind_retry_user_request_key()
|
||||
end
|
||||
end
|
||||
|
||||
function Sidebar:bind_expand_tool_use_key(message_uuid)
|
||||
if self.containers.result then
|
||||
local expanded = self.expanded_message_uuids[message_uuid]
|
||||
vim.keymap.set(
|
||||
"n",
|
||||
Config.mappings.sidebar.expand_tool_use,
|
||||
function() self:handle_expand_message(message_uuid, not expanded) end,
|
||||
{ buffer = self.containers.result.bufnr, noremap = true, silent = true }
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
function Sidebar:unbind_expand_tool_use_key()
|
||||
if self.containers.result then
|
||||
pcall(vim.keymap.del, "n", Config.mappings.sidebar.expand_tool_use, { buffer = self.containers.result.bufnr })
|
||||
end
|
||||
end
|
||||
|
||||
function Sidebar:bind_edit_user_request_key()
|
||||
if self.containers.result then
|
||||
vim.keymap.set(
|
||||
@@ -1083,6 +1132,50 @@ function Sidebar:unbind_edit_user_request_key()
|
||||
end
|
||||
end
|
||||
|
||||
function Sidebar:render_tool_use_control_buttons()
|
||||
local function show_current_tool_use_control_buttons()
|
||||
if self.current_tool_use_extmark_id then
|
||||
api.nvim_buf_del_extmark(
|
||||
self.containers.result.bufnr,
|
||||
TOOL_MESSAGE_KEYBINDING_NAMESPACE,
|
||||
self.current_tool_use_extmark_id
|
||||
)
|
||||
end
|
||||
|
||||
local message_uuid, positions = self:get_current_tool_use_message_uuid()
|
||||
if not message_uuid then return end
|
||||
|
||||
local expanded = self.expanded_message_uuids[message_uuid]
|
||||
local skip_line_count = self.skip_line_count or 0
|
||||
|
||||
self.current_tool_use_extmark_id = api.nvim_buf_set_extmark(
|
||||
self.containers.result.bufnr,
|
||||
TOOL_MESSAGE_KEYBINDING_NAMESPACE,
|
||||
skip_line_count + positions[1] + 2,
|
||||
-1,
|
||||
{
|
||||
virt_text = {
|
||||
{
|
||||
string.format(" [%s: %s] ", Config.mappings.sidebar.expand_tool_use, expanded and "Collapse" or "Expand"),
|
||||
"AvanteInlineHint",
|
||||
},
|
||||
},
|
||||
virt_text_pos = "right_align",
|
||||
hl_group = "AvanteInlineHint",
|
||||
priority = PRIORITY,
|
||||
}
|
||||
)
|
||||
end
|
||||
local current_tool_use_message_uuid = self:get_current_tool_use_message_uuid()
|
||||
if current_tool_use_message_uuid then
|
||||
show_current_tool_use_control_buttons()
|
||||
self:bind_expand_tool_use_key(current_tool_use_message_uuid)
|
||||
else
|
||||
api.nvim_buf_clear_namespace(self.containers.result.bufnr, TOOL_MESSAGE_KEYBINDING_NAMESPACE, 0, -1)
|
||||
self:unbind_expand_tool_use_key()
|
||||
end
|
||||
end
|
||||
|
||||
function Sidebar:bind_sidebar_keys(codeblocks)
|
||||
---@param direction "next" | "prev"
|
||||
local function jump_to_codeblock(direction)
|
||||
@@ -1291,6 +1384,8 @@ function Sidebar:on_mount(opts)
|
||||
group = self.augroup,
|
||||
buffer = self.containers.result.bufnr,
|
||||
callback = function(ev)
|
||||
self:render_tool_use_control_buttons()
|
||||
|
||||
local in_codeblock = is_cursor_in_codeblock(codeblocks)
|
||||
|
||||
if in_codeblock then
|
||||
@@ -1628,8 +1723,10 @@ function Sidebar:update_content(content, opts)
|
||||
)
|
||||
|
||||
local history_lines
|
||||
local tool_message_positions
|
||||
if not self._cached_history_lines or self._history_cache_invalidated then
|
||||
history_lines = self.get_history_lines(self.chat_history, self.show_logo)
|
||||
history_lines, tool_message_positions = self:get_history_lines(self.chat_history, self.show_logo)
|
||||
self.tool_message_positions = tool_message_positions
|
||||
self._cached_history_lines = history_lines
|
||||
self._history_cache_invalidated = false
|
||||
else
|
||||
@@ -1648,7 +1745,10 @@ function Sidebar:update_content(content, opts)
|
||||
self:clear_state()
|
||||
|
||||
local skip_line_count = 0
|
||||
if self.show_logo then skip_line_count = self:render_logo() end
|
||||
if self.show_logo then
|
||||
skip_line_count = self:render_logo()
|
||||
self.skip_line_count = skip_line_count
|
||||
end
|
||||
|
||||
local bufnr = self.containers.result.bufnr
|
||||
Utils.unlock_buf(bufnr)
|
||||
@@ -1673,6 +1773,7 @@ function Sidebar:update_content(content, opts)
|
||||
|
||||
vim.schedule(function()
|
||||
self:render_state()
|
||||
self:render_tool_use_control_buttons()
|
||||
vim.defer_fn(function() vim.cmd("redraw") end, 10)
|
||||
end)
|
||||
|
||||
@@ -1743,10 +1844,11 @@ end
|
||||
---@param messages avante.HistoryMessage[]
|
||||
---@param ctx table
|
||||
---@param ignore_record_prefix boolean | nil
|
||||
---@param expanded boolean | nil
|
||||
---@return avante.ui.Line[]
|
||||
local function _get_message_lines(message, messages, ctx, ignore_record_prefix)
|
||||
local function _get_message_lines(message, messages, ctx, ignore_record_prefix, expanded)
|
||||
if message.visible == false then return {} end
|
||||
local lines = Render.message_to_lines(message, messages)
|
||||
local lines = Render.message_to_lines(message, messages, expanded)
|
||||
if message.is_user_submission and not ignore_record_prefix then
|
||||
ctx.selected_filepaths = message.selected_filepaths
|
||||
local text = table.concat(vim.tbl_map(function(line) return tostring(line) end, lines), "\n")
|
||||
@@ -1794,10 +1896,11 @@ local _message_to_lines_lru_cache = LRUCache:new(100)
|
||||
---@param messages avante.HistoryMessage[]
|
||||
---@param ctx table
|
||||
---@param ignore_record_prefix boolean | nil
|
||||
---@param expanded boolean | nil
|
||||
---@return avante.ui.Line[]
|
||||
local function get_message_lines(message, messages, ctx, ignore_record_prefix)
|
||||
local function get_message_lines(message, messages, ctx, ignore_record_prefix, expanded)
|
||||
if message.state == "generating" or message.is_calling then
|
||||
return _get_message_lines(message, messages, ctx, ignore_record_prefix)
|
||||
return _get_message_lines(message, messages, ctx, ignore_record_prefix, expanded)
|
||||
end
|
||||
local text_len = 0
|
||||
local content = message.message.content
|
||||
@@ -1814,52 +1917,51 @@ local function get_message_lines(message, messages, ctx, ignore_record_prefix)
|
||||
elseif type(content) == "string" then
|
||||
text_len = #content
|
||||
end
|
||||
local cache_key = message.uuid .. ":" .. tostring(text_len)
|
||||
local cache_key = message.uuid .. ":" .. tostring(text_len) .. ":" .. tostring(expanded == true)
|
||||
local cached_lines = _message_to_lines_lru_cache:get(cache_key)
|
||||
if cached_lines then return cached_lines end
|
||||
local lines = _get_message_lines(message, messages, ctx, ignore_record_prefix)
|
||||
local lines = _get_message_lines(message, messages, ctx, ignore_record_prefix, expanded)
|
||||
_message_to_lines_lru_cache:set(cache_key, lines)
|
||||
return lines
|
||||
end
|
||||
|
||||
---@param history avante.ChatHistory
|
||||
---@param ignore_record_prefix boolean | nil
|
||||
---@return avante.ui.Line[]
|
||||
function Sidebar.get_history_lines(history, ignore_record_prefix)
|
||||
---@return avante.ui.Line[] history_lines
|
||||
---@return table<string, [integer, integer]> tool_message_positions
|
||||
function Sidebar:get_history_lines(history, ignore_record_prefix)
|
||||
local history_messages = History.get_history_messages(history)
|
||||
local ctx = {}
|
||||
---@type avante.ui.Line[][]
|
||||
local group = {}
|
||||
---@type avante.ui.Line[]
|
||||
local res = {}
|
||||
local tool_message_positions = {}
|
||||
local is_first_user_submission = true
|
||||
for _, message in ipairs(history_messages) do
|
||||
local lines = get_message_lines(message, history_messages, ctx, ignore_record_prefix)
|
||||
local expanded = self.expanded_message_uuids[message.uuid]
|
||||
local lines = get_message_lines(message, history_messages, ctx, ignore_record_prefix, expanded)
|
||||
if #lines == 0 then goto continue end
|
||||
if message.is_user_submission then table.insert(group, {}) end
|
||||
local last_item = group[#group]
|
||||
if last_item == nil then
|
||||
table.insert(group, {})
|
||||
last_item = group[#group]
|
||||
if message.is_user_submission then
|
||||
if not is_first_user_submission then
|
||||
if ignore_record_prefix then
|
||||
res = vim.list_extend(res, { Line:new({ { "" } }), Line:new({ { "" } }) })
|
||||
else
|
||||
res = vim.list_extend(res, { Line:new({ { "" } }), Line:new({ { RESP_SEPARATOR } }), Line:new({ { "" } }) })
|
||||
end
|
||||
end
|
||||
is_first_user_submission = false
|
||||
end
|
||||
if message.message.role == "assistant" and not message.just_for_display and tostring(lines[1]) ~= "" then
|
||||
table.insert(lines, 1, Line:new({ { "" } }))
|
||||
table.insert(lines, 1, Line:new({ { "" } }))
|
||||
end
|
||||
last_item = vim.list_extend(last_item, lines)
|
||||
group[#group] = last_item
|
||||
if History.Helpers.is_tool_use_message(message) then
|
||||
tool_message_positions[message.uuid] = { #res, #res + #lines }
|
||||
end
|
||||
res = vim.list_extend(res, lines)
|
||||
::continue::
|
||||
end
|
||||
local res = {}
|
||||
for idx, item in ipairs(group) do
|
||||
if idx ~= 1 then
|
||||
if ignore_record_prefix then
|
||||
res = vim.list_extend(res, { Line:new({ { "" } }), Line:new({ { "" } }) })
|
||||
else
|
||||
res = vim.list_extend(res, { Line:new({ { "" } }), Line:new({ { RESP_SEPARATOR } }), Line:new({ { "" } }) })
|
||||
end
|
||||
end
|
||||
res = vim.list_extend(res, item)
|
||||
end
|
||||
table.insert(res, Line:new({ { "" } }))
|
||||
return res
|
||||
return res, tool_message_positions
|
||||
end
|
||||
|
||||
---@param message avante.HistoryMessage
|
||||
@@ -2086,6 +2188,10 @@ function Sidebar:new_chat(args, cb)
|
||||
self:reload_chat_history()
|
||||
self.current_state = nil
|
||||
self.acp_session_id = nil
|
||||
self.message_button_handlers = {}
|
||||
self.expanded_message_uuids = {}
|
||||
self.tool_message_positions = {}
|
||||
self.current_tool_use_extmark_id = nil
|
||||
self:update_content("New chat", { focus = false, scroll = false, callback = function() self:focus_input() end })
|
||||
--- goto first line then go to last line
|
||||
vim.schedule(function()
|
||||
|
||||
Reference in New Issue
Block a user