217 lines
6.7 KiB
Python
217 lines
6.7 KiB
Python
"""Dynamic micro-batching scheduler with priority queue."""
|
|
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, Optional
|
|
|
|
from llmds.indexed_heap import IndexedHeap
|
|
|
|
|
|
@dataclass
|
|
class Request:
|
|
"""Represents a request in the scheduler."""
|
|
|
|
request_id: int
|
|
tokens: int
|
|
priority: float # Higher = more priority
|
|
created_at: float
|
|
slo_ms: Optional[float] = None # Service level objective in milliseconds
|
|
|
|
|
|
class Scheduler:
|
|
"""
|
|
Dynamic micro-batching scheduler with priority-based queuing.
|
|
|
|
Uses an indexed heap to prioritize sequences by remaining length or SLO.
|
|
Supports dynamic batching with configurable waiting time vs. throughput trade-offs.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
max_batch_size: int = 32,
|
|
max_wait_ms: float = 50.0,
|
|
priority_fn: Optional[Callable[[Request], float]] = None,
|
|
):
|
|
"""
|
|
Initialize scheduler.
|
|
|
|
Args:
|
|
max_batch_size: Maximum batch size
|
|
max_wait_ms: Maximum wait time in milliseconds before batching
|
|
priority_fn: Optional function to compute priority from request.
|
|
Default: prioritize by remaining tokens (inverse)
|
|
"""
|
|
self.max_batch_size = max_batch_size
|
|
self.max_wait_ms = max_wait_ms
|
|
self._heap = IndexedHeap(max_heap=True) # Max heap for priority
|
|
self._requests: dict[int, Request] = {}
|
|
self._priority_fn = priority_fn or self._default_priority_fn
|
|
self._request_counter = 0
|
|
self._batch_count = 0
|
|
self._total_processed = 0
|
|
|
|
def _default_priority_fn(self, req: Request) -> float:
|
|
"""Default priority: higher priority for shorter sequences (inverse of tokens)."""
|
|
return 1.0 / (req.tokens + 1.0)
|
|
|
|
def _slo_priority_fn(self, req: Request) -> float:
|
|
"""Priority based on SLO deadline."""
|
|
if req.slo_ms is None:
|
|
return self._default_priority_fn(req)
|
|
|
|
elapsed_ms = (time.time() - req.created_at) * 1000
|
|
remaining_ms = req.slo_ms - elapsed_ms
|
|
if remaining_ms <= 0:
|
|
return float("inf") # Urgent: past deadline
|
|
return 1.0 / (remaining_ms + 1.0)
|
|
|
|
def submit(self, tokens: int, slo_ms: Optional[float] = None) -> int:
|
|
"""
|
|
Submit a request to the scheduler.
|
|
|
|
Args:
|
|
tokens: Estimated token count for the request
|
|
slo_ms: Optional SLO deadline in milliseconds
|
|
|
|
Returns:
|
|
Request ID
|
|
"""
|
|
req_id = self._request_counter
|
|
self._request_counter += 1
|
|
|
|
req = Request(
|
|
request_id=req_id,
|
|
tokens=tokens,
|
|
priority=self._priority_fn(
|
|
Request(
|
|
request_id=req_id,
|
|
tokens=tokens,
|
|
priority=0.0,
|
|
created_at=time.time(),
|
|
slo_ms=slo_ms,
|
|
)
|
|
),
|
|
created_at=time.time(),
|
|
slo_ms=slo_ms,
|
|
)
|
|
|
|
self._requests[req_id] = req
|
|
self._heap.push(req_id, req.priority)
|
|
|
|
return req_id
|
|
|
|
def get_batch(self, force: bool = False) -> Optional[list[int]]:
|
|
"""
|
|
Get next batch of requests to process.
|
|
|
|
Args:
|
|
force: If True, return batch even if not full
|
|
|
|
Returns:
|
|
List of request IDs or None if no batch ready
|
|
"""
|
|
if self._heap.is_empty():
|
|
return None
|
|
|
|
# Check if oldest request exceeds max wait time
|
|
oldest_req_id = None
|
|
oldest_time = float("inf")
|
|
|
|
for req_id in self._requests:
|
|
if self._requests[req_id].created_at < oldest_time:
|
|
oldest_time = self._requests[req_id].created_at
|
|
oldest_req_id = req_id
|
|
|
|
if oldest_req_id:
|
|
wait_time_ms = (time.time() - oldest_time) * 1000
|
|
if not force and wait_time_ms < self.max_wait_ms:
|
|
return None
|
|
|
|
# Build batch from heap
|
|
batch: list[int] = []
|
|
temp_heap = IndexedHeap(max_heap=True)
|
|
|
|
# Pop top requests
|
|
while len(batch) < self.max_batch_size and not self._heap.is_empty():
|
|
_, req_id = self._heap.pop()
|
|
if req_id in self._requests:
|
|
batch.append(req_id)
|
|
else:
|
|
temp_heap.push(req_id, self._requests[req_id].priority)
|
|
|
|
# Restore heap (add back any that weren't used)
|
|
while not temp_heap.is_empty():
|
|
_, req_id = temp_heap.pop()
|
|
self._heap.push(req_id, self._requests[req_id].priority)
|
|
|
|
if batch:
|
|
self._batch_count += 1
|
|
self._total_processed += len(batch)
|
|
return batch
|
|
|
|
return None
|
|
|
|
def complete_batch(self, request_ids: list[int]) -> None:
|
|
"""
|
|
Mark a batch as completed and remove requests.
|
|
|
|
Args:
|
|
request_ids: List of completed request IDs
|
|
"""
|
|
for req_id in request_ids:
|
|
if req_id in self._requests:
|
|
# Try to remove from heap if present
|
|
if self._heap.contains(req_id):
|
|
try:
|
|
self._heap.delete(req_id)
|
|
except KeyError:
|
|
pass
|
|
del self._requests[req_id]
|
|
|
|
def update_priority(self, request_id: int, new_tokens: int) -> None:
|
|
"""
|
|
Update priority for a request (e.g., after partial processing).
|
|
|
|
Args:
|
|
request_id: Request identifier
|
|
new_tokens: Updated token count
|
|
"""
|
|
if request_id not in self._requests:
|
|
return
|
|
|
|
req = self._requests[request_id]
|
|
req.tokens = new_tokens
|
|
new_priority = self._priority_fn(req)
|
|
|
|
if self._heap.contains(request_id):
|
|
old_priority = self._heap.get_score(request_id)
|
|
if old_priority is not None:
|
|
if new_priority > old_priority:
|
|
self._heap.increase_key(request_id, new_priority)
|
|
else:
|
|
self._heap.decrease_key(request_id, new_priority)
|
|
else:
|
|
self._heap.push(request_id, new_priority)
|
|
|
|
def stats(self) -> dict[str, Any]:
|
|
"""
|
|
Get scheduler statistics.
|
|
|
|
Returns:
|
|
Dictionary with scheduler statistics
|
|
"""
|
|
return {
|
|
"queue_size": len(self._requests),
|
|
"batch_count": self._batch_count,
|
|
"total_processed": self._total_processed,
|
|
"avg_batch_size": (
|
|
self._total_processed / self._batch_count if self._batch_count > 0 else 0.0
|
|
),
|
|
}
|
|
|
|
def clear(self) -> None:
|
|
"""Clear all pending requests."""
|
|
self._heap = IndexedHeap(max_heap=True)
|
|
self._requests.clear()
|
|
|