From 487972386e2fe1ba1959947b9f656fb5a2979fb4 Mon Sep 17 00:00:00 2001 From: yetone Date: Thu, 16 Oct 2025 19:05:57 +0800 Subject: [PATCH] feat: support acp slash commands (#2775) --- lua/avante/init.lua | 4 +- lua/avante/libs/acp_client.lua | 13 +++++- lua/avante/llm.lua | 27 +++++++++++++ lua/avante/sidebar.lua | 66 ++++++++++++++++++------------- lua/avante/types.lua | 1 + lua/avante/utils/init.lua | 17 +++++++- lua/avante/utils/promptLogger.lua | 1 + lua/cmp_avante/commands.lua | 11 +++++- 8 files changed, 107 insertions(+), 33 deletions(-) diff --git a/lua/avante/init.lua b/lua/avante/init.lua index 986a0ab..ed721c9 100644 --- a/lua/avante/init.lua +++ b/lua/avante/init.lua @@ -437,6 +437,8 @@ setmetatable(M.toggle, { __call = function() M.toggle_sidebar() end, }) +M.slash_commands_id = nil + ---@param opts? avante.Config function M.setup(opts) ---PERF: we can still allow running require("avante").setup() multiple times to override config if users wish to @@ -496,7 +498,7 @@ function M.setup(opts) local has_cmp, cmp = pcall(require, "cmp") if has_cmp then - cmp.register_source("avante_commands", require("cmp_avante.commands"):new()) + M.slash_commands_id = cmp.register_source("avante_commands", require("cmp_avante.commands"):new()) cmp.register_source("avante_mentions", require("cmp_avante.mentions"):new(Utils.get_chat_mentions)) diff --git a/lua/avante/libs/acp_client.lua b/lua/avante/libs/acp_client.lua index 246d2fe..8ba097c 100644 --- a/lua/avante/libs/acp_client.lua +++ b/lua/avante/libs/acp_client.lua @@ -124,8 +124,13 @@ local Utils = require("avante.utils") ---@class avante.acp.Plan ---@field entries avante.acp.PlanEntry[] +---@class avante.acp.AvailableCommand +---@field name string +---@field description string +---@field input? table + ---@class avante.acp.BaseSessionUpdate ----@field sessionUpdate "user_message_chunk" | "agent_message_chunk" | "agent_thought_chunk" | "tool_call" | "tool_call_update" | "plan" +---@field sessionUpdate "user_message_chunk" | "agent_message_chunk" | "agent_thought_chunk" | "tool_call" | "tool_call_update" | "plan" | "available_commands_update" ---@class avante.acp.UserMessageChunk : avante.acp.BaseSessionUpdate ---@field sessionUpdate "user_message_chunk" @@ -154,6 +159,10 @@ local Utils = require("avante.utils") ---@field sessionUpdate "plan" ---@field entries avante.acp.PlanEntry[] +---@class avante.acp.AvailableCommandsUpdate : avante.acp.BaseSessionUpdate +---@field sessionUpdate "available_commands_update" +---@field availableCommands avante.acp.AvailableCommand[] + ---@class avante.acp.PermissionOption ---@field optionId string ---@field name string @@ -196,7 +205,7 @@ ACPClient.ERROR_CODES = { } ---@class ACPHandlers ----@field on_session_update? fun(update: avante.acp.UserMessageChunk | avante.acp.AgentMessageChunk | avante.acp.AgentThoughtChunk | avante.acp.ToolCallUpdate | avante.acp.PlanUpdate) +---@field on_session_update? fun(update: avante.acp.UserMessageChunk | avante.acp.AgentMessageChunk | avante.acp.AgentThoughtChunk | avante.acp.ToolCallUpdate | avante.acp.PlanUpdate | avante.acp.AvailableCommandsUpdate) ---@field on_request_permission? fun(tool_call: table, options: table[], callback: fun(option_id: string | nil)): nil ---@field on_read_file? fun(path: string, line: integer | nil, limit: integer | nil, callback: fun(content: string)): nil ---@field on_write_file? fun(path: string, content: string, callback: fun(error: string|nil)): nil diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 482d119..5f5bda3 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -1054,6 +1054,32 @@ function M._stream_acp(opts) if tool_result_message then table.insert(messages, tool_result_message) end on_messages_add(messages) end + if update.sessionUpdate == "available_commands_update" then + local commands = update.availableCommands + local has_cmp, cmp = pcall(require, "cmp") + if has_cmp then + local slash_commands_id = require("avante").slash_commands_id + if slash_commands_id ~= nil then cmp.unregister_source(slash_commands_id) end + for _, command in ipairs(commands) do + local exists = false + for _, command_ in ipairs(Config.slash_commands) do + if command_.name == command.name then + exists = true + break + end + end + if not exists then + table.insert(Config.slash_commands, { + name = command.name, + description = command.description, + details = command.description, + }) + end + end + local avante = require("avante") + avante.slash_commands_id = cmp.register_source("avante_commands", require("cmp_avante.commands"):new()) + end + end end, on_request_permission = function(tool_call, options, callback) local sidebar = require("avante").get() @@ -1171,6 +1197,7 @@ function M._stream_acp(opts) session_id = session_id_ if opts.on_save_acp_session_id then opts.on_save_acp_session_id(session_id) end end + if opts.just_connect_acp_client then return end local prompt = {} local donot_use_builtin_system_prompt = opts.history_messages ~= nil and #opts.history_messages > 0 if donot_use_builtin_system_prompt then diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index f3bdf1c..8a00162 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -195,6 +195,9 @@ function Sidebar:open(opts) vim.g.avante_login = true end + local acp_provider = Config.acp_providers[Config.provider] + if acp_provider then self:handle_submit("") end + return self end @@ -2626,6 +2629,20 @@ function Sidebar:get_generate_prompts_options(request, cb) if cb then cb(prompts_opts) end end +function Sidebar:submit_input() + if not vim.g.avante_login then + Utils.warn("Sending message to fast!, API key is not yet set", { title = "Avante" }) + return + end + if not Utils.is_valid_container(self.containers.input) then return end + local lines = api.nvim_buf_get_lines(self.containers.input.bufnr, 0, -1, false) + local request = table.concat(lines, "\n") + if request == "" then return end + api.nvim_buf_set_lines(self.containers.input.bufnr, 0, -1, false, {}) + api.nvim_win_set_cursor(self.containers.input.winid, { 1, 0 }) + self:handle_submit(request) +end + ---@param request string function Sidebar:handle_submit(request) if Config.prompt_logger.enabled then PromptLogger.log_prompt(request) end @@ -2650,16 +2667,18 @@ function Sidebar:handle_submit(request) ---@type AvanteSlashCommand local cmd = vim.iter(cmds):filter(function(cmd) return cmd.name == command end):totable()[1] if cmd then - if command == "lines" then - cmd.callback(self, args, function(args_) - local _, _, question = args_:match("(%d+)-(%d+)%s+(.*)") - request = question - end) - elseif command == "commit" then - cmd.callback(self, args, function(question) request = question end) - else - cmd.callback(self, args) - return + if cmd.callback then + if command == "lines" then + cmd.callback(self, args, function(args_) + local _, _, question = args_:match("(%d+)-(%d+)%s+(.*)") + request = question + end) + elseif command == "commit" then + cmd.callback(self, args, function(question) request = question end) + else + cmd.callback(self, args) + return + end end else self:update_content("Unknown command: " .. command, { focus = false, scroll = false }) @@ -2681,10 +2700,12 @@ function Sidebar:handle_submit(request) content = self.code.selection.content, } - --- HACK: we need to set focus to true and scroll to false to - --- prevent the cursor from jumping to the bottom of the - --- buffer at the beginning - self:update_content("", { focus = true, scroll = false }) + if request ~= "" then + --- HACK: we need to set focus to true and scroll to false to + --- prevent the cursor from jumping to the bottom of the + --- buffer at the beginning + self:update_content("", { focus = true, scroll = false }) + end ---stop scroll when user presses j/k keys local function on_j() @@ -2809,6 +2830,7 @@ function Sidebar:handle_submit(request) ---@type AvanteLLMStreamOptions ---@diagnostic disable-next-line: assign-type-mismatch local stream_options = vim.tbl_deep_extend("force", generate_prompts_options, { + just_connect_acp_client = request == "", on_start = on_start, on_stop = on_stop, on_tool_log = on_tool_log, @@ -2860,7 +2882,7 @@ function Sidebar:handle_submit(request) stream_options.on_memory_summarize = on_memory_summarize - on_state_change("generating") + if request ~= "" then on_state_change("generating") end Llm.stream(stream_options) end) end @@ -2909,19 +2931,7 @@ function Sidebar:create_input_container() size = get_size(), }) - local function on_submit() - if not vim.g.avante_login then - Utils.warn("Sending message to fast!, API key is not yet set", { title = "Avante" }) - return - end - if not Utils.is_valid_container(self.containers.input) then return end - local lines = api.nvim_buf_get_lines(self.containers.input.bufnr, 0, -1, false) - local request = table.concat(lines, "\n") - if request == "" then return end - api.nvim_buf_set_lines(self.containers.input.bufnr, 0, -1, false, {}) - api.nvim_win_set_cursor(self.containers.input.winid, { 1, 0 }) - self:handle_submit(request) - end + local function on_submit() self:submit_input() end self.containers.input:mount() PromptLogger.init() diff --git a/lua/avante/types.lua b/lua/avante/types.lua index 14d2784..57e85a9 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -417,6 +417,7 @@ vim.g.avante_login = vim.g.avante_login ---@class AvanteLLMStreamOptions: AvanteGeneratePromptsOptions ---@field acp_client? avante.acp.ACPClient ---@field on_save_acp_client? fun(client: avante.acp.ACPClient): nil +---@field just_connect_acp_client? boolean ---@field acp_session_id? string ---@field on_save_acp_session_id? fun(session_id: string): nil ---@field on_start AvanteLLMStartCallback diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index 029ad2a..44d09dd 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -1626,7 +1626,22 @@ function M.get_commands() ) :totable() - return vim.list_extend(builtin_commands, Config.slash_commands) + local commands = {} + local seen = {} + for _, command in ipairs(Config.slash_commands) do + if not seen[command.name] then + table.insert(commands, command) + seen[command.name] = true + end + end + for _, command in ipairs(builtin_commands) do + if not seen[command.name] then + table.insert(commands, command) + seen[command.name] = true + end + end + + return commands end function M.get_timestamp() return tostring(os.date("%Y-%m-%d %H:%M:%S")) end diff --git a/lua/avante/utils/promptLogger.lua b/lua/avante/utils/promptLogger.lua index d05c87b..948c02f 100644 --- a/lua/avante/utils/promptLogger.lua +++ b/lua/avante/utils/promptLogger.lua @@ -39,6 +39,7 @@ function M.init() end function M.log_prompt(request) + if request == "" then return end local log_dir = Config.prompt_logger.log_dir local log_file = Utils.join_paths(log_dir, "avante_prompts.log") diff --git a/lua/cmp_avante/commands.lua b/lua/cmp_avante/commands.lua index d691957..a00adbf 100644 --- a/lua/cmp_avante/commands.lua +++ b/lua/cmp_avante/commands.lua @@ -56,9 +56,18 @@ function CommandsSource:execute(item, callback) local commands = Utils.get_commands() local command = vim.iter(commands):find(function(command) return command.name == item.data.name end) - if not command then return end + if not command then + callback() + return + end local sidebar = require("avante").get() + if not command.callback then + if sidebar then sidebar:submit_input() end + callback() + return + end + command.callback(sidebar, nil, function() local bufnr = sidebar.containers.input.bufnr ---@type integer local content = table.concat(api.nvim_buf_get_lines(bufnr, 0, -1, false), "\n")