feat: support acp slash commands (#2775)
This commit is contained in:
@@ -437,6 +437,8 @@ setmetatable(M.toggle, {
|
|||||||
__call = function() M.toggle_sidebar() end,
|
__call = function() M.toggle_sidebar() end,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
M.slash_commands_id = nil
|
||||||
|
|
||||||
---@param opts? avante.Config
|
---@param opts? avante.Config
|
||||||
function M.setup(opts)
|
function M.setup(opts)
|
||||||
---PERF: we can still allow running require("avante").setup() multiple times to override config if users wish to
|
---PERF: we can still allow running require("avante").setup() multiple times to override config if users wish to
|
||||||
@@ -496,7 +498,7 @@ function M.setup(opts)
|
|||||||
|
|
||||||
local has_cmp, cmp = pcall(require, "cmp")
|
local has_cmp, cmp = pcall(require, "cmp")
|
||||||
if has_cmp then
|
if has_cmp then
|
||||||
cmp.register_source("avante_commands", require("cmp_avante.commands"):new())
|
M.slash_commands_id = cmp.register_source("avante_commands", require("cmp_avante.commands"):new())
|
||||||
|
|
||||||
cmp.register_source("avante_mentions", require("cmp_avante.mentions"):new(Utils.get_chat_mentions))
|
cmp.register_source("avante_mentions", require("cmp_avante.mentions"):new(Utils.get_chat_mentions))
|
||||||
|
|
||||||
|
|||||||
@@ -124,8 +124,13 @@ local Utils = require("avante.utils")
|
|||||||
---@class avante.acp.Plan
|
---@class avante.acp.Plan
|
||||||
---@field entries avante.acp.PlanEntry[]
|
---@field entries avante.acp.PlanEntry[]
|
||||||
|
|
||||||
|
---@class avante.acp.AvailableCommand
|
||||||
|
---@field name string
|
||||||
|
---@field description string
|
||||||
|
---@field input? table<string, any>
|
||||||
|
|
||||||
---@class avante.acp.BaseSessionUpdate
|
---@class avante.acp.BaseSessionUpdate
|
||||||
---@field sessionUpdate "user_message_chunk" | "agent_message_chunk" | "agent_thought_chunk" | "tool_call" | "tool_call_update" | "plan"
|
---@field sessionUpdate "user_message_chunk" | "agent_message_chunk" | "agent_thought_chunk" | "tool_call" | "tool_call_update" | "plan" | "available_commands_update"
|
||||||
|
|
||||||
---@class avante.acp.UserMessageChunk : avante.acp.BaseSessionUpdate
|
---@class avante.acp.UserMessageChunk : avante.acp.BaseSessionUpdate
|
||||||
---@field sessionUpdate "user_message_chunk"
|
---@field sessionUpdate "user_message_chunk"
|
||||||
@@ -154,6 +159,10 @@ local Utils = require("avante.utils")
|
|||||||
---@field sessionUpdate "plan"
|
---@field sessionUpdate "plan"
|
||||||
---@field entries avante.acp.PlanEntry[]
|
---@field entries avante.acp.PlanEntry[]
|
||||||
|
|
||||||
|
---@class avante.acp.AvailableCommandsUpdate : avante.acp.BaseSessionUpdate
|
||||||
|
---@field sessionUpdate "available_commands_update"
|
||||||
|
---@field availableCommands avante.acp.AvailableCommand[]
|
||||||
|
|
||||||
---@class avante.acp.PermissionOption
|
---@class avante.acp.PermissionOption
|
||||||
---@field optionId string
|
---@field optionId string
|
||||||
---@field name string
|
---@field name string
|
||||||
@@ -196,7 +205,7 @@ ACPClient.ERROR_CODES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
---@class ACPHandlers
|
---@class ACPHandlers
|
||||||
---@field on_session_update? fun(update: avante.acp.UserMessageChunk | avante.acp.AgentMessageChunk | avante.acp.AgentThoughtChunk | avante.acp.ToolCallUpdate | avante.acp.PlanUpdate)
|
---@field on_session_update? fun(update: avante.acp.UserMessageChunk | avante.acp.AgentMessageChunk | avante.acp.AgentThoughtChunk | avante.acp.ToolCallUpdate | avante.acp.PlanUpdate | avante.acp.AvailableCommandsUpdate)
|
||||||
---@field on_request_permission? fun(tool_call: table, options: table[], callback: fun(option_id: string | nil)): nil
|
---@field on_request_permission? fun(tool_call: table, options: table[], callback: fun(option_id: string | nil)): nil
|
||||||
---@field on_read_file? fun(path: string, line: integer | nil, limit: integer | nil, callback: fun(content: string)): nil
|
---@field on_read_file? fun(path: string, line: integer | nil, limit: integer | nil, callback: fun(content: string)): nil
|
||||||
---@field on_write_file? fun(path: string, content: string, callback: fun(error: string|nil)): nil
|
---@field on_write_file? fun(path: string, content: string, callback: fun(error: string|nil)): nil
|
||||||
|
|||||||
@@ -1054,6 +1054,32 @@ function M._stream_acp(opts)
|
|||||||
if tool_result_message then table.insert(messages, tool_result_message) end
|
if tool_result_message then table.insert(messages, tool_result_message) end
|
||||||
on_messages_add(messages)
|
on_messages_add(messages)
|
||||||
end
|
end
|
||||||
|
if update.sessionUpdate == "available_commands_update" then
|
||||||
|
local commands = update.availableCommands
|
||||||
|
local has_cmp, cmp = pcall(require, "cmp")
|
||||||
|
if has_cmp then
|
||||||
|
local slash_commands_id = require("avante").slash_commands_id
|
||||||
|
if slash_commands_id ~= nil then cmp.unregister_source(slash_commands_id) end
|
||||||
|
for _, command in ipairs(commands) do
|
||||||
|
local exists = false
|
||||||
|
for _, command_ in ipairs(Config.slash_commands) do
|
||||||
|
if command_.name == command.name then
|
||||||
|
exists = true
|
||||||
|
break
|
||||||
|
end
|
||||||
|
end
|
||||||
|
if not exists then
|
||||||
|
table.insert(Config.slash_commands, {
|
||||||
|
name = command.name,
|
||||||
|
description = command.description,
|
||||||
|
details = command.description,
|
||||||
|
})
|
||||||
|
end
|
||||||
|
end
|
||||||
|
local avante = require("avante")
|
||||||
|
avante.slash_commands_id = cmp.register_source("avante_commands", require("cmp_avante.commands"):new())
|
||||||
|
end
|
||||||
|
end
|
||||||
end,
|
end,
|
||||||
on_request_permission = function(tool_call, options, callback)
|
on_request_permission = function(tool_call, options, callback)
|
||||||
local sidebar = require("avante").get()
|
local sidebar = require("avante").get()
|
||||||
@@ -1171,6 +1197,7 @@ function M._stream_acp(opts)
|
|||||||
session_id = session_id_
|
session_id = session_id_
|
||||||
if opts.on_save_acp_session_id then opts.on_save_acp_session_id(session_id) end
|
if opts.on_save_acp_session_id then opts.on_save_acp_session_id(session_id) end
|
||||||
end
|
end
|
||||||
|
if opts.just_connect_acp_client then return end
|
||||||
local prompt = {}
|
local prompt = {}
|
||||||
local donot_use_builtin_system_prompt = opts.history_messages ~= nil and #opts.history_messages > 0
|
local donot_use_builtin_system_prompt = opts.history_messages ~= nil and #opts.history_messages > 0
|
||||||
if donot_use_builtin_system_prompt then
|
if donot_use_builtin_system_prompt then
|
||||||
|
|||||||
@@ -195,6 +195,9 @@ function Sidebar:open(opts)
|
|||||||
vim.g.avante_login = true
|
vim.g.avante_login = true
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local acp_provider = Config.acp_providers[Config.provider]
|
||||||
|
if acp_provider then self:handle_submit("") end
|
||||||
|
|
||||||
return self
|
return self
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -2626,6 +2629,20 @@ function Sidebar:get_generate_prompts_options(request, cb)
|
|||||||
if cb then cb(prompts_opts) end
|
if cb then cb(prompts_opts) end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function Sidebar:submit_input()
|
||||||
|
if not vim.g.avante_login then
|
||||||
|
Utils.warn("Sending message to fast!, API key is not yet set", { title = "Avante" })
|
||||||
|
return
|
||||||
|
end
|
||||||
|
if not Utils.is_valid_container(self.containers.input) then return end
|
||||||
|
local lines = api.nvim_buf_get_lines(self.containers.input.bufnr, 0, -1, false)
|
||||||
|
local request = table.concat(lines, "\n")
|
||||||
|
if request == "" then return end
|
||||||
|
api.nvim_buf_set_lines(self.containers.input.bufnr, 0, -1, false, {})
|
||||||
|
api.nvim_win_set_cursor(self.containers.input.winid, { 1, 0 })
|
||||||
|
self:handle_submit(request)
|
||||||
|
end
|
||||||
|
|
||||||
---@param request string
|
---@param request string
|
||||||
function Sidebar:handle_submit(request)
|
function Sidebar:handle_submit(request)
|
||||||
if Config.prompt_logger.enabled then PromptLogger.log_prompt(request) end
|
if Config.prompt_logger.enabled then PromptLogger.log_prompt(request) end
|
||||||
@@ -2650,16 +2667,18 @@ function Sidebar:handle_submit(request)
|
|||||||
---@type AvanteSlashCommand
|
---@type AvanteSlashCommand
|
||||||
local cmd = vim.iter(cmds):filter(function(cmd) return cmd.name == command end):totable()[1]
|
local cmd = vim.iter(cmds):filter(function(cmd) return cmd.name == command end):totable()[1]
|
||||||
if cmd then
|
if cmd then
|
||||||
if command == "lines" then
|
if cmd.callback then
|
||||||
cmd.callback(self, args, function(args_)
|
if command == "lines" then
|
||||||
local _, _, question = args_:match("(%d+)-(%d+)%s+(.*)")
|
cmd.callback(self, args, function(args_)
|
||||||
request = question
|
local _, _, question = args_:match("(%d+)-(%d+)%s+(.*)")
|
||||||
end)
|
request = question
|
||||||
elseif command == "commit" then
|
end)
|
||||||
cmd.callback(self, args, function(question) request = question end)
|
elseif command == "commit" then
|
||||||
else
|
cmd.callback(self, args, function(question) request = question end)
|
||||||
cmd.callback(self, args)
|
else
|
||||||
return
|
cmd.callback(self, args)
|
||||||
|
return
|
||||||
|
end
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
self:update_content("Unknown command: " .. command, { focus = false, scroll = false })
|
self:update_content("Unknown command: " .. command, { focus = false, scroll = false })
|
||||||
@@ -2681,10 +2700,12 @@ function Sidebar:handle_submit(request)
|
|||||||
content = self.code.selection.content,
|
content = self.code.selection.content,
|
||||||
}
|
}
|
||||||
|
|
||||||
--- HACK: we need to set focus to true and scroll to false to
|
if request ~= "" then
|
||||||
--- prevent the cursor from jumping to the bottom of the
|
--- HACK: we need to set focus to true and scroll to false to
|
||||||
--- buffer at the beginning
|
--- prevent the cursor from jumping to the bottom of the
|
||||||
self:update_content("", { focus = true, scroll = false })
|
--- buffer at the beginning
|
||||||
|
self:update_content("", { focus = true, scroll = false })
|
||||||
|
end
|
||||||
|
|
||||||
---stop scroll when user presses j/k keys
|
---stop scroll when user presses j/k keys
|
||||||
local function on_j()
|
local function on_j()
|
||||||
@@ -2809,6 +2830,7 @@ function Sidebar:handle_submit(request)
|
|||||||
---@type AvanteLLMStreamOptions
|
---@type AvanteLLMStreamOptions
|
||||||
---@diagnostic disable-next-line: assign-type-mismatch
|
---@diagnostic disable-next-line: assign-type-mismatch
|
||||||
local stream_options = vim.tbl_deep_extend("force", generate_prompts_options, {
|
local stream_options = vim.tbl_deep_extend("force", generate_prompts_options, {
|
||||||
|
just_connect_acp_client = request == "",
|
||||||
on_start = on_start,
|
on_start = on_start,
|
||||||
on_stop = on_stop,
|
on_stop = on_stop,
|
||||||
on_tool_log = on_tool_log,
|
on_tool_log = on_tool_log,
|
||||||
@@ -2860,7 +2882,7 @@ function Sidebar:handle_submit(request)
|
|||||||
|
|
||||||
stream_options.on_memory_summarize = on_memory_summarize
|
stream_options.on_memory_summarize = on_memory_summarize
|
||||||
|
|
||||||
on_state_change("generating")
|
if request ~= "" then on_state_change("generating") end
|
||||||
Llm.stream(stream_options)
|
Llm.stream(stream_options)
|
||||||
end)
|
end)
|
||||||
end
|
end
|
||||||
@@ -2909,19 +2931,7 @@ function Sidebar:create_input_container()
|
|||||||
size = get_size(),
|
size = get_size(),
|
||||||
})
|
})
|
||||||
|
|
||||||
local function on_submit()
|
local function on_submit() self:submit_input() end
|
||||||
if not vim.g.avante_login then
|
|
||||||
Utils.warn("Sending message to fast!, API key is not yet set", { title = "Avante" })
|
|
||||||
return
|
|
||||||
end
|
|
||||||
if not Utils.is_valid_container(self.containers.input) then return end
|
|
||||||
local lines = api.nvim_buf_get_lines(self.containers.input.bufnr, 0, -1, false)
|
|
||||||
local request = table.concat(lines, "\n")
|
|
||||||
if request == "" then return end
|
|
||||||
api.nvim_buf_set_lines(self.containers.input.bufnr, 0, -1, false, {})
|
|
||||||
api.nvim_win_set_cursor(self.containers.input.winid, { 1, 0 })
|
|
||||||
self:handle_submit(request)
|
|
||||||
end
|
|
||||||
|
|
||||||
self.containers.input:mount()
|
self.containers.input:mount()
|
||||||
PromptLogger.init()
|
PromptLogger.init()
|
||||||
|
|||||||
@@ -417,6 +417,7 @@ vim.g.avante_login = vim.g.avante_login
|
|||||||
---@class AvanteLLMStreamOptions: AvanteGeneratePromptsOptions
|
---@class AvanteLLMStreamOptions: AvanteGeneratePromptsOptions
|
||||||
---@field acp_client? avante.acp.ACPClient
|
---@field acp_client? avante.acp.ACPClient
|
||||||
---@field on_save_acp_client? fun(client: avante.acp.ACPClient): nil
|
---@field on_save_acp_client? fun(client: avante.acp.ACPClient): nil
|
||||||
|
---@field just_connect_acp_client? boolean
|
||||||
---@field acp_session_id? string
|
---@field acp_session_id? string
|
||||||
---@field on_save_acp_session_id? fun(session_id: string): nil
|
---@field on_save_acp_session_id? fun(session_id: string): nil
|
||||||
---@field on_start AvanteLLMStartCallback
|
---@field on_start AvanteLLMStartCallback
|
||||||
|
|||||||
@@ -1626,7 +1626,22 @@ function M.get_commands()
|
|||||||
)
|
)
|
||||||
:totable()
|
:totable()
|
||||||
|
|
||||||
return vim.list_extend(builtin_commands, Config.slash_commands)
|
local commands = {}
|
||||||
|
local seen = {}
|
||||||
|
for _, command in ipairs(Config.slash_commands) do
|
||||||
|
if not seen[command.name] then
|
||||||
|
table.insert(commands, command)
|
||||||
|
seen[command.name] = true
|
||||||
|
end
|
||||||
|
end
|
||||||
|
for _, command in ipairs(builtin_commands) do
|
||||||
|
if not seen[command.name] then
|
||||||
|
table.insert(commands, command)
|
||||||
|
seen[command.name] = true
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return commands
|
||||||
end
|
end
|
||||||
|
|
||||||
function M.get_timestamp() return tostring(os.date("%Y-%m-%d %H:%M:%S")) end
|
function M.get_timestamp() return tostring(os.date("%Y-%m-%d %H:%M:%S")) end
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ function M.init()
|
|||||||
end
|
end
|
||||||
|
|
||||||
function M.log_prompt(request)
|
function M.log_prompt(request)
|
||||||
|
if request == "" then return end
|
||||||
local log_dir = Config.prompt_logger.log_dir
|
local log_dir = Config.prompt_logger.log_dir
|
||||||
local log_file = Utils.join_paths(log_dir, "avante_prompts.log")
|
local log_file = Utils.join_paths(log_dir, "avante_prompts.log")
|
||||||
|
|
||||||
|
|||||||
@@ -56,9 +56,18 @@ function CommandsSource:execute(item, callback)
|
|||||||
local commands = Utils.get_commands()
|
local commands = Utils.get_commands()
|
||||||
local command = vim.iter(commands):find(function(command) return command.name == item.data.name end)
|
local command = vim.iter(commands):find(function(command) return command.name == item.data.name end)
|
||||||
|
|
||||||
if not command then return end
|
if not command then
|
||||||
|
callback()
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
local sidebar = require("avante").get()
|
local sidebar = require("avante").get()
|
||||||
|
if not command.callback then
|
||||||
|
if sidebar then sidebar:submit_input() end
|
||||||
|
callback()
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
command.callback(sidebar, nil, function()
|
command.callback(sidebar, nil, function()
|
||||||
local bufnr = sidebar.containers.input.bufnr ---@type integer
|
local bufnr = sidebar.containers.input.bufnr ---@type integer
|
||||||
local content = table.concat(api.nvim_buf_get_lines(bufnr, 0, -1, false), "\n")
|
local content = table.concat(api.nvim_buf_get_lines(bufnr, 0, -1, false), "\n")
|
||||||
|
|||||||
Reference in New Issue
Block a user