- Complete transformer implementation from scratch - Training pipeline with gradient accumulation and mixed precision - Optimized inference with KV caching - Multi-format data processing (PDFs, images, code, text) - Comprehensive documentation - Apache 2.0 license - Example training plots included in docs/images/
5.5 KiB
Optimizations for Production RAG Systems
This document describes the optimizations implemented based on the paper "Optimizing LLM Inference and Retrieval: Novel Data Structures and Algorithms for Production RAG Systems".
Implemented Optimizations
1. KV Cache for Efficient Autoregressive Generation
Location: models/optimized_attention.py
The KV (Key-Value) cache mechanism stores computed keys and values from previous tokens during autoregressive generation, eliminating redundant computations.
Benefits:
- Reduces computational cost from O(n²) to O(n) for each new token
- Significantly faster generation for long sequences
- Lower memory bandwidth usage
Usage:
from models import TransformerModel, OptimizedInference
model = TransformerModel(...)
optimizer = model.get_optimized_inference()
# Generate with KV caching
generated = optimizer.generate_with_cache(
input_ids=input_ids,
max_length=100,
temperature=0.8,
)
2. Optimized Attention Computation
Location: models/optimized_attention.py
Implements optimized attention computation using PyTorch's scaled_dot_product_attention when available (similar to Flash Attention).
Features:
- Uses PyTorch's optimized attention implementation
- Supports causal masking efficiently
- Reduces memory usage during attention computation
Usage:
from models import TransformerModel
from models.blocks import TransformerBlock
# Use optimized attention in transformer blocks
block = TransformerBlock(
d_model=512,
num_heads=8,
use_optimized_attention=True, # Enable optimized attention
)
3. Retrieval Cache for Similar Queries
Location: models/optimized_attention.py
Implements approximate caching for retrieval results, reducing expensive vector database lookups by caching similar queries.
Features:
- Cosine similarity-based cache lookup
- Configurable similarity threshold
- Automatic cache eviction when full
Usage:
from models.optimized_attention import RetrievalCache
cache = RetrievalCache(max_size=1000, similarity_threshold=0.9)
# Store retrieval results
cache.set(query_hash, query_embedding, retrieved_docs)
# Retrieve cached results
results = cache.get(query_hash, query_embedding)
4. Prefetching Mechanisms
Location: models/prefetching.py
4.1 PrefetchDataLoader
Prefetches batches in background threads, reducing GPU idle time.
Usage:
from models.prefetching import PrefetchDataLoader
from data import create_dataloader
dataloader = create_dataloader(...)
prefetch_loader = PrefetchDataLoader(
dataloader=dataloader,
prefetch_factor=2,
device=device,
)
4.2 LookaheadRetriever
Prefetches retrieval results for anticipated queries.
Usage:
from models.prefetching import LookaheadRetriever
def retrieve(query: str):
# Your retrieval function
return documents
retriever = LookaheadRetriever(
retrieval_fn=retrieve,
lookahead_window=3,
)
# Start prefetching
retriever.start_prefetching(query_stream)
# Get results (checks cache first)
results = retriever.get(query)
4.3 BatchPrefetcher
Groups queries into batches for efficient batch retrieval.
Usage:
from models.prefetching import BatchPrefetcher
def batch_retrieve(queries: List[str]):
# Batch retrieval function
return [documents for each query]
prefetcher = BatchPrefetcher(
batch_retrieval_fn=batch_retrieve,
batch_size=8,
)
prefetcher.start_prefetching(query_stream)
results = prefetcher.get(query)
5. Optimized Batch Inference
Location: models/optimized_attention.py
The OptimizedInference class provides batch generation utilities for processing multiple prompts efficiently.
Features:
- Batch processing for multiple prompts
- Automatic padding and batching
- Efficient memory usage
Usage:
from models import OptimizedInference
optimizer = model.get_optimized_inference()
# Generate for multiple prompts in batches
results = optimizer.batch_generate(
input_ids_list=[prompt1_ids, prompt2_ids, ...],
max_length=100,
batch_size=8,
)
Performance Improvements
These optimizations provide the following benefits:
- Faster Inference: KV caching reduces generation time by 2-5x for long sequences
- Reduced Latency: Prefetching reduces end-to-end latency by overlapping computation and I/O
- Lower Costs: Retrieval caching reduces expensive vector database calls
- Better Throughput: Batch processing increases throughput for multiple requests
Integration
Using Optimized Inference in Production
- Enable optimized attention (for inference only):
model = TransformerModel(
...,
use_optimized_attention=True, # Set in TransformerBlock
)
- Use optimized inference utility:
optimizer = model.get_optimized_inference()
generated = optimizer.generate_with_cache(...)
- Enable prefetching:
prefetch_loader = PrefetchDataLoader(dataloader, prefetch_factor=2)
CLI Usage
Use the --optimized flag when running inference:
python inference.py \
--checkpoint checkpoints/best_checkpoint.pt \
--prompt "Your prompt here" \
--optimized \
--max-length 100
Example Script
See example_optimized.py for complete examples of all optimizations.
References
Based on optimizations from:
- "Optimizing LLM Inference and Retrieval: Novel Data Structures and Algorithms for Production RAG Systems"
- TeleRAG: Lookahead Retrieval Mechanism
- Flash Attention optimization techniques