From 3a43621e17bf782b8c725f44979ca320ffbe661e Mon Sep 17 00:00:00 2001 From: yetone Date: Thu, 1 May 2025 19:16:15 +0800 Subject: [PATCH] refactor: tool calling ui (#1959) --- lua/avante/llm.lua | 2 +- lua/avante/sidebar.lua | 130 +++++++++++++++++++++++---- lua/avante/ui/line.lua | 5 ++ lua/avante/utils/init.lua | 185 ++++++++++++++++++++++++++++++++++---- 4 files changed, 289 insertions(+), 33 deletions(-) diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index fa2795d..615794c 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -81,7 +81,7 @@ function M.summarize_memory(prev_memory, history_messages, cb) if type(content) == "table" and content[1].type == "tool_use" then return false end return true end) - :map(function(msg) return msg.message.role .. ": " .. Utils.message_to_text(msg) end) + :map(function(msg) return msg.message.role .. ": " .. Utils.message_to_text(msg, history_messages) end) :totable() local conversation_text = table.concat(conversation_items, "\n") local user_prompt = "Here is the conversation so far:\n" diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 24c3369..65ff4f0 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -16,6 +16,7 @@ local RepoMap = require("avante.repo_map") local FileSelector = require("avante.file_selector") local LLMTools = require("avante.llm_tools") local HistoryMessage = require("avante.history_message") +local Line = require("avante.ui.line") local RESULT_BUF_NAME = "AVANTE_RESULT" local VIEW_BUFFER_UPDATED_PATTERN = "AvanteViewBufferUpdated" @@ -25,6 +26,7 @@ local SELECTED_FILES_HINT_NAMESPACE = api.nvim_create_namespace("AVANTE_SELECTED local SELECTED_FILES_ICON_NAMESPACE = api.nvim_create_namespace("AVANTE_SELECTED_FILES_ICON") local INPUT_HINT_NAMESPACE = api.nvim_create_namespace("AVANTE_INPUT_HINT") local STATE_NAMESPACE = api.nvim_create_namespace("AVANTE_STATE") +local RESULT_BUF_HL_NAMESPACE = api.nvim_create_namespace("AVANTE_RESULT_BUF_HL") local PRIORITY = (vim.hl or vim.highlight).priorities.user @@ -59,6 +61,7 @@ Sidebar.__index = Sidebar ---@field scroll boolean ---@field input_hint_window integer | nil ---@field ask_opts AskOptions +---@field old_result_lines avante.ui.Line[] ---@param id integer the tabpage id retrieved from api.nvim_get_current_tabpage() function Sidebar:new(id) @@ -86,6 +89,7 @@ function Sidebar:new(id) scroll = true, input_hint_window = nil, ask_opts = {}, + old_result_lines = {}, }, Sidebar) end @@ -121,6 +125,7 @@ function Sidebar:reset() self.selected_files_container = nil self.input_container = nil self.scroll = true + self.old_result_lines = {} end ---@class SidebarOpenOptions: AskOptions @@ -1510,18 +1515,24 @@ end function Sidebar:update_content(content, opts) if not self.result_container or not self.result_container.bufnr then return end opts = vim.tbl_deep_extend("force", { focus = false, scroll = self.scroll, callback = nil }, opts or {}) - local history_content = self.render_history_content(self.chat_history) - local contents = { history_content, content } - contents = vim.iter(contents):filter(function(item) return item ~= nil and item ~= "" end):totable() - content = table.concat(contents, "\n\n") + local history_lines = self.get_history_lines(self.chat_history) + if content ~= nil and content ~= "" then + table.insert(history_lines, Line:new({ { "" } })) + table.insert(history_lines, Line:new({ { content } })) + end vim.defer_fn(function() self:clear_state() local f = function() if not Utils.is_valid_container(self.result_container) then return end - local lines = vim.split(content, "\n") Utils.unlock_buf(self.result_container.bufnr) - Utils.update_buffer_content(self.result_container.bufnr, lines) + Utils.update_buffer_lines( + RESULT_BUF_HL_NAMESPACE, + self.result_container.bufnr, + self.old_result_lines, + history_lines + ) Utils.lock_buf(self.result_container.bufnr) + self.old_result_lines = history_lines api.nvim_set_option_value("filetype", "Avante", { buf = self.result_container.bufnr }) vim.schedule(function() vim.cmd("redraw") end) if opts.focus and not self:is_focused_on_result() then @@ -1593,11 +1604,96 @@ function Sidebar:get_layout() end ---@param message avante.HistoryMessage +---@param messages avante.HistoryMessage[] +---@param ctx table +---@return avante.ui.Line[] +local function get_message_lines(message, messages, ctx) + if message.visible == false then return {} end + local lines = Utils.message_to_lines(message, messages) + if message.is_user_submission then + ctx.selected_filepaths = message.selected_filepaths + local text = table.concat(vim.tbl_map(function(line) return tostring(line) end, lines), "\n") + local prefix = render_chat_record_prefix( + message.timestamp, + message.provider, + message.model, + text, + message.selected_filepaths, + message.selected_code + ) + local res = {} + for _, line_ in ipairs(vim.split(prefix, "\n")) do + table.insert(res, Line:new({ { line_ } })) + end + return res + end + if message.message.role == "user" then + local res = {} + for _, line_ in ipairs(lines) do + local sections = { { "> " } } + sections = vim.list_extend(sections, line_.sections) + table.insert(res, Line:new(sections)) + end + return res + end + if message.message.role == "assistant" then + local content = message.message.content + if type(content) == "table" and content[1].type == "tool_use" then return lines end + local text = table.concat(vim.tbl_map(function(line) return tostring(line) end, lines), "\n") + local transformed = transform_result_content(text, ctx.prev_filepath) + ctx.prev_filepath = transformed.current_filepath + local displayed_content = generate_display_content(transformed) + local res = {} + for _, line_ in ipairs(vim.split(displayed_content, "\n")) do + table.insert(res, Line:new({ { line_ } })) + end + return res + end + return lines +end + +---@param history avante.ChatHistory +---@return avante.ui.Line[] +function Sidebar.get_history_lines(history) + local history_messages = Utils.get_history_messages(history) + local ctx = {} + ---@type avante.ui.Line[][] + local group = {} + for _, message in ipairs(history_messages) do + local lines = get_message_lines(message, history_messages, ctx) + if #lines == 0 then goto continue end + if message.is_user_submission then table.insert(group, {}) end + local last_item = group[#group] + if last_item == nil then + table.insert(group, {}) + last_item = group[#group] + end + if message.message.role == "assistant" and not message.just_for_display and tostring(lines[1]) ~= "" then + table.insert(lines, 1, Line:new({ { "" } })) + table.insert(lines, 1, Line:new({ { "" } })) + end + last_item = vim.list_extend(last_item, lines) + group[#group] = last_item + ::continue:: + end + local res = {} + for idx, item in ipairs(group) do + if idx ~= 1 and idx ~= #group then + res = vim.list_extend(res, { Line:new({ { "" } }), Line:new({ { RESP_SEPARATOR } }), Line:new({ { "" } }) }) + end + res = vim.list_extend(res, item) + end + table.insert(res, Line:new({ { "" } })) + return res +end + +---@param message avante.HistoryMessage +---@param messages avante.HistoryMessage[] ---@param ctx table ---@return string | nil -local function render_message(message, ctx) +local function render_message(message, messages, ctx) if message.visible == false then return nil end - local text = Utils.message_to_text(message) + local text = Utils.message_to_text(message, messages) if text == "" then return nil end if message.is_user_submission then ctx.selected_filepaths = message.selected_filepaths @@ -1633,7 +1729,7 @@ function Sidebar.render_history_content(history) local ctx = {} local group = {} for _, message in ipairs(history_messages) do - local text = render_message(message, ctx) + local text = render_message(message, history_messages, ctx) if text == nil then goto continue end if message.is_user_submission then table.insert(group, {}) end local last_item = group[#group] @@ -1695,6 +1791,7 @@ function Sidebar:get_content_between_separators() end function Sidebar:clear_history(args, cb) + self.current_state = nil local chat_history = Path.history.load(self.code.bufnr) if next(chat_history) ~= nil then chat_history.messages = {} @@ -1763,6 +1860,7 @@ function Sidebar:new_chat(args, cb) local history = Path.history.new(self.code.bufnr) Path.history.save(self.code.bufnr, history) self:reload_chat_history() + self.current_state = nil self:update_content("New chat", { focus = false, scroll = false, callback = function() self:focus_input() end }) if cb then cb(args) end end @@ -2144,10 +2242,10 @@ function Sidebar:create_input_container() end end - local model = Config.has_provider(Config.provider) and Config.get_provider_config(Config.provider).model - or "default" - - local timestamp = Utils.get_timestamp() + -- local model = Config.has_provider(Config.provider) and Config.get_provider_config(Config.provider).model + -- or "default" + -- + -- local timestamp = Utils.get_timestamp() local selected_filepaths = self.file_selector:get_selected_filepaths() @@ -2161,14 +2259,14 @@ function Sidebar:create_input_container() } end - local content_prefix = - render_chat_record_prefix(timestamp, Config.provider, model, request, selected_filepaths, selected_code) + -- local content_prefix = + -- render_chat_record_prefix(timestamp, Config.provider, model, request, selected_filepaths, selected_code) --- HACK: we need to set focus to true and scroll to false to --- prevent the cursor from jumping to the bottom of the --- buffer at the beginning self:update_content("", { focus = true, scroll = false }) - self:update_content(content_prefix) + -- self:update_content(content_prefix) ---stop scroll when user presses j/k keys local function on_j() diff --git a/lua/avante/ui/line.lua b/lua/avante/ui/line.lua index 33bef06..4a06c26 100644 --- a/lua/avante/ui/line.lua +++ b/lua/avante/ui/line.lua @@ -54,4 +54,9 @@ function M:__tostring() return table.concat(content, "") end +function M:__eq(other) + if not other or type(other) ~= "table" or not other.sections then return false end + return vim.deep_equal(self.sections, other.sections) +end + return M diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index 4c08c49..7df1882 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -966,6 +966,36 @@ function M.get_or_create_buffer_with_filepath(filepath) return buf end +---@param old_lines avante.ui.Line[] +---@param new_lines avante.ui.Line[] +---@return { start_line: integer, end_line: integer, content: avante.ui.Line[] }[] +local function get_lines_diff(old_lines, new_lines) + local diffs = {} + local prev_diff_idx = nil + for i, line in ipairs(new_lines) do + if line ~= old_lines[i] then + if prev_diff_idx == nil then prev_diff_idx = i end + else + if prev_diff_idx ~= nil then + local content = vim.list_slice(new_lines, prev_diff_idx, i - 1) + table.insert(diffs, { start_line = prev_diff_idx, end_line = i, content = content }) + prev_diff_idx = nil + end + end + end + if prev_diff_idx ~= nil then + table.insert( + diffs, + { start_line = prev_diff_idx, end_line = #new_lines + 1, content = vim.list_slice(new_lines, prev_diff_idx) } + ) + end + if #new_lines < #old_lines then + table.insert(diffs, { start_line = #new_lines + 1, end_line = #old_lines + 1, content = {} }) + end + table.sort(diffs, function(a, b) return a.start_line > b.start_line end) + return diffs +end + ---@param bufnr integer ---@param new_lines string[] ---@return { start_line: integer, end_line: integer, content: string[] }[] @@ -1008,6 +1038,24 @@ function M.update_buffer_content(bufnr, new_lines) end end +---@param ns_id number +---@param bufnr integer +---@param old_lines avante.ui.Line[] +---@param new_lines avante.ui.Line[] +function M.update_buffer_lines(ns_id, bufnr, old_lines, new_lines) + local diffs = get_lines_diff(old_lines, new_lines) + if #diffs == 0 then return end + for _, diff in ipairs(diffs) do + local lines = diff.content + -- M.debug("lines", lines) + local text_lines = vim.tbl_map(function(line) return tostring(line) end, lines) + vim.api.nvim_buf_set_lines(bufnr, diff.start_line - 1, diff.end_line - 1, false, text_lines) + for i, line in ipairs(lines) do + line:set_highlights(ns_id, bufnr, diff.start_line + i - 2) + end + end +end + local severity = { [1] = "ERROR", [2] = "WARNING", @@ -1365,35 +1413,140 @@ function M.uuid() end) end ----@param item AvanteLLMMessageContentItem ---@param message avante.HistoryMessage ----@return string -function M.message_content_item_to_text(item, message) - if type(item) == "string" then return item end - if type(item) == "table" then - if item.type == "text" then return item.text end - if item.type == "image" then return "![image](" .. item.source.media_type .. ": " .. item.source.data .. ")" end - if item.type == "tool_use" then - local pieces = {} - table.insert(pieces, string.format("[%s]: calling", item.name)) - for _, log in ipairs(message.tool_use_logs or {}) do - table.insert(pieces, log) +---@param messages avante.HistoryMessage[] +---@return avante.HistoryMessage | nil +function M.get_tool_result_message(message, messages) + local content = message.message.content + if type(content) == "string" then return nil end + if vim.islist(content) then + local tool_id = nil + for _, item in ipairs(content) do + if item.type == "tool_use" then + tool_id = item.id + break + end + end + if not tool_id then return nil end + for _, message_ in ipairs(messages) do + local content_ = message_.message.content + if type(content_) == "table" and content_[1].type == "tool_result" and content_[1].tool_use_id == tool_id then + return message_ end - return table.concat(pieces, "\n") end end - return "" + return nil +end + +---@param text string +---@param hl string | nil +---@return avante.ui.Line[] +function M.text_to_lines(text, hl) + local Line = require("avante.ui.line") + local text_lines = vim.split(text, "\n") + local lines = {} + for _, text_line in ipairs(text_lines) do + local piece = { text_line } + if hl then table.insert(piece, hl) end + table.insert(lines, Line:new({ piece })) + end + return lines +end + +---@param item AvanteLLMMessageContentItem +---@param message avante.HistoryMessage +---@param messages avante.HistoryMessage[] +---@return avante.ui.Line[] +function M.message_content_item_to_lines(item, message, messages) + local Line = require("avante.ui.line") + if type(item) == "string" then return M.text_to_lines(item) end + if type(item) == "table" then + if item.type == "text" then return M.text_to_lines(item.text) end + if item.type == "image" then + return { Line:new({ { "![image](" .. item.source.media_type .. ": " .. item.source.data .. ")" } }) } + end + if item.type == "tool_use" then + local lines = {} + local state = "generating" + local hl = "AvanteStateSpinnerToolCalling" + if message.state == "generated" then + local tool_result_message = M.get_tool_result_message(message, messages) + if tool_result_message then + local tool_result = tool_result_message.message.content[1] + if tool_result.is_error then + state = "failed" + hl = "AvanteStateSpinnerFailed" + else + state = "succeeded" + hl = "AvanteStateSpinnerSucceeded" + end + end + end + table.insert( + lines, + Line:new({ { "╭─" }, { " " }, { string.format(" %s ", item.name), hl }, { string.format(" %s", state) } }) + ) + for idx, log in ipairs(message.tool_use_logs or {}) do + local log_ = M.trim(log, { prefix = string.format("[%s]: ", item.name) }) + local lines_ = vim.split(log_, "\n") + if idx ~= #(message.tool_use_logs or {}) then + for _, line_ in ipairs(lines_) do + table.insert(lines, Line:new({ { "│" }, { string.format(" %s", line_) } })) + end + else + for idx_, line_ in ipairs(lines_) do + if idx_ ~= #lines_ then + table.insert(lines, Line:new({ { "│" }, { string.format(" %s", line_) } })) + else + table.insert(lines, Line:new({ { "╰─" }, { string.format(" %s", line_) } })) + end + end + end + end + return lines + end + end + return {} end ---@param message avante.HistoryMessage +---@param messages avante.HistoryMessage[] +---@return avante.ui.Line[] +function M.message_to_lines(message, messages) + local Line = require("avante.ui.line") + local content = message.message.content + if type(content) == "string" then return { Line:new({ { content } }) } end + if vim.islist(content) then + local lines = {} + for _, item in ipairs(content) do + local lines_ = M.message_content_item_to_lines(item, message, messages) + lines = vim.list_extend(lines, lines_) + end + return lines + end + return {} +end + +---@param item AvanteLLMMessageContentItem +---@param message avante.HistoryMessage +---@param messages avante.HistoryMessage[] ---@return string -function M.message_to_text(message) +function M.message_content_item_to_text(item, message, messages) + local lines = M.message_content_item_to_lines(item, message, messages) + if #lines == 0 then return "" end + return table.concat(vim.tbl_map(function(line) return tostring(line) end, lines), "\n") +end + +---@param message avante.HistoryMessage +---@param messages avante.HistoryMessage[] +---@return string +function M.message_to_text(message, messages) local content = message.message.content if type(content) == "string" then return content end if vim.islist(content) then local pieces = {} for _, item in ipairs(content) do - local text = M.message_content_item_to_text(item, message) + local text = M.message_content_item_to_text(item, message, messages) if text ~= "" then table.insert(pieces, text) end end return table.concat(pieces, "\n")