Initial commit: LLM-DS optimizer framework with data files excluded
This commit is contained in:
216
llmds/scheduler.py
Normal file
216
llmds/scheduler.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""Dynamic micro-batching scheduler with priority queue."""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from llmds.indexed_heap import IndexedHeap
|
||||
|
||||
|
||||
@dataclass
|
||||
class Request:
|
||||
"""Represents a request in the scheduler."""
|
||||
|
||||
request_id: int
|
||||
tokens: int
|
||||
priority: float # Higher = more priority
|
||||
created_at: float
|
||||
slo_ms: Optional[float] = None # Service level objective in milliseconds
|
||||
|
||||
|
||||
class Scheduler:
|
||||
"""
|
||||
Dynamic micro-batching scheduler with priority-based queuing.
|
||||
|
||||
Uses an indexed heap to prioritize sequences by remaining length or SLO.
|
||||
Supports dynamic batching with configurable waiting time vs. throughput trade-offs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_batch_size: int = 32,
|
||||
max_wait_ms: float = 50.0,
|
||||
priority_fn: Optional[Callable[[Request], float]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize scheduler.
|
||||
|
||||
Args:
|
||||
max_batch_size: Maximum batch size
|
||||
max_wait_ms: Maximum wait time in milliseconds before batching
|
||||
priority_fn: Optional function to compute priority from request.
|
||||
Default: prioritize by remaining tokens (inverse)
|
||||
"""
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_wait_ms = max_wait_ms
|
||||
self._heap = IndexedHeap(max_heap=True) # Max heap for priority
|
||||
self._requests: dict[int, Request] = {}
|
||||
self._priority_fn = priority_fn or self._default_priority_fn
|
||||
self._request_counter = 0
|
||||
self._batch_count = 0
|
||||
self._total_processed = 0
|
||||
|
||||
def _default_priority_fn(self, req: Request) -> float:
|
||||
"""Default priority: higher priority for shorter sequences (inverse of tokens)."""
|
||||
return 1.0 / (req.tokens + 1.0)
|
||||
|
||||
def _slo_priority_fn(self, req: Request) -> float:
|
||||
"""Priority based on SLO deadline."""
|
||||
if req.slo_ms is None:
|
||||
return self._default_priority_fn(req)
|
||||
|
||||
elapsed_ms = (time.time() - req.created_at) * 1000
|
||||
remaining_ms = req.slo_ms - elapsed_ms
|
||||
if remaining_ms <= 0:
|
||||
return float("inf") # Urgent: past deadline
|
||||
return 1.0 / (remaining_ms + 1.0)
|
||||
|
||||
def submit(self, tokens: int, slo_ms: Optional[float] = None) -> int:
|
||||
"""
|
||||
Submit a request to the scheduler.
|
||||
|
||||
Args:
|
||||
tokens: Estimated token count for the request
|
||||
slo_ms: Optional SLO deadline in milliseconds
|
||||
|
||||
Returns:
|
||||
Request ID
|
||||
"""
|
||||
req_id = self._request_counter
|
||||
self._request_counter += 1
|
||||
|
||||
req = Request(
|
||||
request_id=req_id,
|
||||
tokens=tokens,
|
||||
priority=self._priority_fn(
|
||||
Request(
|
||||
request_id=req_id,
|
||||
tokens=tokens,
|
||||
priority=0.0,
|
||||
created_at=time.time(),
|
||||
slo_ms=slo_ms,
|
||||
)
|
||||
),
|
||||
created_at=time.time(),
|
||||
slo_ms=slo_ms,
|
||||
)
|
||||
|
||||
self._requests[req_id] = req
|
||||
self._heap.push(req_id, req.priority)
|
||||
|
||||
return req_id
|
||||
|
||||
def get_batch(self, force: bool = False) -> Optional[list[int]]:
|
||||
"""
|
||||
Get next batch of requests to process.
|
||||
|
||||
Args:
|
||||
force: If True, return batch even if not full
|
||||
|
||||
Returns:
|
||||
List of request IDs or None if no batch ready
|
||||
"""
|
||||
if self._heap.is_empty():
|
||||
return None
|
||||
|
||||
# Check if oldest request exceeds max wait time
|
||||
oldest_req_id = None
|
||||
oldest_time = float("inf")
|
||||
|
||||
for req_id in self._requests:
|
||||
if self._requests[req_id].created_at < oldest_time:
|
||||
oldest_time = self._requests[req_id].created_at
|
||||
oldest_req_id = req_id
|
||||
|
||||
if oldest_req_id:
|
||||
wait_time_ms = (time.time() - oldest_time) * 1000
|
||||
if not force and wait_time_ms < self.max_wait_ms:
|
||||
return None
|
||||
|
||||
# Build batch from heap
|
||||
batch: list[int] = []
|
||||
temp_heap = IndexedHeap(max_heap=True)
|
||||
|
||||
# Pop top requests
|
||||
while len(batch) < self.max_batch_size and not self._heap.is_empty():
|
||||
_, req_id = self._heap.pop()
|
||||
if req_id in self._requests:
|
||||
batch.append(req_id)
|
||||
else:
|
||||
temp_heap.push(req_id, self._requests[req_id].priority)
|
||||
|
||||
# Restore heap (add back any that weren't used)
|
||||
while not temp_heap.is_empty():
|
||||
_, req_id = temp_heap.pop()
|
||||
self._heap.push(req_id, self._requests[req_id].priority)
|
||||
|
||||
if batch:
|
||||
self._batch_count += 1
|
||||
self._total_processed += len(batch)
|
||||
return batch
|
||||
|
||||
return None
|
||||
|
||||
def complete_batch(self, request_ids: list[int]) -> None:
|
||||
"""
|
||||
Mark a batch as completed and remove requests.
|
||||
|
||||
Args:
|
||||
request_ids: List of completed request IDs
|
||||
"""
|
||||
for req_id in request_ids:
|
||||
if req_id in self._requests:
|
||||
# Try to remove from heap if present
|
||||
if self._heap.contains(req_id):
|
||||
try:
|
||||
self._heap.delete(req_id)
|
||||
except KeyError:
|
||||
pass
|
||||
del self._requests[req_id]
|
||||
|
||||
def update_priority(self, request_id: int, new_tokens: int) -> None:
|
||||
"""
|
||||
Update priority for a request (e.g., after partial processing).
|
||||
|
||||
Args:
|
||||
request_id: Request identifier
|
||||
new_tokens: Updated token count
|
||||
"""
|
||||
if request_id not in self._requests:
|
||||
return
|
||||
|
||||
req = self._requests[request_id]
|
||||
req.tokens = new_tokens
|
||||
new_priority = self._priority_fn(req)
|
||||
|
||||
if self._heap.contains(request_id):
|
||||
old_priority = self._heap.get_score(request_id)
|
||||
if old_priority is not None:
|
||||
if new_priority > old_priority:
|
||||
self._heap.increase_key(request_id, new_priority)
|
||||
else:
|
||||
self._heap.decrease_key(request_id, new_priority)
|
||||
else:
|
||||
self._heap.push(request_id, new_priority)
|
||||
|
||||
def stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get scheduler statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with scheduler statistics
|
||||
"""
|
||||
return {
|
||||
"queue_size": len(self._requests),
|
||||
"batch_count": self._batch_count,
|
||||
"total_processed": self._total_processed,
|
||||
"avg_batch_size": (
|
||||
self._total_processed / self._batch_count if self._batch_count > 0 else 0.0
|
||||
),
|
||||
}
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all pending requests."""
|
||||
self._heap = IndexedHeap(max_heap=True)
|
||||
self._requests.clear()
|
||||
|
||||
Reference in New Issue
Block a user