diff --git a/README.md b/README.md index 5a966f0..9f63cd7 100644 --- a/README.md +++ b/README.md @@ -611,6 +611,7 @@ The following key bindings are available for use with `avante.nvim`: | Leaderaf | switch sidebar focus | | Leadera? | select model | | Leaderae | edit selected blocks | +| LeaderaS | stop current AI request | | co | choose ours | | ct | choose theirs | | ca | choose all theirs | @@ -687,6 +688,7 @@ return { | `:AvanteEdit` | Edit the selected code blocks | | | `:AvanteFocus` | Switch focus to/from the sidebar | | | `:AvanteRefresh` | Refresh all Avante windows | | +| `:AvanteStop` | Stop the current AI request | | | `:AvanteSwitchProvider` | Switch AI provider (e.g. openai) | | | `:AvanteShowRepoMap` | Show repo map for project's structure | | | `:AvanteToggle` | Toggle the Avante sidebar | | diff --git a/lua/avante/api.lua b/lua/avante/api.lua index 6e8592a..9ac892b 100644 --- a/lua/avante/api.lua +++ b/lua/avante/api.lua @@ -237,6 +237,8 @@ function M.select_history() end) end +function M.stop() require("avante.llm").cancel_inflight_request() end + return setmetatable(M, { __index = function(t, k) local module = require("avante") diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 44856ae..d8a31c1 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -387,6 +387,7 @@ M._defaults = { edit = "ae", refresh = "ar", focus = "af", + stop = "aS", toggle = { default = "at", debug = "ad", diff --git a/lua/avante/init.lua b/lua/avante/init.lua index 7c2c7da..0a55e8c 100644 --- a/lua/avante/init.lua +++ b/lua/avante/init.lua @@ -94,6 +94,12 @@ function H.keymaps() function() require("avante.api").edit() end, { desc = "avante: edit" } ) + Utils.safe_keymap_set( + "n", + Config.mappings.stop, + function() require("avante.api").stop() end, + { desc = "avante: stop" } + ) Utils.safe_keymap_set( "n", Config.mappings.refresh, diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index eac1628..11609da 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -385,10 +385,13 @@ function M.curl(opts) ) end end + active_job = nil - completed = true - cleanup() - handler_opts.on_stop({ reason = "error", error = err }) + if not completed then + completed = true + cleanup() + handler_opts.on_stop({ reason = "error", error = err }) + end end, callback = function(result) active_job = nil @@ -452,9 +455,21 @@ function M.curl(opts) callback = function() -- Error: cannot resume dead coroutine if active_job then - xpcall(function() active_job:shutdown() end, function(err) return err end) + -- Mark as completed first to prevent error handler from running + completed = true + + -- Attempt to shutdown the active job, but ignore any errors + xpcall(function() active_job:shutdown() end, function(err) + Utils.debug("Ignored error during job shutdown: " .. vim.inspect(err)) + return err + end) + Utils.debug("LLM request cancelled") active_job = nil + + -- Clean up and notify of cancellation + cleanup() + vim.schedule(function() handler_opts.on_stop({ reason = "cancelled" }) end) end end, }) @@ -464,6 +479,9 @@ end ---@param opts AvanteLLMStreamOptions function M._stream(opts) + -- Reset the cancellation flag at the start of a new request + if LLMTools then LLMTools.is_cancelled = false end + local provider = opts.provider or Providers[Config.provider] ---@cast provider AvanteProviderFunctor @@ -500,6 +518,13 @@ function M._stream(opts) ---@param result string | nil ---@param error string | nil local function handle_tool_result(result, error) + -- Special handling for cancellation signal from tools + if error == LLMTools.CANCEL_TOKEN then + Utils.debug("Tool execution was cancelled by user") + opts.on_chunk("\n*[Request cancelled by user during tool execution.]*\n") + return opts.on_stop({ reason = "cancelled", tool_histories = tool_histories }) + end + local tool_result = { tool_use_id = tool_use.id, content = error ~= nil and error or result, @@ -512,6 +537,10 @@ function M._stream(opts) local result, error = LLMTools.process_tool_use(opts.tools, tool_use, opts.on_tool_log, handle_tool_result) if result ~= nil or error ~= nil then return handle_tool_result(result, error) end end + if stop_opts.reason == "cancelled" then + opts.on_chunk("\n*[Request cancelled by user.]*\n") + return opts.on_stop({ reason = "cancelled", tool_histories = opts.tool_histories }) + end if stop_opts.reason == "tool_use" and stop_opts.tool_use_list then local old_tool_histories = vim.deepcopy(opts.tool_histories) or {} local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[] @@ -667,7 +696,9 @@ function M.stream(opts) local original_on_stop = opts.on_stop opts.on_stop = vim.schedule_wrap(function(stop_opts) if is_completed then return end - if stop_opts.reason == "complete" or stop_opts.reason == "error" then is_completed = true end + if stop_opts.reason == "complete" or stop_opts.reason == "error" or stop_opts.reason == "cancelled" then + is_completed = true + end return original_on_stop(stop_opts) end) end @@ -690,6 +721,14 @@ function M.stream(opts) end end -function M.cancel_inflight_request() api.nvim_exec_autocmds("User", { pattern = M.CANCEL_PATTERN }) end +function M.cancel_inflight_request() + if LLMTools.is_cancelled ~= nil then LLMTools.is_cancelled = true end + if LLMTools.confirm_popup ~= nil then + LLMTools.confirm_popup:cancel() + LLMTools.confirm_popup = nil + end + + api.nvim_exec_autocmds("User", { pattern = M.CANCEL_PATTERN }) +end return M diff --git a/lua/avante/llm_tools.lua b/lua/avante/llm_tools.lua index a65eb75..3a533fc 100644 --- a/lua/avante/llm_tools.lua +++ b/lua/avante/llm_tools.lua @@ -9,6 +9,13 @@ local Highlights = require("avante.highlights") ---@class AvanteRagService local M = {} +M.CANCEL_TOKEN = "__CANCELLED__" + +-- Track cancellation state +M.is_cancelled = false +---@type avante.ui.Confirm +M.confirm_popup = nil + ---@param rel_path string ---@return string local function get_abs_path(rel_path) @@ -32,9 +39,9 @@ function M.confirm(message, callback, opts) return end local confirm_opts = vim.tbl_deep_extend("force", { container_winid = sidebar.input_container.winid }, opts or {}) - local confirm = Confirm:new(message, callback, confirm_opts) - confirm:open() - return confirm + M.confirm_popup = Confirm:new(message, callback, confirm_opts) + M.confirm_popup:open() + return M.confirm_popup end ---@param abs_path string @@ -1634,6 +1641,17 @@ M._tools = { ---@return string | nil error function M.process_tool_use(tools, tool_use, on_log, on_complete) Utils.debug("use tool", tool_use.name, tool_use.input_json) + + -- Check if execution is already cancelled + if M.is_cancelled then + Utils.debug("Tool execution cancelled before starting: " .. tool_use.name) + if on_complete then + on_complete(nil, M.CANCEL_TOKEN) + return + end + return nil, M.CANCEL_TOKEN + end + local func if tool_use.name == "str_replace_editor" then func = M.str_replace_editor @@ -1646,9 +1664,44 @@ function M.process_tool_use(tools, tool_use, on_log, on_complete) local input_json = vim.json.decode(tool_use.input_json) if not func then return nil, "Tool not found: " .. tool_use.name end if on_log then on_log(tool_use.name, "running tool") end + + -- Set up a timer to periodically check for cancellation + local cancel_timer + if on_complete then + cancel_timer = vim.loop.new_timer() + if cancel_timer then + cancel_timer:start( + 100, + 100, + vim.schedule_wrap(function() + if M.is_cancelled then + Utils.debug("Tool execution cancelled during execution: " .. tool_use.name) + if cancel_timer and not cancel_timer:is_closing() then + cancel_timer:stop() + cancel_timer:close() + end + on_complete(nil, M.CANCEL_TOKEN) + end + end) + ) + end + end + ---@param result string | nil | boolean ---@param err string | nil local function handle_result(result, err) + -- Stop the cancellation timer if it exists + if cancel_timer and not cancel_timer:is_closing() then + cancel_timer:stop() + cancel_timer:close() + end + + -- Check for cancellation one more time before processing result + if M.is_cancelled then + if on_log then on_log(tool_use.name, "cancelled during result handling") end + return nil, M.CANCEL_TOKEN + end + if on_log then on_log(tool_use.name, "tool finished") end -- Utils.debug("result", result) -- Utils.debug("error", error) @@ -1663,9 +1716,18 @@ function M.process_tool_use(tools, tool_use, on_log, on_complete) end return result_str, err end + local result, err = func(input_json, function(log) + -- Check for cancellation during logging + if M.is_cancelled then return end if on_log then on_log(tool_use.name, log) end end, function(result, err) + -- Check for cancellation before completing + if M.is_cancelled then + if on_complete then on_complete(nil, M.CANCEL_TOKEN) end + return + end + result, err = handle_result(result, err) if on_complete == nil then Utils.error("asynchronous tool " .. tool_use.name .. " result not handled") @@ -1673,6 +1735,7 @@ function M.process_tool_use(tools, tool_use, on_log, on_complete) end on_complete(result, err) end) + -- Result and error being nil means that the tool was executed asynchronously if result == nil and err == nil and on_complete then return end return handle_result(result, err) diff --git a/lua/avante/types.lua b/lua/avante/types.lua index fde6116..686f88d 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -245,7 +245,7 @@ vim.g.avante_login = vim.g.avante_login ---@field usage? AvanteLLMUsage --- ---@class AvanteLLMStopCallbackOptions ----@field reason "complete" | "tool_use" | "error" | "rate_limit" +---@field reason "complete" | "tool_use" | "error" | "rate_limit" | "cancelled" ---@field error? string | table ---@field usage? AvanteLLMUsage ---@field tool_use_list? AvanteLLMToolUse[] diff --git a/lua/avante/ui/confirm.lua b/lua/avante/ui/confirm.lua index f4e9f68..c358143 100644 --- a/lua/avante/ui/confirm.lua +++ b/lua/avante/ui/confirm.lua @@ -253,6 +253,11 @@ end function M:unbind_window_focus_keymaps() vim.keymap.del({ "n", "i" }, "f") end +function M:cancel() + self.callback(false) + return self:close() +end + function M:close() self:unbind_window_focus_keymaps() if self._group then diff --git a/plugin/avante.lua b/plugin/avante.lua index 3eadc70..77b9c7b 100644 --- a/plugin/avante.lua +++ b/plugin/avante.lua @@ -156,3 +156,4 @@ end, { cmd("ShowRepoMap", function() require("avante.repo_map").show() end, { desc = "avante: show repo map" }) cmd("Models", function() require("avante.model_selector").open() end, { desc = "avante: show models" }) cmd("History", function() require("avante.api").select_history() end, { desc = "avante: show histories" }) +cmd("Stop", function() require("avante.api").stop() end, { desc = "avante: stop current AI request" })