Initial commit: SheepOp LLM - Transformer-based language model implementation
- Complete transformer implementation from scratch - Training pipeline with gradient accumulation and mixed precision - Optimized inference with KV caching - Multi-format data processing (PDFs, images, code, text) - Comprehensive documentation - Apache 2.0 license - Example training plots included in docs/images/
This commit is contained in:
267
models/prefetching.py
Normal file
267
models/prefetching.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""
|
||||
Prefetching mechanism for parallel data loading and processing
|
||||
Optimizes RAG systems by prefetching retrieval results
|
||||
"""
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from typing import List, Dict, Optional, Callable, Any
|
||||
from threading import Thread
|
||||
from queue import Queue
|
||||
import time
|
||||
|
||||
|
||||
class PrefetchDataLoader:
|
||||
"""
|
||||
DataLoader with prefetching for parallel data loading.
|
||||
Reduces GPU idle time by prefetching batches in background threads.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataloader: DataLoader,
|
||||
prefetch_factor: int = 2,
|
||||
device: torch.device = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dataloader: Base DataLoader to wrap
|
||||
prefetch_factor: Number of batches to prefetch
|
||||
device: Device to prefetch batches to
|
||||
"""
|
||||
self.dataloader = dataloader
|
||||
self.prefetch_factor = prefetch_factor
|
||||
self.device = device
|
||||
|
||||
self.queue = Queue(maxsize=prefetch_factor)
|
||||
self.thread = None
|
||||
self._stop_thread = False
|
||||
|
||||
def _prefetch_worker(self):
|
||||
"""Worker thread that prefetches batches."""
|
||||
for batch in self.dataloader:
|
||||
if self._stop_thread:
|
||||
break
|
||||
|
||||
# Move to device if specified
|
||||
if self.device is not None:
|
||||
batch = {k: v.to(self.device, non_blocking=True)
|
||||
for k, v in batch.items()}
|
||||
|
||||
self.queue.put(batch)
|
||||
|
||||
self.queue.put(None) # Signal end of data
|
||||
|
||||
def __iter__(self):
|
||||
"""Start prefetching thread and return iterator."""
|
||||
self._stop_thread = False
|
||||
self.thread = Thread(target=self._prefetch_worker, daemon=True)
|
||||
self.thread.start()
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
"""Get next prefetched batch."""
|
||||
batch = self.queue.get()
|
||||
if batch is None:
|
||||
raise StopIteration
|
||||
return batch
|
||||
|
||||
def __len__(self):
|
||||
"""Return length of underlying dataloader."""
|
||||
return len(self.dataloader)
|
||||
|
||||
def stop(self):
|
||||
"""Stop prefetching thread."""
|
||||
self._stop_thread = True
|
||||
if self.thread is not None:
|
||||
self.thread.join()
|
||||
|
||||
|
||||
class LookaheadRetriever:
|
||||
"""
|
||||
Lookahead retrieval mechanism for RAG systems.
|
||||
Prefetches retrieval results for anticipated queries.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
retrieval_fn: Callable[[str], List[Dict]],
|
||||
lookahead_window: int = 3,
|
||||
prefetch_queue_size: int = 10,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
retrieval_fn: Function that takes a query and returns retrieved documents
|
||||
lookahead_window: Number of queries to look ahead
|
||||
prefetch_queue_size: Maximum size of prefetch queue
|
||||
"""
|
||||
self.retrieval_fn = retrieval_fn
|
||||
self.lookahead_window = lookahead_window
|
||||
self.prefetch_queue_size = prefetch_queue_size
|
||||
|
||||
self.prefetch_queue: Queue = Queue(maxsize=prefetch_queue_size)
|
||||
self.prefetch_thread: Optional[Thread] = None
|
||||
self._stop_thread = False
|
||||
|
||||
def _prefetch_worker(self, query_queue: Queue):
|
||||
"""Worker thread that prefetches retrieval results."""
|
||||
while not self._stop_thread:
|
||||
try:
|
||||
query = query_queue.get(timeout=1.0)
|
||||
if query is None:
|
||||
break
|
||||
|
||||
# Perform retrieval
|
||||
results = self.retrieval_fn(query)
|
||||
|
||||
# Add to prefetch queue
|
||||
try:
|
||||
self.prefetch_queue.put((query, results), timeout=0.1)
|
||||
except:
|
||||
pass # Queue full, skip
|
||||
|
||||
except:
|
||||
continue
|
||||
|
||||
def start_prefetching(self, query_stream: List[str]):
|
||||
"""Start prefetching retrieval results for query stream."""
|
||||
query_queue = Queue()
|
||||
|
||||
# Add queries to queue
|
||||
for query in query_stream:
|
||||
query_queue.put(query)
|
||||
query_queue.put(None) # Signal end
|
||||
|
||||
self._stop_thread = False
|
||||
self.prefetch_thread = Thread(target=self._prefetch_worker, args=(query_queue,), daemon=True)
|
||||
self.prefetch_thread.start()
|
||||
|
||||
def get(self, query: str, timeout: float = 1.0) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Get retrieval results, checking prefetch queue first.
|
||||
|
||||
Args:
|
||||
query: Query string
|
||||
timeout: Timeout for checking prefetch queue
|
||||
|
||||
Returns:
|
||||
Retrieved documents or None if not found
|
||||
"""
|
||||
# Check prefetch queue
|
||||
while not self.prefetch_queue.empty():
|
||||
try:
|
||||
cached_query, results = self.prefetch_queue.get(timeout=timeout)
|
||||
if cached_query == query:
|
||||
return results
|
||||
# Put back if not matching
|
||||
self.prefetch_queue.put((cached_query, results))
|
||||
except:
|
||||
break
|
||||
|
||||
# Fallback to direct retrieval
|
||||
return self.retrieval_fn(query)
|
||||
|
||||
def stop(self):
|
||||
"""Stop prefetching thread."""
|
||||
self._stop_thread = True
|
||||
if self.prefetch_thread is not None:
|
||||
self.prefetch_thread.join()
|
||||
|
||||
|
||||
class BatchPrefetcher:
|
||||
"""
|
||||
Batched prefetching for multiple queries.
|
||||
Groups queries into batches for efficient retrieval.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_retrieval_fn: Callable[[List[str]], List[List[Dict]]],
|
||||
batch_size: int = 8,
|
||||
prefetch_factor: int = 2,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
batch_retrieval_fn: Function that takes list of queries and returns list of results
|
||||
batch_size: Size of batches for retrieval
|
||||
prefetch_factor: Number of batches to prefetch
|
||||
"""
|
||||
self.batch_retrieval_fn = batch_retrieval_fn
|
||||
self.batch_size = batch_size
|
||||
self.prefetch_factor = prefetch_factor
|
||||
|
||||
self.prefetch_queue: Queue = Queue(maxsize=prefetch_factor)
|
||||
self.prefetch_thread: Optional[Thread] = None
|
||||
self._stop_thread = False
|
||||
|
||||
def _prefetch_worker(self, query_queue: Queue):
|
||||
"""Worker thread that prefetches batches of retrieval results."""
|
||||
batch = []
|
||||
|
||||
while not self._stop_thread:
|
||||
try:
|
||||
query = query_queue.get(timeout=1.0)
|
||||
if query is None:
|
||||
# Process remaining batch
|
||||
if batch:
|
||||
results = self.batch_retrieval_fn(batch)
|
||||
for q, r in zip(batch, results):
|
||||
self.prefetch_queue.put((q, r))
|
||||
break
|
||||
|
||||
batch.append(query)
|
||||
|
||||
# Process batch when full
|
||||
if len(batch) >= self.batch_size:
|
||||
results = self.batch_retrieval_fn(batch)
|
||||
for q, r in zip(batch, results):
|
||||
try:
|
||||
self.prefetch_queue.put((q, r), timeout=0.1)
|
||||
except:
|
||||
pass # Queue full
|
||||
batch = []
|
||||
|
||||
except:
|
||||
continue
|
||||
|
||||
def start_prefetching(self, query_stream: List[str]):
|
||||
"""Start prefetching retrieval results for query stream."""
|
||||
query_queue = Queue()
|
||||
|
||||
for query in query_stream:
|
||||
query_queue.put(query)
|
||||
query_queue.put(None) # Signal end
|
||||
|
||||
self._stop_thread = False
|
||||
self.prefetch_thread = Thread(target=self._prefetch_worker, args=(query_queue,), daemon=True)
|
||||
self.prefetch_thread.start()
|
||||
|
||||
def get(self, query: str, timeout: float = 1.0) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Get retrieval results from prefetch queue.
|
||||
|
||||
Args:
|
||||
query: Query string
|
||||
timeout: Timeout for checking prefetch queue
|
||||
|
||||
Returns:
|
||||
Retrieved documents or None if not found
|
||||
"""
|
||||
# Check prefetch queue
|
||||
while not self.prefetch_queue.empty():
|
||||
try:
|
||||
cached_query, results = self.prefetch_queue.get(timeout=timeout)
|
||||
if cached_query == query:
|
||||
return results
|
||||
# Put back if not matching
|
||||
self.prefetch_queue.put((cached_query, results))
|
||||
except:
|
||||
break
|
||||
|
||||
return None
|
||||
|
||||
def stop(self):
|
||||
"""Stop prefetching thread."""
|
||||
self._stop_thread = True
|
||||
if self.prefetch_thread is not None:
|
||||
self.prefetch_thread.join()
|
||||
|
||||
Reference in New Issue
Block a user