feat: add stop sequence (#1652)
This commit is contained in:
@@ -611,6 +611,7 @@ The following key bindings are available for use with `avante.nvim`:
|
|||||||
| <kbd>Leader</kbd><kbd>a</kbd><kbd>f</kbd> | switch sidebar focus |
|
| <kbd>Leader</kbd><kbd>a</kbd><kbd>f</kbd> | switch sidebar focus |
|
||||||
| <kbd>Leader</kbd><kbd>a</kbd><kbd>?</kbd> | select model |
|
| <kbd>Leader</kbd><kbd>a</kbd><kbd>?</kbd> | select model |
|
||||||
| <kbd>Leader</kbd><kbd>a</kbd><kbd>e</kbd> | edit selected blocks |
|
| <kbd>Leader</kbd><kbd>a</kbd><kbd>e</kbd> | edit selected blocks |
|
||||||
|
| <kbd>Leader</kbd><kbd>a</kbd><kbd>S</kbd> | stop current AI request |
|
||||||
| <kbd>c</kbd><kbd>o</kbd> | choose ours |
|
| <kbd>c</kbd><kbd>o</kbd> | choose ours |
|
||||||
| <kbd>c</kbd><kbd>t</kbd> | choose theirs |
|
| <kbd>c</kbd><kbd>t</kbd> | choose theirs |
|
||||||
| <kbd>c</kbd><kbd>a</kbd> | choose all theirs |
|
| <kbd>c</kbd><kbd>a</kbd> | choose all theirs |
|
||||||
@@ -687,6 +688,7 @@ return {
|
|||||||
| `:AvanteEdit` | Edit the selected code blocks | |
|
| `:AvanteEdit` | Edit the selected code blocks | |
|
||||||
| `:AvanteFocus` | Switch focus to/from the sidebar | |
|
| `:AvanteFocus` | Switch focus to/from the sidebar | |
|
||||||
| `:AvanteRefresh` | Refresh all Avante windows | |
|
| `:AvanteRefresh` | Refresh all Avante windows | |
|
||||||
|
| `:AvanteStop` | Stop the current AI request | |
|
||||||
| `:AvanteSwitchProvider` | Switch AI provider (e.g. openai) | |
|
| `:AvanteSwitchProvider` | Switch AI provider (e.g. openai) | |
|
||||||
| `:AvanteShowRepoMap` | Show repo map for project's structure | |
|
| `:AvanteShowRepoMap` | Show repo map for project's structure | |
|
||||||
| `:AvanteToggle` | Toggle the Avante sidebar | |
|
| `:AvanteToggle` | Toggle the Avante sidebar | |
|
||||||
|
|||||||
@@ -237,6 +237,8 @@ function M.select_history()
|
|||||||
end)
|
end)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function M.stop() require("avante.llm").cancel_inflight_request() end
|
||||||
|
|
||||||
return setmetatable(M, {
|
return setmetatable(M, {
|
||||||
__index = function(t, k)
|
__index = function(t, k)
|
||||||
local module = require("avante")
|
local module = require("avante")
|
||||||
|
|||||||
@@ -387,6 +387,7 @@ M._defaults = {
|
|||||||
edit = "<leader>ae",
|
edit = "<leader>ae",
|
||||||
refresh = "<leader>ar",
|
refresh = "<leader>ar",
|
||||||
focus = "<leader>af",
|
focus = "<leader>af",
|
||||||
|
stop = "<leader>aS",
|
||||||
toggle = {
|
toggle = {
|
||||||
default = "<leader>at",
|
default = "<leader>at",
|
||||||
debug = "<leader>ad",
|
debug = "<leader>ad",
|
||||||
|
|||||||
@@ -94,6 +94,12 @@ function H.keymaps()
|
|||||||
function() require("avante.api").edit() end,
|
function() require("avante.api").edit() end,
|
||||||
{ desc = "avante: edit" }
|
{ desc = "avante: edit" }
|
||||||
)
|
)
|
||||||
|
Utils.safe_keymap_set(
|
||||||
|
"n",
|
||||||
|
Config.mappings.stop,
|
||||||
|
function() require("avante.api").stop() end,
|
||||||
|
{ desc = "avante: stop" }
|
||||||
|
)
|
||||||
Utils.safe_keymap_set(
|
Utils.safe_keymap_set(
|
||||||
"n",
|
"n",
|
||||||
Config.mappings.refresh,
|
Config.mappings.refresh,
|
||||||
|
|||||||
@@ -385,10 +385,13 @@ function M.curl(opts)
|
|||||||
)
|
)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
active_job = nil
|
active_job = nil
|
||||||
completed = true
|
if not completed then
|
||||||
cleanup()
|
completed = true
|
||||||
handler_opts.on_stop({ reason = "error", error = err })
|
cleanup()
|
||||||
|
handler_opts.on_stop({ reason = "error", error = err })
|
||||||
|
end
|
||||||
end,
|
end,
|
||||||
callback = function(result)
|
callback = function(result)
|
||||||
active_job = nil
|
active_job = nil
|
||||||
@@ -452,9 +455,21 @@ function M.curl(opts)
|
|||||||
callback = function()
|
callback = function()
|
||||||
-- Error: cannot resume dead coroutine
|
-- Error: cannot resume dead coroutine
|
||||||
if active_job then
|
if active_job then
|
||||||
xpcall(function() active_job:shutdown() end, function(err) return err end)
|
-- Mark as completed first to prevent error handler from running
|
||||||
|
completed = true
|
||||||
|
|
||||||
|
-- Attempt to shutdown the active job, but ignore any errors
|
||||||
|
xpcall(function() active_job:shutdown() end, function(err)
|
||||||
|
Utils.debug("Ignored error during job shutdown: " .. vim.inspect(err))
|
||||||
|
return err
|
||||||
|
end)
|
||||||
|
|
||||||
Utils.debug("LLM request cancelled")
|
Utils.debug("LLM request cancelled")
|
||||||
active_job = nil
|
active_job = nil
|
||||||
|
|
||||||
|
-- Clean up and notify of cancellation
|
||||||
|
cleanup()
|
||||||
|
vim.schedule(function() handler_opts.on_stop({ reason = "cancelled" }) end)
|
||||||
end
|
end
|
||||||
end,
|
end,
|
||||||
})
|
})
|
||||||
@@ -464,6 +479,9 @@ end
|
|||||||
|
|
||||||
---@param opts AvanteLLMStreamOptions
|
---@param opts AvanteLLMStreamOptions
|
||||||
function M._stream(opts)
|
function M._stream(opts)
|
||||||
|
-- Reset the cancellation flag at the start of a new request
|
||||||
|
if LLMTools then LLMTools.is_cancelled = false end
|
||||||
|
|
||||||
local provider = opts.provider or Providers[Config.provider]
|
local provider = opts.provider or Providers[Config.provider]
|
||||||
|
|
||||||
---@cast provider AvanteProviderFunctor
|
---@cast provider AvanteProviderFunctor
|
||||||
@@ -500,6 +518,13 @@ function M._stream(opts)
|
|||||||
---@param result string | nil
|
---@param result string | nil
|
||||||
---@param error string | nil
|
---@param error string | nil
|
||||||
local function handle_tool_result(result, error)
|
local function handle_tool_result(result, error)
|
||||||
|
-- Special handling for cancellation signal from tools
|
||||||
|
if error == LLMTools.CANCEL_TOKEN then
|
||||||
|
Utils.debug("Tool execution was cancelled by user")
|
||||||
|
opts.on_chunk("\n*[Request cancelled by user during tool execution.]*\n")
|
||||||
|
return opts.on_stop({ reason = "cancelled", tool_histories = tool_histories })
|
||||||
|
end
|
||||||
|
|
||||||
local tool_result = {
|
local tool_result = {
|
||||||
tool_use_id = tool_use.id,
|
tool_use_id = tool_use.id,
|
||||||
content = error ~= nil and error or result,
|
content = error ~= nil and error or result,
|
||||||
@@ -512,6 +537,10 @@ function M._stream(opts)
|
|||||||
local result, error = LLMTools.process_tool_use(opts.tools, tool_use, opts.on_tool_log, handle_tool_result)
|
local result, error = LLMTools.process_tool_use(opts.tools, tool_use, opts.on_tool_log, handle_tool_result)
|
||||||
if result ~= nil or error ~= nil then return handle_tool_result(result, error) end
|
if result ~= nil or error ~= nil then return handle_tool_result(result, error) end
|
||||||
end
|
end
|
||||||
|
if stop_opts.reason == "cancelled" then
|
||||||
|
opts.on_chunk("\n*[Request cancelled by user.]*\n")
|
||||||
|
return opts.on_stop({ reason = "cancelled", tool_histories = opts.tool_histories })
|
||||||
|
end
|
||||||
if stop_opts.reason == "tool_use" and stop_opts.tool_use_list then
|
if stop_opts.reason == "tool_use" and stop_opts.tool_use_list then
|
||||||
local old_tool_histories = vim.deepcopy(opts.tool_histories) or {}
|
local old_tool_histories = vim.deepcopy(opts.tool_histories) or {}
|
||||||
local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[]
|
local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[]
|
||||||
@@ -667,7 +696,9 @@ function M.stream(opts)
|
|||||||
local original_on_stop = opts.on_stop
|
local original_on_stop = opts.on_stop
|
||||||
opts.on_stop = vim.schedule_wrap(function(stop_opts)
|
opts.on_stop = vim.schedule_wrap(function(stop_opts)
|
||||||
if is_completed then return end
|
if is_completed then return end
|
||||||
if stop_opts.reason == "complete" or stop_opts.reason == "error" then is_completed = true end
|
if stop_opts.reason == "complete" or stop_opts.reason == "error" or stop_opts.reason == "cancelled" then
|
||||||
|
is_completed = true
|
||||||
|
end
|
||||||
return original_on_stop(stop_opts)
|
return original_on_stop(stop_opts)
|
||||||
end)
|
end)
|
||||||
end
|
end
|
||||||
@@ -690,6 +721,14 @@ function M.stream(opts)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function M.cancel_inflight_request() api.nvim_exec_autocmds("User", { pattern = M.CANCEL_PATTERN }) end
|
function M.cancel_inflight_request()
|
||||||
|
if LLMTools.is_cancelled ~= nil then LLMTools.is_cancelled = true end
|
||||||
|
if LLMTools.confirm_popup ~= nil then
|
||||||
|
LLMTools.confirm_popup:cancel()
|
||||||
|
LLMTools.confirm_popup = nil
|
||||||
|
end
|
||||||
|
|
||||||
|
api.nvim_exec_autocmds("User", { pattern = M.CANCEL_PATTERN })
|
||||||
|
end
|
||||||
|
|
||||||
return M
|
return M
|
||||||
|
|||||||
@@ -9,6 +9,13 @@ local Highlights = require("avante.highlights")
|
|||||||
---@class AvanteRagService
|
---@class AvanteRagService
|
||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
|
M.CANCEL_TOKEN = "__CANCELLED__"
|
||||||
|
|
||||||
|
-- Track cancellation state
|
||||||
|
M.is_cancelled = false
|
||||||
|
---@type avante.ui.Confirm
|
||||||
|
M.confirm_popup = nil
|
||||||
|
|
||||||
---@param rel_path string
|
---@param rel_path string
|
||||||
---@return string
|
---@return string
|
||||||
local function get_abs_path(rel_path)
|
local function get_abs_path(rel_path)
|
||||||
@@ -32,9 +39,9 @@ function M.confirm(message, callback, opts)
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
local confirm_opts = vim.tbl_deep_extend("force", { container_winid = sidebar.input_container.winid }, opts or {})
|
local confirm_opts = vim.tbl_deep_extend("force", { container_winid = sidebar.input_container.winid }, opts or {})
|
||||||
local confirm = Confirm:new(message, callback, confirm_opts)
|
M.confirm_popup = Confirm:new(message, callback, confirm_opts)
|
||||||
confirm:open()
|
M.confirm_popup:open()
|
||||||
return confirm
|
return M.confirm_popup
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param abs_path string
|
---@param abs_path string
|
||||||
@@ -1634,6 +1641,17 @@ M._tools = {
|
|||||||
---@return string | nil error
|
---@return string | nil error
|
||||||
function M.process_tool_use(tools, tool_use, on_log, on_complete)
|
function M.process_tool_use(tools, tool_use, on_log, on_complete)
|
||||||
Utils.debug("use tool", tool_use.name, tool_use.input_json)
|
Utils.debug("use tool", tool_use.name, tool_use.input_json)
|
||||||
|
|
||||||
|
-- Check if execution is already cancelled
|
||||||
|
if M.is_cancelled then
|
||||||
|
Utils.debug("Tool execution cancelled before starting: " .. tool_use.name)
|
||||||
|
if on_complete then
|
||||||
|
on_complete(nil, M.CANCEL_TOKEN)
|
||||||
|
return
|
||||||
|
end
|
||||||
|
return nil, M.CANCEL_TOKEN
|
||||||
|
end
|
||||||
|
|
||||||
local func
|
local func
|
||||||
if tool_use.name == "str_replace_editor" then
|
if tool_use.name == "str_replace_editor" then
|
||||||
func = M.str_replace_editor
|
func = M.str_replace_editor
|
||||||
@@ -1646,9 +1664,44 @@ function M.process_tool_use(tools, tool_use, on_log, on_complete)
|
|||||||
local input_json = vim.json.decode(tool_use.input_json)
|
local input_json = vim.json.decode(tool_use.input_json)
|
||||||
if not func then return nil, "Tool not found: " .. tool_use.name end
|
if not func then return nil, "Tool not found: " .. tool_use.name end
|
||||||
if on_log then on_log(tool_use.name, "running tool") end
|
if on_log then on_log(tool_use.name, "running tool") end
|
||||||
|
|
||||||
|
-- Set up a timer to periodically check for cancellation
|
||||||
|
local cancel_timer
|
||||||
|
if on_complete then
|
||||||
|
cancel_timer = vim.loop.new_timer()
|
||||||
|
if cancel_timer then
|
||||||
|
cancel_timer:start(
|
||||||
|
100,
|
||||||
|
100,
|
||||||
|
vim.schedule_wrap(function()
|
||||||
|
if M.is_cancelled then
|
||||||
|
Utils.debug("Tool execution cancelled during execution: " .. tool_use.name)
|
||||||
|
if cancel_timer and not cancel_timer:is_closing() then
|
||||||
|
cancel_timer:stop()
|
||||||
|
cancel_timer:close()
|
||||||
|
end
|
||||||
|
on_complete(nil, M.CANCEL_TOKEN)
|
||||||
|
end
|
||||||
|
end)
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
---@param result string | nil | boolean
|
---@param result string | nil | boolean
|
||||||
---@param err string | nil
|
---@param err string | nil
|
||||||
local function handle_result(result, err)
|
local function handle_result(result, err)
|
||||||
|
-- Stop the cancellation timer if it exists
|
||||||
|
if cancel_timer and not cancel_timer:is_closing() then
|
||||||
|
cancel_timer:stop()
|
||||||
|
cancel_timer:close()
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Check for cancellation one more time before processing result
|
||||||
|
if M.is_cancelled then
|
||||||
|
if on_log then on_log(tool_use.name, "cancelled during result handling") end
|
||||||
|
return nil, M.CANCEL_TOKEN
|
||||||
|
end
|
||||||
|
|
||||||
if on_log then on_log(tool_use.name, "tool finished") end
|
if on_log then on_log(tool_use.name, "tool finished") end
|
||||||
-- Utils.debug("result", result)
|
-- Utils.debug("result", result)
|
||||||
-- Utils.debug("error", error)
|
-- Utils.debug("error", error)
|
||||||
@@ -1663,9 +1716,18 @@ function M.process_tool_use(tools, tool_use, on_log, on_complete)
|
|||||||
end
|
end
|
||||||
return result_str, err
|
return result_str, err
|
||||||
end
|
end
|
||||||
|
|
||||||
local result, err = func(input_json, function(log)
|
local result, err = func(input_json, function(log)
|
||||||
|
-- Check for cancellation during logging
|
||||||
|
if M.is_cancelled then return end
|
||||||
if on_log then on_log(tool_use.name, log) end
|
if on_log then on_log(tool_use.name, log) end
|
||||||
end, function(result, err)
|
end, function(result, err)
|
||||||
|
-- Check for cancellation before completing
|
||||||
|
if M.is_cancelled then
|
||||||
|
if on_complete then on_complete(nil, M.CANCEL_TOKEN) end
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
result, err = handle_result(result, err)
|
result, err = handle_result(result, err)
|
||||||
if on_complete == nil then
|
if on_complete == nil then
|
||||||
Utils.error("asynchronous tool " .. tool_use.name .. " result not handled")
|
Utils.error("asynchronous tool " .. tool_use.name .. " result not handled")
|
||||||
@@ -1673,6 +1735,7 @@ function M.process_tool_use(tools, tool_use, on_log, on_complete)
|
|||||||
end
|
end
|
||||||
on_complete(result, err)
|
on_complete(result, err)
|
||||||
end)
|
end)
|
||||||
|
|
||||||
-- Result and error being nil means that the tool was executed asynchronously
|
-- Result and error being nil means that the tool was executed asynchronously
|
||||||
if result == nil and err == nil and on_complete then return end
|
if result == nil and err == nil and on_complete then return end
|
||||||
return handle_result(result, err)
|
return handle_result(result, err)
|
||||||
|
|||||||
@@ -245,7 +245,7 @@ vim.g.avante_login = vim.g.avante_login
|
|||||||
---@field usage? AvanteLLMUsage
|
---@field usage? AvanteLLMUsage
|
||||||
---
|
---
|
||||||
---@class AvanteLLMStopCallbackOptions
|
---@class AvanteLLMStopCallbackOptions
|
||||||
---@field reason "complete" | "tool_use" | "error" | "rate_limit"
|
---@field reason "complete" | "tool_use" | "error" | "rate_limit" | "cancelled"
|
||||||
---@field error? string | table
|
---@field error? string | table
|
||||||
---@field usage? AvanteLLMUsage
|
---@field usage? AvanteLLMUsage
|
||||||
---@field tool_use_list? AvanteLLMToolUse[]
|
---@field tool_use_list? AvanteLLMToolUse[]
|
||||||
|
|||||||
@@ -253,6 +253,11 @@ end
|
|||||||
|
|
||||||
function M:unbind_window_focus_keymaps() vim.keymap.del({ "n", "i" }, "<C-w>f") end
|
function M:unbind_window_focus_keymaps() vim.keymap.del({ "n", "i" }, "<C-w>f") end
|
||||||
|
|
||||||
|
function M:cancel()
|
||||||
|
self.callback(false)
|
||||||
|
return self:close()
|
||||||
|
end
|
||||||
|
|
||||||
function M:close()
|
function M:close()
|
||||||
self:unbind_window_focus_keymaps()
|
self:unbind_window_focus_keymaps()
|
||||||
if self._group then
|
if self._group then
|
||||||
|
|||||||
@@ -156,3 +156,4 @@ end, {
|
|||||||
cmd("ShowRepoMap", function() require("avante.repo_map").show() end, { desc = "avante: show repo map" })
|
cmd("ShowRepoMap", function() require("avante.repo_map").show() end, { desc = "avante: show repo map" })
|
||||||
cmd("Models", function() require("avante.model_selector").open() end, { desc = "avante: show models" })
|
cmd("Models", function() require("avante.model_selector").open() end, { desc = "avante: show models" })
|
||||||
cmd("History", function() require("avante.api").select_history() end, { desc = "avante: show histories" })
|
cmd("History", function() require("avante.api").select_history() end, { desc = "avante: show histories" })
|
||||||
|
cmd("Stop", function() require("avante.api").stop() end, { desc = "avante: stop current AI request" })
|
||||||
|
|||||||
Reference in New Issue
Block a user