feat: implement execution control (pause/resume/abort) for agent mode
Adds execution control system per GitHub issue #113: - Ctrl+P: Toggle pause/resume during agent execution - Ctrl+Z: Abort with rollback (undo file changes) - Ctrl+Shift+S: Toggle step-by-step mode - Enter: Advance one step when in step mode New files: - src/types/execution-control.ts: Type definitions - src/services/execution-control.ts: Control implementation with rollback - src/constants/execution-control.ts: Keyboard shortcuts and messages Modified: - agent-stream.ts: Integrated execution control into agent loop - message-handler.ts: Added control functions and callbacks - app.tsx: Added keyboard shortcut handlers - help-content.ts: Added help topics for new shortcuts Closes #113
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
*
|
||||
* Agent loop that streams LLM responses in real-time to the TUI.
|
||||
* Handles tool call accumulation mid-stream.
|
||||
* Supports pause/resume, step-by-step mode, and abort with rollback.
|
||||
*/
|
||||
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
@@ -20,11 +21,16 @@ import type {
|
||||
PartialToolCall,
|
||||
StreamCallbacks,
|
||||
} from "@/types/streaming";
|
||||
import type { ExecutionControlEvents } from "@/types/execution-control";
|
||||
import { chatStream } from "@providers/core/chat";
|
||||
import { getTool, getToolsForApi, refreshMCPTools } from "@tools/index";
|
||||
import { initializePermissions } from "@services/core/permissions";
|
||||
import { MAX_ITERATIONS, MAX_CONSECUTIVE_ERRORS } from "@constants/agent";
|
||||
import { createStreamAccumulator } from "@/types/streaming";
|
||||
import {
|
||||
createExecutionControl,
|
||||
captureFileState,
|
||||
} from "@services/execution-control";
|
||||
|
||||
// =============================================================================
|
||||
// Types
|
||||
@@ -36,6 +42,21 @@ interface StreamAgentState {
|
||||
abort: AbortController;
|
||||
options: AgentOptions;
|
||||
callbacks: Partial<StreamCallbacks>;
|
||||
executionControl: ReturnType<typeof createExecutionControl>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extended stream callbacks with execution control events
|
||||
*/
|
||||
export interface ExtendedStreamCallbacks extends StreamCallbacks {
|
||||
onPause?: () => void;
|
||||
onResume?: () => void;
|
||||
onStepModeEnabled?: () => void;
|
||||
onStepModeDisabled?: () => void;
|
||||
onWaitingForStep?: (toolName: string, toolArgs: Record<string, unknown>) => void;
|
||||
onAbort?: (rollbackCount: number) => void;
|
||||
onRollback?: (action: { type: string; description: string }) => void;
|
||||
onRollbackComplete?: (actionsRolledBack: number) => void;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
@@ -45,14 +66,34 @@ interface StreamAgentState {
|
||||
const createStreamAgentState = (
|
||||
workingDir: string,
|
||||
options: AgentOptions,
|
||||
callbacks: Partial<StreamCallbacks>,
|
||||
): StreamAgentState => ({
|
||||
sessionId: uuidv4(),
|
||||
workingDir,
|
||||
abort: new AbortController(),
|
||||
options,
|
||||
callbacks,
|
||||
});
|
||||
callbacks: Partial<ExtendedStreamCallbacks>,
|
||||
): StreamAgentState => {
|
||||
const extendedCallbacks = callbacks as Partial<ExtendedStreamCallbacks>;
|
||||
|
||||
const executionControlEvents: ExecutionControlEvents = {
|
||||
onPause: extendedCallbacks.onPause,
|
||||
onResume: extendedCallbacks.onResume,
|
||||
onStepModeEnabled: extendedCallbacks.onStepModeEnabled,
|
||||
onStepModeDisabled: extendedCallbacks.onStepModeDisabled,
|
||||
onWaitingForStep: extendedCallbacks.onWaitingForStep,
|
||||
onAbort: extendedCallbacks.onAbort,
|
||||
onRollback: (action) =>
|
||||
extendedCallbacks.onRollback?.({
|
||||
type: action.type,
|
||||
description: action.description,
|
||||
}),
|
||||
onRollbackComplete: extendedCallbacks.onRollbackComplete,
|
||||
};
|
||||
|
||||
return {
|
||||
sessionId: uuidv4(),
|
||||
workingDir,
|
||||
abort: new AbortController(),
|
||||
options,
|
||||
callbacks,
|
||||
executionControl: createExecutionControl(executionControlEvents),
|
||||
};
|
||||
};
|
||||
|
||||
// =============================================================================
|
||||
// Tool Call Accumulation
|
||||
@@ -251,10 +292,46 @@ const finalizeToolCall = (partial: PartialToolCall): ToolCall => {
|
||||
// Tool Execution
|
||||
// =============================================================================
|
||||
|
||||
/**
|
||||
* Tools that modify files and support rollback
|
||||
*/
|
||||
const ROLLBACK_CAPABLE_TOOLS: Record<string, "file_write" | "file_edit" | "file_delete" | "bash_command"> = {
|
||||
write: "file_write",
|
||||
edit: "file_edit",
|
||||
delete: "file_delete",
|
||||
bash: "bash_command",
|
||||
};
|
||||
|
||||
const executeTool = async (
|
||||
state: StreamAgentState,
|
||||
toolCall: ToolCall,
|
||||
): Promise<ToolResult> => {
|
||||
// Check if execution was aborted
|
||||
if (state.executionControl.getState() === "aborted") {
|
||||
return {
|
||||
success: false,
|
||||
title: "Aborted",
|
||||
output: "",
|
||||
error: "Execution was aborted",
|
||||
};
|
||||
}
|
||||
|
||||
// Wait if paused
|
||||
await state.executionControl.waitIfPaused();
|
||||
|
||||
// Wait for step confirmation if in step mode
|
||||
await state.executionControl.waitForStep(toolCall.name, toolCall.arguments);
|
||||
|
||||
// Check again after waiting (might have been aborted while waiting)
|
||||
if (state.executionControl.getState() === "aborted") {
|
||||
return {
|
||||
success: false,
|
||||
title: "Aborted",
|
||||
output: "",
|
||||
error: "Execution was aborted",
|
||||
};
|
||||
}
|
||||
|
||||
// Check for debug error markers from truncated/malformed JSON
|
||||
const debugError = toolCall.arguments.__debug_error as string | undefined;
|
||||
if (debugError) {
|
||||
@@ -292,9 +369,32 @@ const executeTool = async (
|
||||
onMetadata: () => {},
|
||||
};
|
||||
|
||||
// Capture file state for rollback if this is a modifying tool
|
||||
const rollbackType = ROLLBACK_CAPABLE_TOOLS[toolCall.name];
|
||||
let originalState: { filePath: string; content: string } | null = null;
|
||||
|
||||
if (rollbackType && (rollbackType === "file_edit" || rollbackType === "file_delete")) {
|
||||
const filePath = toolCall.arguments.file_path as string | undefined;
|
||||
if (filePath) {
|
||||
originalState = await captureFileState(filePath);
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const validatedArgs = tool.parameters.parse(toolCall.arguments);
|
||||
return await tool.execute(validatedArgs, ctx);
|
||||
const result = await tool.execute(validatedArgs, ctx);
|
||||
|
||||
// Record action for rollback if successful and modifying
|
||||
if (result.success && rollbackType) {
|
||||
const filePath = toolCall.arguments.file_path as string | undefined;
|
||||
state.executionControl.recordAction({
|
||||
type: rollbackType,
|
||||
description: `${toolCall.name}: ${filePath ?? "unknown file"}`,
|
||||
originalState: originalState ?? (filePath ? { filePath, content: "" } : undefined),
|
||||
});
|
||||
}
|
||||
|
||||
return result;
|
||||
} catch (error: unknown) {
|
||||
const receivedArgs = JSON.stringify(toolCall.arguments);
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
@@ -496,6 +596,19 @@ export const runAgentLoopStream = async (
|
||||
const agentMessages: AgentMessage[] = [...messages];
|
||||
|
||||
while (iterations < maxIterations) {
|
||||
// Check for abort at start of each iteration
|
||||
if (state.executionControl.getState() === "aborted") {
|
||||
return {
|
||||
success: false,
|
||||
finalResponse: "Execution aborted by user",
|
||||
iterations,
|
||||
toolCalls: allToolCalls,
|
||||
};
|
||||
}
|
||||
|
||||
// Wait if paused
|
||||
await state.executionControl.waitIfPaused();
|
||||
|
||||
iterations++;
|
||||
|
||||
try {
|
||||
@@ -611,7 +724,7 @@ export const runStreamingAgent = async (
|
||||
prompt: string,
|
||||
systemPrompt: string,
|
||||
options: AgentOptions,
|
||||
callbacks: Partial<StreamCallbacks> = {},
|
||||
callbacks: Partial<ExtendedStreamCallbacks> = {},
|
||||
): Promise<AgentResult> => {
|
||||
const messages: Message[] = [
|
||||
{ role: "system", content: systemPrompt },
|
||||
@@ -623,24 +736,63 @@ export const runStreamingAgent = async (
|
||||
};
|
||||
|
||||
/**
|
||||
* Create a streaming agent instance with stop capability
|
||||
* Streaming agent instance with full execution control
|
||||
*/
|
||||
export interface StreamingAgentInstance {
|
||||
/** Run the agent with given messages */
|
||||
run: (messages: Message[]) => Promise<AgentResult>;
|
||||
/** Stop the agent (abort without rollback) */
|
||||
stop: () => void;
|
||||
/** Update callbacks */
|
||||
updateCallbacks: (newCallbacks: Partial<ExtendedStreamCallbacks>) => void;
|
||||
/** Pause execution */
|
||||
pause: () => void;
|
||||
/** Resume execution */
|
||||
resume: () => void;
|
||||
/** Abort with optional rollback */
|
||||
abort: (rollback?: boolean) => Promise<void>;
|
||||
/** Enable/disable step-by-step mode */
|
||||
stepMode: (enabled: boolean) => void;
|
||||
/** Advance one step in step mode */
|
||||
step: () => void;
|
||||
/** Get current execution state */
|
||||
getExecutionState: () => "running" | "paused" | "stepping" | "aborted" | "completed";
|
||||
/** Check if waiting for step confirmation */
|
||||
isWaitingForStep: () => boolean;
|
||||
/** Get count of rollback actions available */
|
||||
getRollbackCount: () => number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a streaming agent instance with full execution control
|
||||
*/
|
||||
export const createStreamingAgent = (
|
||||
workingDir: string,
|
||||
options: AgentOptions,
|
||||
callbacks: Partial<StreamCallbacks> = {},
|
||||
): {
|
||||
run: (messages: Message[]) => Promise<AgentResult>;
|
||||
stop: () => void;
|
||||
updateCallbacks: (newCallbacks: Partial<StreamCallbacks>) => void;
|
||||
} => {
|
||||
callbacks: Partial<ExtendedStreamCallbacks> = {},
|
||||
): StreamingAgentInstance => {
|
||||
const state = createStreamAgentState(workingDir, options, callbacks);
|
||||
const control = state.executionControl;
|
||||
|
||||
return {
|
||||
run: (messages: Message[]) => runAgentLoopStream(state, messages),
|
||||
stop: () => state.abort.abort(),
|
||||
updateCallbacks: (newCallbacks: Partial<StreamCallbacks>) => {
|
||||
stop: () => {
|
||||
state.abort.abort();
|
||||
control.abort(false);
|
||||
},
|
||||
updateCallbacks: (newCallbacks: Partial<ExtendedStreamCallbacks>) => {
|
||||
Object.assign(state.callbacks, newCallbacks);
|
||||
},
|
||||
pause: () => control.pause(),
|
||||
resume: () => control.resume(),
|
||||
abort: (rollback = false) => control.abort(rollback),
|
||||
stepMode: (enabled: boolean) => control.stepMode(enabled),
|
||||
step: () => control.step(),
|
||||
getExecutionState: () => control.getState(),
|
||||
isWaitingForStep: () => control.isWaitingForStep(),
|
||||
getRollbackCount: () => control.getRollbackActions().length,
|
||||
};
|
||||
};
|
||||
|
||||
// Re-export types for external use
|
||||
export type { ExecutionControl, ExecutionControlEvents } from "@/types/execution-control";
|
||||
|
||||
Reference in New Issue
Block a user