feat: better tool ui (#2668)

This commit is contained in:
yetone
2025-09-01 22:31:38 +08:00
committed by GitHub
parent b3b9327fec
commit 15d19518b1
4 changed files with 84 additions and 48 deletions

View File

@@ -224,7 +224,8 @@ function M.get_content_lines(content, decoration, truncate)
end
end
elseif content_item.type == "diff" then
table.insert(lines, Line:new({ { decoration }, { "Path: " .. content_item.path } }))
local relative_path = Utils.relative_path(content_item.path)
table.insert(lines, Line:new({ { decoration }, { "Path: " .. relative_path } }))
local lines_ = M.get_diff_lines(content_item.oldText, content_item.newText, decoration, truncate)
lines = vim.list_extend(lines, lines_)
end
@@ -233,6 +234,61 @@ function M.get_content_lines(content, decoration, truncate)
return lines
end
---@param message avante.HistoryMessage
---@return string tool_name
---@return string | nil error
function M.get_tool_display_name(message)
local content = message.message.content
if type(content) ~= "table" then return "", "expected message content to be a table" end
---@cast content AvanteLLMMessageContentItem[]
if not islist(content) then return "", "expected message content to be a list" end
local item = message.message.content[1]
local tool_name = item.name
if message.displayed_tool_name then
tool_name = message.displayed_tool_name
else
local param
if item.input and type(item.input) == "table" then
local path
if type(item.input.path) == "string" then path = item.input.path end
if type(item.input.rel_path) == "string" then path = item.input.rel_path end
if type(item.input.filepath) == "string" then path = item.input.filepath end
if type(item.input.file_path) == "string" then path = item.input.file_path end
if type(item.input.query) == "string" then param = item.input.query end
if type(item.input.pattern) == "string" then param = item.input.pattern end
if type(item.input.command) == "string" then
param = item.input.command
local pieces = vim.split(param, "\n")
if #pieces > 1 then param = pieces[1] .. "..." end
end
if path then
local relative_path = Utils.relative_path(path)
param = relative_path
end
end
if not param and message.acp_tool_call then
if message.acp_tool_call.locations then
for _, location in ipairs(message.acp_tool_call.locations) do
if location.path then
local relative_path = Utils.relative_path(location.path)
param = relative_path
break
end
end
end
end
if param then tool_name = item.name .. "(" .. param .. ")" end
end
---@cast tool_name string
return tool_name, nil
end
---Converts a tool invocation into format suitable for UI
---@param item AvanteLLMMessageContentItem
---@param message avante.HistoryMessage
@@ -242,29 +298,14 @@ local function tool_to_lines(item, message, messages)
-- local logs = message.tool_use_logs
local lines = {}
local tool_name = item.name
local tool_name, error = M.get_tool_display_name(message)
if error then
table.insert(lines, Line:new({ { "" }, { error } }))
return lines
end
local rest_input_text_lines = {}
if message.displayed_tool_name then
tool_name = message.displayed_tool_name
else
if item.input and type(item.input) == "table" then
local param
if type(item.input.path) == "string" then param = item.input.path end
if type(item.input.rel_path) == "string" then param = item.input.rel_path end
if type(item.input.filepath) == "string" then param = item.input.filepath end
if type(item.input.query) == "string" then param = item.input.query end
if type(item.input.pattern) == "string" then param = item.input.pattern end
if type(item.input.command) == "string" then
param = item.input.command
local pieces = vim.split(param, "\n")
if #pieces > 1 then param = pieces[1] .. "..." end
end
if param then tool_name = item.name .. "(" .. param .. ")" end
end
end
local result = Helpers.get_tool_result(item.id, messages)
local state
if not result then

View File

@@ -97,17 +97,17 @@
---@class avante.acp.BaseToolCallContent
---@field type "content" | "diff"
---@class avante.acp.ToolCallContentBlock : avante.acp.BaseToolCallContent
---@class avante.acp.ToolCallRegularContent : avante.acp.BaseToolCallContent
---@field type "content"
---@field content ACPContent
---@class avante.acp.ToolCallDiff : avante.acp.BaseToolCallContent
---@class avante.acp.ToolCallDiffContent : avante.acp.BaseToolCallContent
---@field type "diff"
---@field path string
---@field oldText string|nil
---@field newText string
---@alias ACPToolCallContent avante.acp.ToolCallContentBlock | avante.acp.ToolCallDiff
---@alias ACPToolCallContent avante.acp.ToolCallRegularContent | avante.acp.ToolCallDiffContent
---@class avante.acp.ToolCallLocation
---@field path string
@@ -530,6 +530,12 @@ end
---@param method string
---@param params table
function ACPClient:_handle_notification(message_id, method, params)
-- local f = io.open("/tmp/session.txt", "a")
-- if f then
-- f:write("method: " .. method .. "\n")
-- f:write(vim.inspect(params) .. "\n" .. string.rep("=", 100) .. "\n")
-- f:close()
-- end
if method == "session/update" then
self:_handle_session_update(params)
elseif method == "session/request_permission" then

View File

@@ -800,6 +800,7 @@ end
---@param opts AvanteLLMStreamOptions
function M._stream_acp(opts)
local Render = require("avante.history.render")
---@type table<string, avante.HistoryMessage>
local tool_call_messages = {}
local acp_provider = Config.acp_providers[Config.provider]
@@ -811,20 +812,13 @@ function M._stream_acp(opts)
local message = History.Message:new("assistant", {
type = "tool_use",
id = update.toolCallId,
name = update.kind .. "(" .. update.title .. ")",
name = update.kind,
input = update.rawInput or {},
})
message.acp_tool_call = update
if update.status == "pending" or update.status == "in_progress" then message.is_calling = true end
tool_call_messages[update.toolCallId] = message
if update.rawInput then
local path = update.rawInput.path or update.rawInput.file_path
if path then
local relative_path = Utils.relative_path(path)
message.displayed_tool_name = update.title .. "(" .. relative_path .. ")"
end
local pattern = update.rawInput.pattern or update.rawInput.search
if pattern then message.displayed_tool_name = update.title .. "(" .. pattern .. ")" end
local command = update.rawInput.command or update.rawInput.command_line
if command then message.displayed_tool_name = update.title .. "(" .. command .. ")" end
local description = update.rawInput.description
if description then
message.tool_use_logs = message.tool_use_logs or {}
@@ -909,19 +903,7 @@ function M._stream_acp(opts)
id = update.toolCallId,
name = "",
})
local update_content = update.content
if type(update_content) == "table" then
for _, item in ipairs(update_content) do
if item.path then
local relative_path = Utils.relative_path(item.path)
tool_call_message.displayed_tool_name = "Edit(" .. relative_path .. ")"
break
end
end
end
if not tool_call_message.displayed_tool_name then
tool_call_message.displayed_tool_name = update.toolCallId
end
tool_call_message.acp_tool_call = update
end
tool_call_message.tool_use_logs = tool_call_message.tool_use_logs or {}
tool_call_message.tool_use_log_lines = tool_call_message.tool_use_log_lines or {}
@@ -968,8 +950,14 @@ function M._stream_acp(opts)
callback(choice.id)
end
local tool_name, error = Render.get_tool_display_name(message)
if error then
Utils.error(error)
tool_name = message.message.content[1].name
end
local selector = Selector:new({
title = message.displayed_tool_name or message.message.content[1].name,
title = tool_name,
items = items,
default_item_id = default_item and default_item.name or nil,
provider = Config.selector.provider,

View File

@@ -117,6 +117,7 @@ vim.g.avante_login = vim.g.avante_login
---@field turn_id string | nil
---@field is_calling boolean | nil
---@field original_content AvanteLLMMessageContent | nil
---@field acp_tool_call? avante.acp.ToolCall
---
---@class AvanteLLMToolResult
---@field tool_name string