refactor: tool calling ui (#1959)

This commit is contained in:
yetone
2025-05-01 19:16:15 +08:00
committed by GitHub
parent 448efbb842
commit 3a43621e17
4 changed files with 289 additions and 33 deletions

View File

@@ -966,6 +966,36 @@ function M.get_or_create_buffer_with_filepath(filepath)
return buf
end
---@param old_lines avante.ui.Line[]
---@param new_lines avante.ui.Line[]
---@return { start_line: integer, end_line: integer, content: avante.ui.Line[] }[]
local function get_lines_diff(old_lines, new_lines)
local diffs = {}
local prev_diff_idx = nil
for i, line in ipairs(new_lines) do
if line ~= old_lines[i] then
if prev_diff_idx == nil then prev_diff_idx = i end
else
if prev_diff_idx ~= nil then
local content = vim.list_slice(new_lines, prev_diff_idx, i - 1)
table.insert(diffs, { start_line = prev_diff_idx, end_line = i, content = content })
prev_diff_idx = nil
end
end
end
if prev_diff_idx ~= nil then
table.insert(
diffs,
{ start_line = prev_diff_idx, end_line = #new_lines + 1, content = vim.list_slice(new_lines, prev_diff_idx) }
)
end
if #new_lines < #old_lines then
table.insert(diffs, { start_line = #new_lines + 1, end_line = #old_lines + 1, content = {} })
end
table.sort(diffs, function(a, b) return a.start_line > b.start_line end)
return diffs
end
---@param bufnr integer
---@param new_lines string[]
---@return { start_line: integer, end_line: integer, content: string[] }[]
@@ -1008,6 +1038,24 @@ function M.update_buffer_content(bufnr, new_lines)
end
end
---@param ns_id number
---@param bufnr integer
---@param old_lines avante.ui.Line[]
---@param new_lines avante.ui.Line[]
function M.update_buffer_lines(ns_id, bufnr, old_lines, new_lines)
local diffs = get_lines_diff(old_lines, new_lines)
if #diffs == 0 then return end
for _, diff in ipairs(diffs) do
local lines = diff.content
-- M.debug("lines", lines)
local text_lines = vim.tbl_map(function(line) return tostring(line) end, lines)
vim.api.nvim_buf_set_lines(bufnr, diff.start_line - 1, diff.end_line - 1, false, text_lines)
for i, line in ipairs(lines) do
line:set_highlights(ns_id, bufnr, diff.start_line + i - 2)
end
end
end
local severity = {
[1] = "ERROR",
[2] = "WARNING",
@@ -1365,35 +1413,140 @@ function M.uuid()
end)
end
---@param item AvanteLLMMessageContentItem
---@param message avante.HistoryMessage
---@return string
function M.message_content_item_to_text(item, message)
if type(item) == "string" then return item end
if type(item) == "table" then
if item.type == "text" then return item.text end
if item.type == "image" then return "![image](" .. item.source.media_type .. ": " .. item.source.data .. ")" end
if item.type == "tool_use" then
local pieces = {}
table.insert(pieces, string.format("[%s]: calling", item.name))
for _, log in ipairs(message.tool_use_logs or {}) do
table.insert(pieces, log)
---@param messages avante.HistoryMessage[]
---@return avante.HistoryMessage | nil
function M.get_tool_result_message(message, messages)
local content = message.message.content
if type(content) == "string" then return nil end
if vim.islist(content) then
local tool_id = nil
for _, item in ipairs(content) do
if item.type == "tool_use" then
tool_id = item.id
break
end
end
if not tool_id then return nil end
for _, message_ in ipairs(messages) do
local content_ = message_.message.content
if type(content_) == "table" and content_[1].type == "tool_result" and content_[1].tool_use_id == tool_id then
return message_
end
return table.concat(pieces, "\n")
end
end
return ""
return nil
end
---@param text string
---@param hl string | nil
---@return avante.ui.Line[]
function M.text_to_lines(text, hl)
local Line = require("avante.ui.line")
local text_lines = vim.split(text, "\n")
local lines = {}
for _, text_line in ipairs(text_lines) do
local piece = { text_line }
if hl then table.insert(piece, hl) end
table.insert(lines, Line:new({ piece }))
end
return lines
end
---@param item AvanteLLMMessageContentItem
---@param message avante.HistoryMessage
---@param messages avante.HistoryMessage[]
---@return avante.ui.Line[]
function M.message_content_item_to_lines(item, message, messages)
local Line = require("avante.ui.line")
if type(item) == "string" then return M.text_to_lines(item) end
if type(item) == "table" then
if item.type == "text" then return M.text_to_lines(item.text) end
if item.type == "image" then
return { Line:new({ { "![image](" .. item.source.media_type .. ": " .. item.source.data .. ")" } }) }
end
if item.type == "tool_use" then
local lines = {}
local state = "generating"
local hl = "AvanteStateSpinnerToolCalling"
if message.state == "generated" then
local tool_result_message = M.get_tool_result_message(message, messages)
if tool_result_message then
local tool_result = tool_result_message.message.content[1]
if tool_result.is_error then
state = "failed"
hl = "AvanteStateSpinnerFailed"
else
state = "succeeded"
hl = "AvanteStateSpinnerSucceeded"
end
end
end
table.insert(
lines,
Line:new({ { "╭─" }, { " " }, { string.format(" %s ", item.name), hl }, { string.format(" %s", state) } })
)
for idx, log in ipairs(message.tool_use_logs or {}) do
local log_ = M.trim(log, { prefix = string.format("[%s]: ", item.name) })
local lines_ = vim.split(log_, "\n")
if idx ~= #(message.tool_use_logs or {}) then
for _, line_ in ipairs(lines_) do
table.insert(lines, Line:new({ { "" }, { string.format(" %s", line_) } }))
end
else
for idx_, line_ in ipairs(lines_) do
if idx_ ~= #lines_ then
table.insert(lines, Line:new({ { "" }, { string.format(" %s", line_) } }))
else
table.insert(lines, Line:new({ { "╰─" }, { string.format(" %s", line_) } }))
end
end
end
end
return lines
end
end
return {}
end
---@param message avante.HistoryMessage
---@param messages avante.HistoryMessage[]
---@return avante.ui.Line[]
function M.message_to_lines(message, messages)
local Line = require("avante.ui.line")
local content = message.message.content
if type(content) == "string" then return { Line:new({ { content } }) } end
if vim.islist(content) then
local lines = {}
for _, item in ipairs(content) do
local lines_ = M.message_content_item_to_lines(item, message, messages)
lines = vim.list_extend(lines, lines_)
end
return lines
end
return {}
end
---@param item AvanteLLMMessageContentItem
---@param message avante.HistoryMessage
---@param messages avante.HistoryMessage[]
---@return string
function M.message_to_text(message)
function M.message_content_item_to_text(item, message, messages)
local lines = M.message_content_item_to_lines(item, message, messages)
if #lines == 0 then return "" end
return table.concat(vim.tbl_map(function(line) return tostring(line) end, lines), "\n")
end
---@param message avante.HistoryMessage
---@param messages avante.HistoryMessage[]
---@return string
function M.message_to_text(message, messages)
local content = message.message.content
if type(content) == "string" then return content end
if vim.islist(content) then
local pieces = {}
for _, item in ipairs(content) do
local text = M.message_content_item_to_text(item, message)
local text = M.message_content_item_to_text(item, message, messages)
if text ~= "" then table.insert(pieces, text) end
end
return table.concat(pieces, "\n")