refactor: tool calling ui (#1959)

This commit is contained in:
yetone
2025-05-01 19:16:15 +08:00
committed by GitHub
parent 448efbb842
commit 3a43621e17
4 changed files with 289 additions and 33 deletions

View File

@@ -81,7 +81,7 @@ function M.summarize_memory(prev_memory, history_messages, cb)
if type(content) == "table" and content[1].type == "tool_use" then return false end
return true
end)
:map(function(msg) return msg.message.role .. ": " .. Utils.message_to_text(msg) end)
:map(function(msg) return msg.message.role .. ": " .. Utils.message_to_text(msg, history_messages) end)
:totable()
local conversation_text = table.concat(conversation_items, "\n")
local user_prompt = "Here is the conversation so far:\n"

View File

@@ -16,6 +16,7 @@ local RepoMap = require("avante.repo_map")
local FileSelector = require("avante.file_selector")
local LLMTools = require("avante.llm_tools")
local HistoryMessage = require("avante.history_message")
local Line = require("avante.ui.line")
local RESULT_BUF_NAME = "AVANTE_RESULT"
local VIEW_BUFFER_UPDATED_PATTERN = "AvanteViewBufferUpdated"
@@ -25,6 +26,7 @@ local SELECTED_FILES_HINT_NAMESPACE = api.nvim_create_namespace("AVANTE_SELECTED
local SELECTED_FILES_ICON_NAMESPACE = api.nvim_create_namespace("AVANTE_SELECTED_FILES_ICON")
local INPUT_HINT_NAMESPACE = api.nvim_create_namespace("AVANTE_INPUT_HINT")
local STATE_NAMESPACE = api.nvim_create_namespace("AVANTE_STATE")
local RESULT_BUF_HL_NAMESPACE = api.nvim_create_namespace("AVANTE_RESULT_BUF_HL")
local PRIORITY = (vim.hl or vim.highlight).priorities.user
@@ -59,6 +61,7 @@ Sidebar.__index = Sidebar
---@field scroll boolean
---@field input_hint_window integer | nil
---@field ask_opts AskOptions
---@field old_result_lines avante.ui.Line[]
---@param id integer the tabpage id retrieved from api.nvim_get_current_tabpage()
function Sidebar:new(id)
@@ -86,6 +89,7 @@ function Sidebar:new(id)
scroll = true,
input_hint_window = nil,
ask_opts = {},
old_result_lines = {},
}, Sidebar)
end
@@ -121,6 +125,7 @@ function Sidebar:reset()
self.selected_files_container = nil
self.input_container = nil
self.scroll = true
self.old_result_lines = {}
end
---@class SidebarOpenOptions: AskOptions
@@ -1510,18 +1515,24 @@ end
function Sidebar:update_content(content, opts)
if not self.result_container or not self.result_container.bufnr then return end
opts = vim.tbl_deep_extend("force", { focus = false, scroll = self.scroll, callback = nil }, opts or {})
local history_content = self.render_history_content(self.chat_history)
local contents = { history_content, content }
contents = vim.iter(contents):filter(function(item) return item ~= nil and item ~= "" end):totable()
content = table.concat(contents, "\n\n")
local history_lines = self.get_history_lines(self.chat_history)
if content ~= nil and content ~= "" then
table.insert(history_lines, Line:new({ { "" } }))
table.insert(history_lines, Line:new({ { content } }))
end
vim.defer_fn(function()
self:clear_state()
local f = function()
if not Utils.is_valid_container(self.result_container) then return end
local lines = vim.split(content, "\n")
Utils.unlock_buf(self.result_container.bufnr)
Utils.update_buffer_content(self.result_container.bufnr, lines)
Utils.update_buffer_lines(
RESULT_BUF_HL_NAMESPACE,
self.result_container.bufnr,
self.old_result_lines,
history_lines
)
Utils.lock_buf(self.result_container.bufnr)
self.old_result_lines = history_lines
api.nvim_set_option_value("filetype", "Avante", { buf = self.result_container.bufnr })
vim.schedule(function() vim.cmd("redraw") end)
if opts.focus and not self:is_focused_on_result() then
@@ -1593,11 +1604,96 @@ function Sidebar:get_layout()
end
---@param message avante.HistoryMessage
---@param messages avante.HistoryMessage[]
---@param ctx table
---@return avante.ui.Line[]
local function get_message_lines(message, messages, ctx)
if message.visible == false then return {} end
local lines = Utils.message_to_lines(message, messages)
if message.is_user_submission then
ctx.selected_filepaths = message.selected_filepaths
local text = table.concat(vim.tbl_map(function(line) return tostring(line) end, lines), "\n")
local prefix = render_chat_record_prefix(
message.timestamp,
message.provider,
message.model,
text,
message.selected_filepaths,
message.selected_code
)
local res = {}
for _, line_ in ipairs(vim.split(prefix, "\n")) do
table.insert(res, Line:new({ { line_ } }))
end
return res
end
if message.message.role == "user" then
local res = {}
for _, line_ in ipairs(lines) do
local sections = { { "> " } }
sections = vim.list_extend(sections, line_.sections)
table.insert(res, Line:new(sections))
end
return res
end
if message.message.role == "assistant" then
local content = message.message.content
if type(content) == "table" and content[1].type == "tool_use" then return lines end
local text = table.concat(vim.tbl_map(function(line) return tostring(line) end, lines), "\n")
local transformed = transform_result_content(text, ctx.prev_filepath)
ctx.prev_filepath = transformed.current_filepath
local displayed_content = generate_display_content(transformed)
local res = {}
for _, line_ in ipairs(vim.split(displayed_content, "\n")) do
table.insert(res, Line:new({ { line_ } }))
end
return res
end
return lines
end
---@param history avante.ChatHistory
---@return avante.ui.Line[]
function Sidebar.get_history_lines(history)
local history_messages = Utils.get_history_messages(history)
local ctx = {}
---@type avante.ui.Line[][]
local group = {}
for _, message in ipairs(history_messages) do
local lines = get_message_lines(message, history_messages, ctx)
if #lines == 0 then goto continue end
if message.is_user_submission then table.insert(group, {}) end
local last_item = group[#group]
if last_item == nil then
table.insert(group, {})
last_item = group[#group]
end
if message.message.role == "assistant" and not message.just_for_display and tostring(lines[1]) ~= "" then
table.insert(lines, 1, Line:new({ { "" } }))
table.insert(lines, 1, Line:new({ { "" } }))
end
last_item = vim.list_extend(last_item, lines)
group[#group] = last_item
::continue::
end
local res = {}
for idx, item in ipairs(group) do
if idx ~= 1 and idx ~= #group then
res = vim.list_extend(res, { Line:new({ { "" } }), Line:new({ { RESP_SEPARATOR } }), Line:new({ { "" } }) })
end
res = vim.list_extend(res, item)
end
table.insert(res, Line:new({ { "" } }))
return res
end
---@param message avante.HistoryMessage
---@param messages avante.HistoryMessage[]
---@param ctx table
---@return string | nil
local function render_message(message, ctx)
local function render_message(message, messages, ctx)
if message.visible == false then return nil end
local text = Utils.message_to_text(message)
local text = Utils.message_to_text(message, messages)
if text == "" then return nil end
if message.is_user_submission then
ctx.selected_filepaths = message.selected_filepaths
@@ -1633,7 +1729,7 @@ function Sidebar.render_history_content(history)
local ctx = {}
local group = {}
for _, message in ipairs(history_messages) do
local text = render_message(message, ctx)
local text = render_message(message, history_messages, ctx)
if text == nil then goto continue end
if message.is_user_submission then table.insert(group, {}) end
local last_item = group[#group]
@@ -1695,6 +1791,7 @@ function Sidebar:get_content_between_separators()
end
function Sidebar:clear_history(args, cb)
self.current_state = nil
local chat_history = Path.history.load(self.code.bufnr)
if next(chat_history) ~= nil then
chat_history.messages = {}
@@ -1763,6 +1860,7 @@ function Sidebar:new_chat(args, cb)
local history = Path.history.new(self.code.bufnr)
Path.history.save(self.code.bufnr, history)
self:reload_chat_history()
self.current_state = nil
self:update_content("New chat", { focus = false, scroll = false, callback = function() self:focus_input() end })
if cb then cb(args) end
end
@@ -2144,10 +2242,10 @@ function Sidebar:create_input_container()
end
end
local model = Config.has_provider(Config.provider) and Config.get_provider_config(Config.provider).model
or "default"
local timestamp = Utils.get_timestamp()
-- local model = Config.has_provider(Config.provider) and Config.get_provider_config(Config.provider).model
-- or "default"
--
-- local timestamp = Utils.get_timestamp()
local selected_filepaths = self.file_selector:get_selected_filepaths()
@@ -2161,14 +2259,14 @@ function Sidebar:create_input_container()
}
end
local content_prefix =
render_chat_record_prefix(timestamp, Config.provider, model, request, selected_filepaths, selected_code)
-- local content_prefix =
-- render_chat_record_prefix(timestamp, Config.provider, model, request, selected_filepaths, selected_code)
--- HACK: we need to set focus to true and scroll to false to
--- prevent the cursor from jumping to the bottom of the
--- buffer at the beginning
self:update_content("", { focus = true, scroll = false })
self:update_content(content_prefix)
-- self:update_content(content_prefix)
---stop scroll when user presses j/k keys
local function on_j()

View File

@@ -54,4 +54,9 @@ function M:__tostring()
return table.concat(content, "")
end
function M:__eq(other)
if not other or type(other) ~= "table" or not other.sections then return false end
return vim.deep_equal(self.sections, other.sections)
end
return M

View File

@@ -966,6 +966,36 @@ function M.get_or_create_buffer_with_filepath(filepath)
return buf
end
---@param old_lines avante.ui.Line[]
---@param new_lines avante.ui.Line[]
---@return { start_line: integer, end_line: integer, content: avante.ui.Line[] }[]
local function get_lines_diff(old_lines, new_lines)
local diffs = {}
local prev_diff_idx = nil
for i, line in ipairs(new_lines) do
if line ~= old_lines[i] then
if prev_diff_idx == nil then prev_diff_idx = i end
else
if prev_diff_idx ~= nil then
local content = vim.list_slice(new_lines, prev_diff_idx, i - 1)
table.insert(diffs, { start_line = prev_diff_idx, end_line = i, content = content })
prev_diff_idx = nil
end
end
end
if prev_diff_idx ~= nil then
table.insert(
diffs,
{ start_line = prev_diff_idx, end_line = #new_lines + 1, content = vim.list_slice(new_lines, prev_diff_idx) }
)
end
if #new_lines < #old_lines then
table.insert(diffs, { start_line = #new_lines + 1, end_line = #old_lines + 1, content = {} })
end
table.sort(diffs, function(a, b) return a.start_line > b.start_line end)
return diffs
end
---@param bufnr integer
---@param new_lines string[]
---@return { start_line: integer, end_line: integer, content: string[] }[]
@@ -1008,6 +1038,24 @@ function M.update_buffer_content(bufnr, new_lines)
end
end
---@param ns_id number
---@param bufnr integer
---@param old_lines avante.ui.Line[]
---@param new_lines avante.ui.Line[]
function M.update_buffer_lines(ns_id, bufnr, old_lines, new_lines)
local diffs = get_lines_diff(old_lines, new_lines)
if #diffs == 0 then return end
for _, diff in ipairs(diffs) do
local lines = diff.content
-- M.debug("lines", lines)
local text_lines = vim.tbl_map(function(line) return tostring(line) end, lines)
vim.api.nvim_buf_set_lines(bufnr, diff.start_line - 1, diff.end_line - 1, false, text_lines)
for i, line in ipairs(lines) do
line:set_highlights(ns_id, bufnr, diff.start_line + i - 2)
end
end
end
local severity = {
[1] = "ERROR",
[2] = "WARNING",
@@ -1365,35 +1413,140 @@ function M.uuid()
end)
end
---@param item AvanteLLMMessageContentItem
---@param message avante.HistoryMessage
---@return string
function M.message_content_item_to_text(item, message)
if type(item) == "string" then return item end
if type(item) == "table" then
if item.type == "text" then return item.text end
if item.type == "image" then return "![image](" .. item.source.media_type .. ": " .. item.source.data .. ")" end
if item.type == "tool_use" then
local pieces = {}
table.insert(pieces, string.format("[%s]: calling", item.name))
for _, log in ipairs(message.tool_use_logs or {}) do
table.insert(pieces, log)
---@param messages avante.HistoryMessage[]
---@return avante.HistoryMessage | nil
function M.get_tool_result_message(message, messages)
local content = message.message.content
if type(content) == "string" then return nil end
if vim.islist(content) then
local tool_id = nil
for _, item in ipairs(content) do
if item.type == "tool_use" then
tool_id = item.id
break
end
end
if not tool_id then return nil end
for _, message_ in ipairs(messages) do
local content_ = message_.message.content
if type(content_) == "table" and content_[1].type == "tool_result" and content_[1].tool_use_id == tool_id then
return message_
end
return table.concat(pieces, "\n")
end
end
return ""
return nil
end
---@param text string
---@param hl string | nil
---@return avante.ui.Line[]
function M.text_to_lines(text, hl)
local Line = require("avante.ui.line")
local text_lines = vim.split(text, "\n")
local lines = {}
for _, text_line in ipairs(text_lines) do
local piece = { text_line }
if hl then table.insert(piece, hl) end
table.insert(lines, Line:new({ piece }))
end
return lines
end
---@param item AvanteLLMMessageContentItem
---@param message avante.HistoryMessage
---@param messages avante.HistoryMessage[]
---@return avante.ui.Line[]
function M.message_content_item_to_lines(item, message, messages)
local Line = require("avante.ui.line")
if type(item) == "string" then return M.text_to_lines(item) end
if type(item) == "table" then
if item.type == "text" then return M.text_to_lines(item.text) end
if item.type == "image" then
return { Line:new({ { "![image](" .. item.source.media_type .. ": " .. item.source.data .. ")" } }) }
end
if item.type == "tool_use" then
local lines = {}
local state = "generating"
local hl = "AvanteStateSpinnerToolCalling"
if message.state == "generated" then
local tool_result_message = M.get_tool_result_message(message, messages)
if tool_result_message then
local tool_result = tool_result_message.message.content[1]
if tool_result.is_error then
state = "failed"
hl = "AvanteStateSpinnerFailed"
else
state = "succeeded"
hl = "AvanteStateSpinnerSucceeded"
end
end
end
table.insert(
lines,
Line:new({ { "╭─" }, { " " }, { string.format(" %s ", item.name), hl }, { string.format(" %s", state) } })
)
for idx, log in ipairs(message.tool_use_logs or {}) do
local log_ = M.trim(log, { prefix = string.format("[%s]: ", item.name) })
local lines_ = vim.split(log_, "\n")
if idx ~= #(message.tool_use_logs or {}) then
for _, line_ in ipairs(lines_) do
table.insert(lines, Line:new({ { "" }, { string.format(" %s", line_) } }))
end
else
for idx_, line_ in ipairs(lines_) do
if idx_ ~= #lines_ then
table.insert(lines, Line:new({ { "" }, { string.format(" %s", line_) } }))
else
table.insert(lines, Line:new({ { "╰─" }, { string.format(" %s", line_) } }))
end
end
end
end
return lines
end
end
return {}
end
---@param message avante.HistoryMessage
---@param messages avante.HistoryMessage[]
---@return avante.ui.Line[]
function M.message_to_lines(message, messages)
local Line = require("avante.ui.line")
local content = message.message.content
if type(content) == "string" then return { Line:new({ { content } }) } end
if vim.islist(content) then
local lines = {}
for _, item in ipairs(content) do
local lines_ = M.message_content_item_to_lines(item, message, messages)
lines = vim.list_extend(lines, lines_)
end
return lines
end
return {}
end
---@param item AvanteLLMMessageContentItem
---@param message avante.HistoryMessage
---@param messages avante.HistoryMessage[]
---@return string
function M.message_to_text(message)
function M.message_content_item_to_text(item, message, messages)
local lines = M.message_content_item_to_lines(item, message, messages)
if #lines == 0 then return "" end
return table.concat(vim.tbl_map(function(line) return tostring(line) end, lines), "\n")
end
---@param message avante.HistoryMessage
---@param messages avante.HistoryMessage[]
---@return string
function M.message_to_text(message, messages)
local content = message.message.content
if type(content) == "string" then return content end
if vim.islist(content) then
local pieces = {}
for _, item in ipairs(content) do
local text = M.message_content_item_to_text(item, message)
local text = M.message_content_item_to_text(item, message, messages)
if text ~= "" then table.insert(pieces, text) end
end
return table.concat(pieces, "\n")