- Complete transformer implementation from scratch - Training pipeline with gradient accumulation and mixed precision - Optimized inference with KV caching - Multi-format data processing (PDFs, images, code, text) - Comprehensive documentation - Apache 2.0 license - Example training plots included in docs/images/
200 lines
6.0 KiB
Python
200 lines
6.0 KiB
Python
"""
|
|
Example usage of optimized inference and retrieval mechanisms
|
|
Demonstrates KV caching, retrieval caching, and prefetching for RAG systems
|
|
"""
|
|
import torch
|
|
import sys
|
|
import importlib.util
|
|
from pathlib import Path
|
|
|
|
# Ensure current directory is in path
|
|
project_root = Path(__file__).parent.absolute()
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
# Explicitly import from local data module to avoid conflicts with stdlib 'data' module
|
|
data_module_path = project_root / "data" / "__init__.py"
|
|
spec = importlib.util.spec_from_file_location("sheepop_data", data_module_path)
|
|
sheepop_data = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(sheepop_data)
|
|
SimpleTokenizer = sheepop_data.SimpleTokenizer
|
|
create_dataloader = sheepop_data.create_dataloader
|
|
|
|
from models import TransformerModel, OptimizedInference, RetrievalCache
|
|
from models.prefetching import PrefetchDataLoader, LookaheadRetriever
|
|
|
|
|
|
def example_optimized_inference():
|
|
"""Example: Using optimized inference with KV caching."""
|
|
print("=" * 60)
|
|
print("Example: Optimized Inference with KV Caching")
|
|
print("=" * 60)
|
|
|
|
# Create model (example configuration)
|
|
model = TransformerModel(
|
|
vocab_size=128,
|
|
d_model=512,
|
|
num_layers=6,
|
|
num_heads=8,
|
|
)
|
|
|
|
# Get optimized inference utility
|
|
optimizer = model.get_optimized_inference()
|
|
|
|
# Example prompt
|
|
tokenizer = SimpleTokenizer()
|
|
prompt = "The future of AI"
|
|
input_ids = torch.tensor([tokenizer.encode(prompt)])
|
|
|
|
# Generate with KV caching (faster for autoregressive generation)
|
|
generated = optimizer.generate_with_cache(
|
|
input_ids=input_ids,
|
|
max_length=50,
|
|
temperature=0.8,
|
|
top_k=50,
|
|
top_p=0.95,
|
|
)
|
|
|
|
print(f"Generated: {tokenizer.decode(generated[0].tolist())}")
|
|
print()
|
|
|
|
|
|
def example_retrieval_caching():
|
|
"""Example: Using retrieval cache for similar queries."""
|
|
print("=" * 60)
|
|
print("Example: Retrieval Caching")
|
|
print("=" * 60)
|
|
|
|
# Create retrieval cache
|
|
cache = RetrievalCache(max_size=1000, similarity_threshold=0.9)
|
|
|
|
# Example: Simulate retrieval function
|
|
def retrieve_documents(query: str):
|
|
"""Mock retrieval function."""
|
|
return [
|
|
{"doc_id": "1", "text": f"Document about {query}", "score": 0.95},
|
|
{"doc_id": "2", "text": f"Another document about {query}", "score": 0.92},
|
|
]
|
|
|
|
# Create query embeddings (simplified)
|
|
query1 = "What is machine learning?"
|
|
query1_embedding = torch.randn(128) # Example embedding
|
|
|
|
query2 = "What is deep learning?" # Similar query
|
|
query2_embedding = torch.randn(128) # Example embedding (would be similar in practice)
|
|
|
|
# Store first query
|
|
import hashlib
|
|
query1_hash = hashlib.md5(query1.encode()).hexdigest()
|
|
results1 = retrieve_documents(query1)
|
|
cache.set(query1_hash, query1_embedding, results1)
|
|
|
|
# Retrieve from cache (should find similar query)
|
|
query2_hash = hashlib.md5(query2.encode()).hexdigest()
|
|
cached_results = cache.get(query2_hash, query2_embedding)
|
|
|
|
if cached_results:
|
|
print(f"Found cached results for query: {query2}")
|
|
print(f"Retrieved {len(cached_results)} documents")
|
|
else:
|
|
print("Cache miss, performing retrieval...")
|
|
results = retrieve_documents(query2)
|
|
cache.set(query2_hash, query2_embedding, results)
|
|
|
|
print()
|
|
|
|
|
|
def example_prefetching():
|
|
"""Example: Using prefetching for data loading."""
|
|
print("=" * 60)
|
|
print("Example: Prefetching DataLoader")
|
|
print("=" * 60)
|
|
|
|
# Create sample data
|
|
texts = ["This is a sample text."] * 100
|
|
tokenizer = SimpleTokenizer()
|
|
|
|
# Create standard dataloader
|
|
dataloader = create_dataloader(
|
|
texts=texts,
|
|
tokenizer=tokenizer,
|
|
batch_size=32,
|
|
max_length=512,
|
|
)
|
|
|
|
# Wrap with prefetching
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
prefetch_loader = PrefetchDataLoader(
|
|
dataloader=dataloader,
|
|
prefetch_factor=2,
|
|
device=device,
|
|
)
|
|
|
|
print(f"Created prefetch loader with {len(prefetch_loader)} batches")
|
|
print("Prefetching batches in background thread...")
|
|
print()
|
|
|
|
|
|
def example_batch_generation():
|
|
"""Example: Batch generation for multiple prompts."""
|
|
print("=" * 60)
|
|
print("Example: Batch Generation")
|
|
print("=" * 60)
|
|
|
|
# Create model
|
|
model = TransformerModel(
|
|
vocab_size=128,
|
|
d_model=512,
|
|
num_layers=6,
|
|
num_heads=8,
|
|
)
|
|
|
|
# Get optimized inference utility
|
|
optimizer = model.get_optimized_inference()
|
|
|
|
# Multiple prompts
|
|
tokenizer = SimpleTokenizer()
|
|
prompts = [
|
|
"The future of AI",
|
|
"Machine learning applications",
|
|
"Deep learning advances",
|
|
]
|
|
|
|
input_ids_list = [torch.tensor([tokenizer.encode(p)]) for p in prompts]
|
|
|
|
# Generate for all prompts in batches
|
|
results = optimizer.batch_generate(
|
|
input_ids_list=input_ids_list,
|
|
max_length=30,
|
|
temperature=0.8,
|
|
batch_size=2,
|
|
)
|
|
|
|
print(f"Generated {len(results)} responses:")
|
|
for i, (prompt, result) in enumerate(zip(prompts, results)):
|
|
# result is already a tensor [batch_size, seq_len], get first item if batch_size > 1
|
|
if result.dim() > 1 and result.shape[0] > 1:
|
|
generated_ids = result[0].tolist()
|
|
else:
|
|
generated_ids = result.squeeze(0).tolist() if result.dim() > 1 else result.tolist()
|
|
generated_text = tokenizer.decode(generated_ids)
|
|
print(f"{i+1}. Prompt: {prompt}")
|
|
print(f" Generated: {generated_text[:50]}...")
|
|
print()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
print("\n" + "=" * 60)
|
|
print("Optimized RAG System Examples")
|
|
print("=" * 60 + "\n")
|
|
|
|
# Run examples
|
|
example_optimized_inference()
|
|
example_retrieval_caching()
|
|
example_prefetching()
|
|
example_batch_generation()
|
|
|
|
print("=" * 60)
|
|
print("All examples completed!")
|
|
print("=" * 60)
|
|
|