refactor: tool calling ui (#1959)
This commit is contained in:
@@ -81,7 +81,7 @@ function M.summarize_memory(prev_memory, history_messages, cb)
|
||||
if type(content) == "table" and content[1].type == "tool_use" then return false end
|
||||
return true
|
||||
end)
|
||||
:map(function(msg) return msg.message.role .. ": " .. Utils.message_to_text(msg) end)
|
||||
:map(function(msg) return msg.message.role .. ": " .. Utils.message_to_text(msg, history_messages) end)
|
||||
:totable()
|
||||
local conversation_text = table.concat(conversation_items, "\n")
|
||||
local user_prompt = "Here is the conversation so far:\n"
|
||||
|
||||
@@ -16,6 +16,7 @@ local RepoMap = require("avante.repo_map")
|
||||
local FileSelector = require("avante.file_selector")
|
||||
local LLMTools = require("avante.llm_tools")
|
||||
local HistoryMessage = require("avante.history_message")
|
||||
local Line = require("avante.ui.line")
|
||||
|
||||
local RESULT_BUF_NAME = "AVANTE_RESULT"
|
||||
local VIEW_BUFFER_UPDATED_PATTERN = "AvanteViewBufferUpdated"
|
||||
@@ -25,6 +26,7 @@ local SELECTED_FILES_HINT_NAMESPACE = api.nvim_create_namespace("AVANTE_SELECTED
|
||||
local SELECTED_FILES_ICON_NAMESPACE = api.nvim_create_namespace("AVANTE_SELECTED_FILES_ICON")
|
||||
local INPUT_HINT_NAMESPACE = api.nvim_create_namespace("AVANTE_INPUT_HINT")
|
||||
local STATE_NAMESPACE = api.nvim_create_namespace("AVANTE_STATE")
|
||||
local RESULT_BUF_HL_NAMESPACE = api.nvim_create_namespace("AVANTE_RESULT_BUF_HL")
|
||||
|
||||
local PRIORITY = (vim.hl or vim.highlight).priorities.user
|
||||
|
||||
@@ -59,6 +61,7 @@ Sidebar.__index = Sidebar
|
||||
---@field scroll boolean
|
||||
---@field input_hint_window integer | nil
|
||||
---@field ask_opts AskOptions
|
||||
---@field old_result_lines avante.ui.Line[]
|
||||
|
||||
---@param id integer the tabpage id retrieved from api.nvim_get_current_tabpage()
|
||||
function Sidebar:new(id)
|
||||
@@ -86,6 +89,7 @@ function Sidebar:new(id)
|
||||
scroll = true,
|
||||
input_hint_window = nil,
|
||||
ask_opts = {},
|
||||
old_result_lines = {},
|
||||
}, Sidebar)
|
||||
end
|
||||
|
||||
@@ -121,6 +125,7 @@ function Sidebar:reset()
|
||||
self.selected_files_container = nil
|
||||
self.input_container = nil
|
||||
self.scroll = true
|
||||
self.old_result_lines = {}
|
||||
end
|
||||
|
||||
---@class SidebarOpenOptions: AskOptions
|
||||
@@ -1510,18 +1515,24 @@ end
|
||||
function Sidebar:update_content(content, opts)
|
||||
if not self.result_container or not self.result_container.bufnr then return end
|
||||
opts = vim.tbl_deep_extend("force", { focus = false, scroll = self.scroll, callback = nil }, opts or {})
|
||||
local history_content = self.render_history_content(self.chat_history)
|
||||
local contents = { history_content, content }
|
||||
contents = vim.iter(contents):filter(function(item) return item ~= nil and item ~= "" end):totable()
|
||||
content = table.concat(contents, "\n\n")
|
||||
local history_lines = self.get_history_lines(self.chat_history)
|
||||
if content ~= nil and content ~= "" then
|
||||
table.insert(history_lines, Line:new({ { "" } }))
|
||||
table.insert(history_lines, Line:new({ { content } }))
|
||||
end
|
||||
vim.defer_fn(function()
|
||||
self:clear_state()
|
||||
local f = function()
|
||||
if not Utils.is_valid_container(self.result_container) then return end
|
||||
local lines = vim.split(content, "\n")
|
||||
Utils.unlock_buf(self.result_container.bufnr)
|
||||
Utils.update_buffer_content(self.result_container.bufnr, lines)
|
||||
Utils.update_buffer_lines(
|
||||
RESULT_BUF_HL_NAMESPACE,
|
||||
self.result_container.bufnr,
|
||||
self.old_result_lines,
|
||||
history_lines
|
||||
)
|
||||
Utils.lock_buf(self.result_container.bufnr)
|
||||
self.old_result_lines = history_lines
|
||||
api.nvim_set_option_value("filetype", "Avante", { buf = self.result_container.bufnr })
|
||||
vim.schedule(function() vim.cmd("redraw") end)
|
||||
if opts.focus and not self:is_focused_on_result() then
|
||||
@@ -1593,11 +1604,96 @@ function Sidebar:get_layout()
|
||||
end
|
||||
|
||||
---@param message avante.HistoryMessage
|
||||
---@param messages avante.HistoryMessage[]
|
||||
---@param ctx table
|
||||
---@return avante.ui.Line[]
|
||||
local function get_message_lines(message, messages, ctx)
|
||||
if message.visible == false then return {} end
|
||||
local lines = Utils.message_to_lines(message, messages)
|
||||
if message.is_user_submission then
|
||||
ctx.selected_filepaths = message.selected_filepaths
|
||||
local text = table.concat(vim.tbl_map(function(line) return tostring(line) end, lines), "\n")
|
||||
local prefix = render_chat_record_prefix(
|
||||
message.timestamp,
|
||||
message.provider,
|
||||
message.model,
|
||||
text,
|
||||
message.selected_filepaths,
|
||||
message.selected_code
|
||||
)
|
||||
local res = {}
|
||||
for _, line_ in ipairs(vim.split(prefix, "\n")) do
|
||||
table.insert(res, Line:new({ { line_ } }))
|
||||
end
|
||||
return res
|
||||
end
|
||||
if message.message.role == "user" then
|
||||
local res = {}
|
||||
for _, line_ in ipairs(lines) do
|
||||
local sections = { { "> " } }
|
||||
sections = vim.list_extend(sections, line_.sections)
|
||||
table.insert(res, Line:new(sections))
|
||||
end
|
||||
return res
|
||||
end
|
||||
if message.message.role == "assistant" then
|
||||
local content = message.message.content
|
||||
if type(content) == "table" and content[1].type == "tool_use" then return lines end
|
||||
local text = table.concat(vim.tbl_map(function(line) return tostring(line) end, lines), "\n")
|
||||
local transformed = transform_result_content(text, ctx.prev_filepath)
|
||||
ctx.prev_filepath = transformed.current_filepath
|
||||
local displayed_content = generate_display_content(transformed)
|
||||
local res = {}
|
||||
for _, line_ in ipairs(vim.split(displayed_content, "\n")) do
|
||||
table.insert(res, Line:new({ { line_ } }))
|
||||
end
|
||||
return res
|
||||
end
|
||||
return lines
|
||||
end
|
||||
|
||||
---@param history avante.ChatHistory
|
||||
---@return avante.ui.Line[]
|
||||
function Sidebar.get_history_lines(history)
|
||||
local history_messages = Utils.get_history_messages(history)
|
||||
local ctx = {}
|
||||
---@type avante.ui.Line[][]
|
||||
local group = {}
|
||||
for _, message in ipairs(history_messages) do
|
||||
local lines = get_message_lines(message, history_messages, ctx)
|
||||
if #lines == 0 then goto continue end
|
||||
if message.is_user_submission then table.insert(group, {}) end
|
||||
local last_item = group[#group]
|
||||
if last_item == nil then
|
||||
table.insert(group, {})
|
||||
last_item = group[#group]
|
||||
end
|
||||
if message.message.role == "assistant" and not message.just_for_display and tostring(lines[1]) ~= "" then
|
||||
table.insert(lines, 1, Line:new({ { "" } }))
|
||||
table.insert(lines, 1, Line:new({ { "" } }))
|
||||
end
|
||||
last_item = vim.list_extend(last_item, lines)
|
||||
group[#group] = last_item
|
||||
::continue::
|
||||
end
|
||||
local res = {}
|
||||
for idx, item in ipairs(group) do
|
||||
if idx ~= 1 and idx ~= #group then
|
||||
res = vim.list_extend(res, { Line:new({ { "" } }), Line:new({ { RESP_SEPARATOR } }), Line:new({ { "" } }) })
|
||||
end
|
||||
res = vim.list_extend(res, item)
|
||||
end
|
||||
table.insert(res, Line:new({ { "" } }))
|
||||
return res
|
||||
end
|
||||
|
||||
---@param message avante.HistoryMessage
|
||||
---@param messages avante.HistoryMessage[]
|
||||
---@param ctx table
|
||||
---@return string | nil
|
||||
local function render_message(message, ctx)
|
||||
local function render_message(message, messages, ctx)
|
||||
if message.visible == false then return nil end
|
||||
local text = Utils.message_to_text(message)
|
||||
local text = Utils.message_to_text(message, messages)
|
||||
if text == "" then return nil end
|
||||
if message.is_user_submission then
|
||||
ctx.selected_filepaths = message.selected_filepaths
|
||||
@@ -1633,7 +1729,7 @@ function Sidebar.render_history_content(history)
|
||||
local ctx = {}
|
||||
local group = {}
|
||||
for _, message in ipairs(history_messages) do
|
||||
local text = render_message(message, ctx)
|
||||
local text = render_message(message, history_messages, ctx)
|
||||
if text == nil then goto continue end
|
||||
if message.is_user_submission then table.insert(group, {}) end
|
||||
local last_item = group[#group]
|
||||
@@ -1695,6 +1791,7 @@ function Sidebar:get_content_between_separators()
|
||||
end
|
||||
|
||||
function Sidebar:clear_history(args, cb)
|
||||
self.current_state = nil
|
||||
local chat_history = Path.history.load(self.code.bufnr)
|
||||
if next(chat_history) ~= nil then
|
||||
chat_history.messages = {}
|
||||
@@ -1763,6 +1860,7 @@ function Sidebar:new_chat(args, cb)
|
||||
local history = Path.history.new(self.code.bufnr)
|
||||
Path.history.save(self.code.bufnr, history)
|
||||
self:reload_chat_history()
|
||||
self.current_state = nil
|
||||
self:update_content("New chat", { focus = false, scroll = false, callback = function() self:focus_input() end })
|
||||
if cb then cb(args) end
|
||||
end
|
||||
@@ -2144,10 +2242,10 @@ function Sidebar:create_input_container()
|
||||
end
|
||||
end
|
||||
|
||||
local model = Config.has_provider(Config.provider) and Config.get_provider_config(Config.provider).model
|
||||
or "default"
|
||||
|
||||
local timestamp = Utils.get_timestamp()
|
||||
-- local model = Config.has_provider(Config.provider) and Config.get_provider_config(Config.provider).model
|
||||
-- or "default"
|
||||
--
|
||||
-- local timestamp = Utils.get_timestamp()
|
||||
|
||||
local selected_filepaths = self.file_selector:get_selected_filepaths()
|
||||
|
||||
@@ -2161,14 +2259,14 @@ function Sidebar:create_input_container()
|
||||
}
|
||||
end
|
||||
|
||||
local content_prefix =
|
||||
render_chat_record_prefix(timestamp, Config.provider, model, request, selected_filepaths, selected_code)
|
||||
-- local content_prefix =
|
||||
-- render_chat_record_prefix(timestamp, Config.provider, model, request, selected_filepaths, selected_code)
|
||||
|
||||
--- HACK: we need to set focus to true and scroll to false to
|
||||
--- prevent the cursor from jumping to the bottom of the
|
||||
--- buffer at the beginning
|
||||
self:update_content("", { focus = true, scroll = false })
|
||||
self:update_content(content_prefix)
|
||||
-- self:update_content(content_prefix)
|
||||
|
||||
---stop scroll when user presses j/k keys
|
||||
local function on_j()
|
||||
|
||||
@@ -54,4 +54,9 @@ function M:__tostring()
|
||||
return table.concat(content, "")
|
||||
end
|
||||
|
||||
function M:__eq(other)
|
||||
if not other or type(other) ~= "table" or not other.sections then return false end
|
||||
return vim.deep_equal(self.sections, other.sections)
|
||||
end
|
||||
|
||||
return M
|
||||
|
||||
@@ -966,6 +966,36 @@ function M.get_or_create_buffer_with_filepath(filepath)
|
||||
return buf
|
||||
end
|
||||
|
||||
---@param old_lines avante.ui.Line[]
|
||||
---@param new_lines avante.ui.Line[]
|
||||
---@return { start_line: integer, end_line: integer, content: avante.ui.Line[] }[]
|
||||
local function get_lines_diff(old_lines, new_lines)
|
||||
local diffs = {}
|
||||
local prev_diff_idx = nil
|
||||
for i, line in ipairs(new_lines) do
|
||||
if line ~= old_lines[i] then
|
||||
if prev_diff_idx == nil then prev_diff_idx = i end
|
||||
else
|
||||
if prev_diff_idx ~= nil then
|
||||
local content = vim.list_slice(new_lines, prev_diff_idx, i - 1)
|
||||
table.insert(diffs, { start_line = prev_diff_idx, end_line = i, content = content })
|
||||
prev_diff_idx = nil
|
||||
end
|
||||
end
|
||||
end
|
||||
if prev_diff_idx ~= nil then
|
||||
table.insert(
|
||||
diffs,
|
||||
{ start_line = prev_diff_idx, end_line = #new_lines + 1, content = vim.list_slice(new_lines, prev_diff_idx) }
|
||||
)
|
||||
end
|
||||
if #new_lines < #old_lines then
|
||||
table.insert(diffs, { start_line = #new_lines + 1, end_line = #old_lines + 1, content = {} })
|
||||
end
|
||||
table.sort(diffs, function(a, b) return a.start_line > b.start_line end)
|
||||
return diffs
|
||||
end
|
||||
|
||||
---@param bufnr integer
|
||||
---@param new_lines string[]
|
||||
---@return { start_line: integer, end_line: integer, content: string[] }[]
|
||||
@@ -1008,6 +1038,24 @@ function M.update_buffer_content(bufnr, new_lines)
|
||||
end
|
||||
end
|
||||
|
||||
---@param ns_id number
|
||||
---@param bufnr integer
|
||||
---@param old_lines avante.ui.Line[]
|
||||
---@param new_lines avante.ui.Line[]
|
||||
function M.update_buffer_lines(ns_id, bufnr, old_lines, new_lines)
|
||||
local diffs = get_lines_diff(old_lines, new_lines)
|
||||
if #diffs == 0 then return end
|
||||
for _, diff in ipairs(diffs) do
|
||||
local lines = diff.content
|
||||
-- M.debug("lines", lines)
|
||||
local text_lines = vim.tbl_map(function(line) return tostring(line) end, lines)
|
||||
vim.api.nvim_buf_set_lines(bufnr, diff.start_line - 1, diff.end_line - 1, false, text_lines)
|
||||
for i, line in ipairs(lines) do
|
||||
line:set_highlights(ns_id, bufnr, diff.start_line + i - 2)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local severity = {
|
||||
[1] = "ERROR",
|
||||
[2] = "WARNING",
|
||||
@@ -1365,35 +1413,140 @@ function M.uuid()
|
||||
end)
|
||||
end
|
||||
|
||||
---@param item AvanteLLMMessageContentItem
|
||||
---@param message avante.HistoryMessage
|
||||
---@return string
|
||||
function M.message_content_item_to_text(item, message)
|
||||
if type(item) == "string" then return item end
|
||||
if type(item) == "table" then
|
||||
if item.type == "text" then return item.text end
|
||||
if item.type == "image" then return "" end
|
||||
if item.type == "tool_use" then
|
||||
local pieces = {}
|
||||
table.insert(pieces, string.format("[%s]: calling", item.name))
|
||||
for _, log in ipairs(message.tool_use_logs or {}) do
|
||||
table.insert(pieces, log)
|
||||
---@param messages avante.HistoryMessage[]
|
||||
---@return avante.HistoryMessage | nil
|
||||
function M.get_tool_result_message(message, messages)
|
||||
local content = message.message.content
|
||||
if type(content) == "string" then return nil end
|
||||
if vim.islist(content) then
|
||||
local tool_id = nil
|
||||
for _, item in ipairs(content) do
|
||||
if item.type == "tool_use" then
|
||||
tool_id = item.id
|
||||
break
|
||||
end
|
||||
end
|
||||
if not tool_id then return nil end
|
||||
for _, message_ in ipairs(messages) do
|
||||
local content_ = message_.message.content
|
||||
if type(content_) == "table" and content_[1].type == "tool_result" and content_[1].tool_use_id == tool_id then
|
||||
return message_
|
||||
end
|
||||
return table.concat(pieces, "\n")
|
||||
end
|
||||
end
|
||||
return ""
|
||||
return nil
|
||||
end
|
||||
|
||||
---@param text string
|
||||
---@param hl string | nil
|
||||
---@return avante.ui.Line[]
|
||||
function M.text_to_lines(text, hl)
|
||||
local Line = require("avante.ui.line")
|
||||
local text_lines = vim.split(text, "\n")
|
||||
local lines = {}
|
||||
for _, text_line in ipairs(text_lines) do
|
||||
local piece = { text_line }
|
||||
if hl then table.insert(piece, hl) end
|
||||
table.insert(lines, Line:new({ piece }))
|
||||
end
|
||||
return lines
|
||||
end
|
||||
|
||||
---@param item AvanteLLMMessageContentItem
|
||||
---@param message avante.HistoryMessage
|
||||
---@param messages avante.HistoryMessage[]
|
||||
---@return avante.ui.Line[]
|
||||
function M.message_content_item_to_lines(item, message, messages)
|
||||
local Line = require("avante.ui.line")
|
||||
if type(item) == "string" then return M.text_to_lines(item) end
|
||||
if type(item) == "table" then
|
||||
if item.type == "text" then return M.text_to_lines(item.text) end
|
||||
if item.type == "image" then
|
||||
return { Line:new({ { "" } }) }
|
||||
end
|
||||
if item.type == "tool_use" then
|
||||
local lines = {}
|
||||
local state = "generating"
|
||||
local hl = "AvanteStateSpinnerToolCalling"
|
||||
if message.state == "generated" then
|
||||
local tool_result_message = M.get_tool_result_message(message, messages)
|
||||
if tool_result_message then
|
||||
local tool_result = tool_result_message.message.content[1]
|
||||
if tool_result.is_error then
|
||||
state = "failed"
|
||||
hl = "AvanteStateSpinnerFailed"
|
||||
else
|
||||
state = "succeeded"
|
||||
hl = "AvanteStateSpinnerSucceeded"
|
||||
end
|
||||
end
|
||||
end
|
||||
table.insert(
|
||||
lines,
|
||||
Line:new({ { "╭─" }, { " " }, { string.format(" %s ", item.name), hl }, { string.format(" %s", state) } })
|
||||
)
|
||||
for idx, log in ipairs(message.tool_use_logs or {}) do
|
||||
local log_ = M.trim(log, { prefix = string.format("[%s]: ", item.name) })
|
||||
local lines_ = vim.split(log_, "\n")
|
||||
if idx ~= #(message.tool_use_logs or {}) then
|
||||
for _, line_ in ipairs(lines_) do
|
||||
table.insert(lines, Line:new({ { "│" }, { string.format(" %s", line_) } }))
|
||||
end
|
||||
else
|
||||
for idx_, line_ in ipairs(lines_) do
|
||||
if idx_ ~= #lines_ then
|
||||
table.insert(lines, Line:new({ { "│" }, { string.format(" %s", line_) } }))
|
||||
else
|
||||
table.insert(lines, Line:new({ { "╰─" }, { string.format(" %s", line_) } }))
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
return lines
|
||||
end
|
||||
end
|
||||
return {}
|
||||
end
|
||||
|
||||
---@param message avante.HistoryMessage
|
||||
---@param messages avante.HistoryMessage[]
|
||||
---@return avante.ui.Line[]
|
||||
function M.message_to_lines(message, messages)
|
||||
local Line = require("avante.ui.line")
|
||||
local content = message.message.content
|
||||
if type(content) == "string" then return { Line:new({ { content } }) } end
|
||||
if vim.islist(content) then
|
||||
local lines = {}
|
||||
for _, item in ipairs(content) do
|
||||
local lines_ = M.message_content_item_to_lines(item, message, messages)
|
||||
lines = vim.list_extend(lines, lines_)
|
||||
end
|
||||
return lines
|
||||
end
|
||||
return {}
|
||||
end
|
||||
|
||||
---@param item AvanteLLMMessageContentItem
|
||||
---@param message avante.HistoryMessage
|
||||
---@param messages avante.HistoryMessage[]
|
||||
---@return string
|
||||
function M.message_to_text(message)
|
||||
function M.message_content_item_to_text(item, message, messages)
|
||||
local lines = M.message_content_item_to_lines(item, message, messages)
|
||||
if #lines == 0 then return "" end
|
||||
return table.concat(vim.tbl_map(function(line) return tostring(line) end, lines), "\n")
|
||||
end
|
||||
|
||||
---@param message avante.HistoryMessage
|
||||
---@param messages avante.HistoryMessage[]
|
||||
---@return string
|
||||
function M.message_to_text(message, messages)
|
||||
local content = message.message.content
|
||||
if type(content) == "string" then return content end
|
||||
if vim.islist(content) then
|
||||
local pieces = {}
|
||||
for _, item in ipairs(content) do
|
||||
local text = M.message_content_item_to_text(item, message)
|
||||
local text = M.message_content_item_to_text(item, message, messages)
|
||||
if text ~= "" then table.insert(pieces, text) end
|
||||
end
|
||||
return table.concat(pieces, "\n")
|
||||
|
||||
Reference in New Issue
Block a user