From 15d19518b1e4a517aa120a29d398b63bf7a13dbd Mon Sep 17 00:00:00 2001 From: yetone Date: Mon, 1 Sep 2025 22:31:38 +0800 Subject: [PATCH] feat: better tool ui (#2668) --- lua/avante/history/render.lua | 83 +++++++++++++++++++++++++--------- lua/avante/libs/acp_client.lua | 12 +++-- lua/avante/llm.lua | 36 +++++---------- lua/avante/types.lua | 1 + 4 files changed, 84 insertions(+), 48 deletions(-) diff --git a/lua/avante/history/render.lua b/lua/avante/history/render.lua index f4dd758..e381ce8 100644 --- a/lua/avante/history/render.lua +++ b/lua/avante/history/render.lua @@ -224,7 +224,8 @@ function M.get_content_lines(content, decoration, truncate) end end elseif content_item.type == "diff" then - table.insert(lines, Line:new({ { decoration }, { "Path: " .. content_item.path } })) + local relative_path = Utils.relative_path(content_item.path) + table.insert(lines, Line:new({ { decoration }, { "Path: " .. relative_path } })) local lines_ = M.get_diff_lines(content_item.oldText, content_item.newText, decoration, truncate) lines = vim.list_extend(lines, lines_) end @@ -233,6 +234,61 @@ function M.get_content_lines(content, decoration, truncate) return lines end +---@param message avante.HistoryMessage +---@return string tool_name +---@return string | nil error +function M.get_tool_display_name(message) + local content = message.message.content + if type(content) ~= "table" then return "", "expected message content to be a table" end + + ---@cast content AvanteLLMMessageContentItem[] + + if not islist(content) then return "", "expected message content to be a list" end + + local item = message.message.content[1] + + local tool_name = item.name + if message.displayed_tool_name then + tool_name = message.displayed_tool_name + else + local param + if item.input and type(item.input) == "table" then + local path + if type(item.input.path) == "string" then path = item.input.path end + if type(item.input.rel_path) == "string" then path = item.input.rel_path end + if type(item.input.filepath) == "string" then path = item.input.filepath end + if type(item.input.file_path) == "string" then path = item.input.file_path end + if type(item.input.query) == "string" then param = item.input.query end + if type(item.input.pattern) == "string" then param = item.input.pattern end + if type(item.input.command) == "string" then + param = item.input.command + local pieces = vim.split(param, "\n") + if #pieces > 1 then param = pieces[1] .. "..." end + end + if path then + local relative_path = Utils.relative_path(path) + param = relative_path + end + end + if not param and message.acp_tool_call then + if message.acp_tool_call.locations then + for _, location in ipairs(message.acp_tool_call.locations) do + if location.path then + local relative_path = Utils.relative_path(location.path) + param = relative_path + break + end + end + end + end + if param then tool_name = item.name .. "(" .. param .. ")" end + end + + ---@cast tool_name string + + return tool_name, nil +end + ---Converts a tool invocation into format suitable for UI ---@param item AvanteLLMMessageContentItem ---@param message avante.HistoryMessage @@ -242,29 +298,14 @@ local function tool_to_lines(item, message, messages) -- local logs = message.tool_use_logs local lines = {} - local tool_name = item.name + local tool_name, error = M.get_tool_display_name(message) + if error then + table.insert(lines, Line:new({ { "❌ " }, { error } })) + return lines + end local rest_input_text_lines = {} - if message.displayed_tool_name then - tool_name = message.displayed_tool_name - else - if item.input and type(item.input) == "table" then - local param - if type(item.input.path) == "string" then param = item.input.path end - if type(item.input.rel_path) == "string" then param = item.input.rel_path end - if type(item.input.filepath) == "string" then param = item.input.filepath end - if type(item.input.query) == "string" then param = item.input.query end - if type(item.input.pattern) == "string" then param = item.input.pattern end - if type(item.input.command) == "string" then - param = item.input.command - local pieces = vim.split(param, "\n") - if #pieces > 1 then param = pieces[1] .. "..." end - end - if param then tool_name = item.name .. "(" .. param .. ")" end - end - end - local result = Helpers.get_tool_result(item.id, messages) local state if not result then diff --git a/lua/avante/libs/acp_client.lua b/lua/avante/libs/acp_client.lua index 19c0074..c736c87 100644 --- a/lua/avante/libs/acp_client.lua +++ b/lua/avante/libs/acp_client.lua @@ -97,17 +97,17 @@ ---@class avante.acp.BaseToolCallContent ---@field type "content" | "diff" ----@class avante.acp.ToolCallContentBlock : avante.acp.BaseToolCallContent +---@class avante.acp.ToolCallRegularContent : avante.acp.BaseToolCallContent ---@field type "content" ---@field content ACPContent ----@class avante.acp.ToolCallDiff : avante.acp.BaseToolCallContent +---@class avante.acp.ToolCallDiffContent : avante.acp.BaseToolCallContent ---@field type "diff" ---@field path string ---@field oldText string|nil ---@field newText string ----@alias ACPToolCallContent avante.acp.ToolCallContentBlock | avante.acp.ToolCallDiff +---@alias ACPToolCallContent avante.acp.ToolCallRegularContent | avante.acp.ToolCallDiffContent ---@class avante.acp.ToolCallLocation ---@field path string @@ -530,6 +530,12 @@ end ---@param method string ---@param params table function ACPClient:_handle_notification(message_id, method, params) + -- local f = io.open("/tmp/session.txt", "a") + -- if f then + -- f:write("method: " .. method .. "\n") + -- f:write(vim.inspect(params) .. "\n" .. string.rep("=", 100) .. "\n") + -- f:close() + -- end if method == "session/update" then self:_handle_session_update(params) elseif method == "session/request_permission" then diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index b75f9cd..9f94320 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -800,6 +800,7 @@ end ---@param opts AvanteLLMStreamOptions function M._stream_acp(opts) + local Render = require("avante.history.render") ---@type table local tool_call_messages = {} local acp_provider = Config.acp_providers[Config.provider] @@ -811,20 +812,13 @@ function M._stream_acp(opts) local message = History.Message:new("assistant", { type = "tool_use", id = update.toolCallId, - name = update.kind .. "(" .. update.title .. ")", + name = update.kind, + input = update.rawInput or {}, }) + message.acp_tool_call = update if update.status == "pending" or update.status == "in_progress" then message.is_calling = true end tool_call_messages[update.toolCallId] = message if update.rawInput then - local path = update.rawInput.path or update.rawInput.file_path - if path then - local relative_path = Utils.relative_path(path) - message.displayed_tool_name = update.title .. "(" .. relative_path .. ")" - end - local pattern = update.rawInput.pattern or update.rawInput.search - if pattern then message.displayed_tool_name = update.title .. "(" .. pattern .. ")" end - local command = update.rawInput.command or update.rawInput.command_line - if command then message.displayed_tool_name = update.title .. "(" .. command .. ")" end local description = update.rawInput.description if description then message.tool_use_logs = message.tool_use_logs or {} @@ -909,19 +903,7 @@ function M._stream_acp(opts) id = update.toolCallId, name = "", }) - local update_content = update.content - if type(update_content) == "table" then - for _, item in ipairs(update_content) do - if item.path then - local relative_path = Utils.relative_path(item.path) - tool_call_message.displayed_tool_name = "Edit(" .. relative_path .. ")" - break - end - end - end - if not tool_call_message.displayed_tool_name then - tool_call_message.displayed_tool_name = update.toolCallId - end + tool_call_message.acp_tool_call = update end tool_call_message.tool_use_logs = tool_call_message.tool_use_logs or {} tool_call_message.tool_use_log_lines = tool_call_message.tool_use_log_lines or {} @@ -968,8 +950,14 @@ function M._stream_acp(opts) callback(choice.id) end + local tool_name, error = Render.get_tool_display_name(message) + if error then + Utils.error(error) + tool_name = message.message.content[1].name + end + local selector = Selector:new({ - title = message.displayed_tool_name or message.message.content[1].name, + title = tool_name, items = items, default_item_id = default_item and default_item.name or nil, provider = Config.selector.provider, diff --git a/lua/avante/types.lua b/lua/avante/types.lua index c1915d4..05e87f3 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -117,6 +117,7 @@ vim.g.avante_login = vim.g.avante_login ---@field turn_id string | nil ---@field is_calling boolean | nil ---@field original_content AvanteLLMMessageContent | nil +---@field acp_tool_call? avante.acp.ToolCall --- ---@class AvanteLLMToolResult ---@field tool_name string