feat: request permission button group (#2685)

This commit is contained in:
yetone
2025-09-04 15:35:40 +08:00
committed by GitHub
parent 4ac4c8ed3f
commit 10e0312ec4
6 changed files with 394 additions and 125 deletions

View File

@@ -14,7 +14,7 @@ local Providers = require("avante.providers")
local LLMToolHelpers = require("avante.llm_tools.helpers")
local LLMTools = require("avante.llm_tools")
local History = require("avante.history")
local Selector = require("avante.ui.selector")
local Highlights = require("avante.highlights")
---@class avante.LLM
local M = {}
@@ -801,9 +801,10 @@ end
---@param opts AvanteLLMStreamOptions
function M._stream_acp(opts)
Utils.debug("use ACP", Config.provider)
local Render = require("avante.history.render")
---@type table<string, avante.HistoryMessage>
local tool_call_messages = {}
---@type avante.HistoryMessage
local last_tool_call_message = nil
local acp_provider = Config.acp_providers[Config.provider]
local on_messages_add = function(messages)
if opts.on_messages_add then opts.on_messages_add(messages) end
@@ -816,6 +817,7 @@ function M._stream_acp(opts)
name = update.kind,
input = update.rawInput or {},
})
last_tool_call_message = message
message.acp_tool_call = update
if update.status == "pending" or update.status == "in_progress" then message.is_calling = true end
tool_call_messages[update.toolCallId] = message
@@ -908,6 +910,7 @@ function M._stream_acp(opts)
tool_call_message.acp_tool_call = update
end
if tool_call_message.acp_tool_call then
if update.content and next(update.content) == nil then update.content = nil end
tool_call_message.acp_tool_call = vim.tbl_deep_extend("force", tool_call_message.acp_tool_call, update)
end
tool_call_message.tool_use_logs = tool_call_message.tool_use_logs or {}
@@ -933,83 +936,76 @@ function M._stream_acp(opts)
end
end,
on_request_permission = function(tool_call, options, callback)
local message = tool_call_messages[tool_call.toolCallId]
if not message then
add_tool_call_message(tool_call)
else
if message.acp_tool_call then
message.acp_tool_call = vim.tbl_deep_extend("force", message.acp_tool_call, tool_call)
tool_call = message.acp_tool_call
end
local sidebar = require("avante").get()
if not sidebar then
Utils.error("Avante sidebar not found")
return
end
---@cast tool_call avante.acp.ToolCall
local items = vim
.iter(options)
:map(
function(item)
return {
id = item.optionId,
title = item.name,
}
:map(function(item)
local icon = item.kind == "allow_once" and "" or ""
if item.kind == "allow_always" then icon = "" end
local hl = nil
if item.kind == "reject_once" or item.kind == "reject_always" then
hl = Highlights.BUTTON_DANGER_HOVER
end
)
return {
id = item.optionId,
name = item.name,
icon = icon,
hl = hl,
}
end)
:totable()
local default_item = vim.iter(items):find(function(item) return item.id == options[1].optionId end)
local function on_select(item_ids)
if not item_ids then return end
local choice = vim.iter(items):find(function(item) return item.id == item_ids[1] end)
if not choice then return end
Utils.debug("on_select", choice.id)
callback(choice.id)
sidebar.permission_button_options = items
sidebar.permission_handler = function(id)
callback(id)
sidebar.scroll = true
sidebar.permission_button_options = nil
sidebar.permission_handler = nil
sidebar._history_cache_invalidated = true
sidebar:update_content("")
end
local tool_name, error = Render.get_tool_display_name(message)
if error then
Utils.error(error)
tool_name = message.message.content[1].name
local message = tool_call_messages[tool_call.toolCallId]
if not message then
message = add_tool_call_message(tool_call)
else
if message.acp_tool_call then
if tool_call.content and next(tool_call.content) == nil then tool_call.content = nil end
message.acp_tool_call = vim.tbl_deep_extend("force", message.acp_tool_call, tool_call)
end
end
local selector = Selector:new({
title = tool_name,
items = items,
default_item_id = default_item and default_item.name or nil,
provider = Config.selector.provider,
provider_opts = Config.selector.provider_opts,
on_select = on_select,
get_preview_content = function(_)
local file_content = ""
local filetype = "text"
local content = tool_call.content
if type(content) == "table" then
for _, item in ipairs(content) do
if item.type == "content" then
if type(item.content) == "table" then
if item.content.type == "text" then
file_content = file_content .. item.content.text .. "\n\n"
end
end
end
if item.type == "diff" then
local unified_diff = Utils.get_unified_diff(item.oldText, item.newText, { algorithm = "myers" })
local result = "--- a/" .. item.path .. "\n+++ b/" .. item.path .. "\n" .. unified_diff .. "\n\n"
filetype = "diff"
file_content = file_content .. result
end
end
end
return file_content, filetype
end,
})
selector:open()
on_messages_add({ message })
end,
on_read_file = function(path, line, limit, callback)
local abs_path = Utils.to_absolute_path(path)
local lines = Utils.read_file_from_buf_or_disk(abs_path)
lines = lines or {}
if line ~= nil and limit ~= nil then lines = vim.list_slice(lines, line, line + limit) end
callback(table.concat(lines, "\n"))
local content = table.concat(lines, "\n")
if
last_tool_call_message
and last_tool_call_message.acp_tool_call
and last_tool_call_message.acp_tool_call.kind == "read"
then
if
last_tool_call_message.acp_tool_call.content
and next(last_tool_call_message.acp_tool_call.content) == nil
then
last_tool_call_message.acp_tool_call.content = {
{
type = "content",
content = {
type = "text",
text = content,
},
},
}
end
end
callback(content)
end,
on_write_file = function(path, content, callback)
local abs_path = Utils.to_absolute_path(path)