feat: add stop sequence (#1652)

This commit is contained in:
Omar Crespo
2025-03-21 06:34:33 -05:00
committed by GitHub
parent d1fa11ec41
commit bae5275705
9 changed files with 129 additions and 10 deletions

View File

@@ -9,6 +9,13 @@ local Highlights = require("avante.highlights")
---@class AvanteRagService
local M = {}
M.CANCEL_TOKEN = "__CANCELLED__"
-- Track cancellation state
M.is_cancelled = false
---@type avante.ui.Confirm
M.confirm_popup = nil
---@param rel_path string
---@return string
local function get_abs_path(rel_path)
@@ -32,9 +39,9 @@ function M.confirm(message, callback, opts)
return
end
local confirm_opts = vim.tbl_deep_extend("force", { container_winid = sidebar.input_container.winid }, opts or {})
local confirm = Confirm:new(message, callback, confirm_opts)
confirm:open()
return confirm
M.confirm_popup = Confirm:new(message, callback, confirm_opts)
M.confirm_popup:open()
return M.confirm_popup
end
---@param abs_path string
@@ -1634,6 +1641,17 @@ M._tools = {
---@return string | nil error
function M.process_tool_use(tools, tool_use, on_log, on_complete)
Utils.debug("use tool", tool_use.name, tool_use.input_json)
-- Check if execution is already cancelled
if M.is_cancelled then
Utils.debug("Tool execution cancelled before starting: " .. tool_use.name)
if on_complete then
on_complete(nil, M.CANCEL_TOKEN)
return
end
return nil, M.CANCEL_TOKEN
end
local func
if tool_use.name == "str_replace_editor" then
func = M.str_replace_editor
@@ -1646,9 +1664,44 @@ function M.process_tool_use(tools, tool_use, on_log, on_complete)
local input_json = vim.json.decode(tool_use.input_json)
if not func then return nil, "Tool not found: " .. tool_use.name end
if on_log then on_log(tool_use.name, "running tool") end
-- Set up a timer to periodically check for cancellation
local cancel_timer
if on_complete then
cancel_timer = vim.loop.new_timer()
if cancel_timer then
cancel_timer:start(
100,
100,
vim.schedule_wrap(function()
if M.is_cancelled then
Utils.debug("Tool execution cancelled during execution: " .. tool_use.name)
if cancel_timer and not cancel_timer:is_closing() then
cancel_timer:stop()
cancel_timer:close()
end
on_complete(nil, M.CANCEL_TOKEN)
end
end)
)
end
end
---@param result string | nil | boolean
---@param err string | nil
local function handle_result(result, err)
-- Stop the cancellation timer if it exists
if cancel_timer and not cancel_timer:is_closing() then
cancel_timer:stop()
cancel_timer:close()
end
-- Check for cancellation one more time before processing result
if M.is_cancelled then
if on_log then on_log(tool_use.name, "cancelled during result handling") end
return nil, M.CANCEL_TOKEN
end
if on_log then on_log(tool_use.name, "tool finished") end
-- Utils.debug("result", result)
-- Utils.debug("error", error)
@@ -1663,9 +1716,18 @@ function M.process_tool_use(tools, tool_use, on_log, on_complete)
end
return result_str, err
end
local result, err = func(input_json, function(log)
-- Check for cancellation during logging
if M.is_cancelled then return end
if on_log then on_log(tool_use.name, log) end
end, function(result, err)
-- Check for cancellation before completing
if M.is_cancelled then
if on_complete then on_complete(nil, M.CANCEL_TOKEN) end
return
end
result, err = handle_result(result, err)
if on_complete == nil then
Utils.error("asynchronous tool " .. tool_use.name .. " result not handled")
@@ -1673,6 +1735,7 @@ function M.process_tool_use(tools, tool_use, on_log, on_complete)
end
on_complete(result, err)
end)
-- Result and error being nil means that the tool was executed asynchronously
if result == nil and err == nil and on_complete then return end
return handle_result(result, err)