Files
sheepOp/docs/OPTIMIZATIONS.md
Carlos Gutierrez 3d2da94ce2 Initial commit: SheepOp LLM - Transformer-based language model implementation
- Complete transformer implementation from scratch
- Training pipeline with gradient accumulation and mixed precision
- Optimized inference with KV caching
- Multi-format data processing (PDFs, images, code, text)
- Comprehensive documentation
- Apache 2.0 license
- Example training plots included in docs/images/
2025-11-06 22:07:41 -05:00

224 lines
5.5 KiB
Markdown

# Optimizations for Production RAG Systems
This document describes the optimizations implemented based on the paper "Optimizing LLM Inference and Retrieval: Novel Data Structures and Algorithms for Production RAG Systems".
## Implemented Optimizations
### 1. KV Cache for Efficient Autoregressive Generation
**Location**: `models/optimized_attention.py`
The KV (Key-Value) cache mechanism stores computed keys and values from previous tokens during autoregressive generation, eliminating redundant computations.
**Benefits**:
- Reduces computational cost from O(n²) to O(n) for each new token
- Significantly faster generation for long sequences
- Lower memory bandwidth usage
**Usage**:
```python
from models import TransformerModel, OptimizedInference
model = TransformerModel(...)
optimizer = model.get_optimized_inference()
# Generate with KV caching
generated = optimizer.generate_with_cache(
input_ids=input_ids,
max_length=100,
temperature=0.8,
)
```
### 2. Optimized Attention Computation
**Location**: `models/optimized_attention.py`
Implements optimized attention computation using PyTorch's `scaled_dot_product_attention` when available (similar to Flash Attention).
**Features**:
- Uses PyTorch's optimized attention implementation
- Supports causal masking efficiently
- Reduces memory usage during attention computation
**Usage**:
```python
from models import TransformerModel
from models.blocks import TransformerBlock
# Use optimized attention in transformer blocks
block = TransformerBlock(
d_model=512,
num_heads=8,
use_optimized_attention=True, # Enable optimized attention
)
```
### 3. Retrieval Cache for Similar Queries
**Location**: `models/optimized_attention.py`
Implements approximate caching for retrieval results, reducing expensive vector database lookups by caching similar queries.
**Features**:
- Cosine similarity-based cache lookup
- Configurable similarity threshold
- Automatic cache eviction when full
**Usage**:
```python
from models.optimized_attention import RetrievalCache
cache = RetrievalCache(max_size=1000, similarity_threshold=0.9)
# Store retrieval results
cache.set(query_hash, query_embedding, retrieved_docs)
# Retrieve cached results
results = cache.get(query_hash, query_embedding)
```
### 4. Prefetching Mechanisms
**Location**: `models/prefetching.py`
#### 4.1 PrefetchDataLoader
Prefetches batches in background threads, reducing GPU idle time.
**Usage**:
```python
from models.prefetching import PrefetchDataLoader
from data import create_dataloader
dataloader = create_dataloader(...)
prefetch_loader = PrefetchDataLoader(
dataloader=dataloader,
prefetch_factor=2,
device=device,
)
```
#### 4.2 LookaheadRetriever
Prefetches retrieval results for anticipated queries.
**Usage**:
```python
from models.prefetching import LookaheadRetriever
def retrieve(query: str):
# Your retrieval function
return documents
retriever = LookaheadRetriever(
retrieval_fn=retrieve,
lookahead_window=3,
)
# Start prefetching
retriever.start_prefetching(query_stream)
# Get results (checks cache first)
results = retriever.get(query)
```
#### 4.3 BatchPrefetcher
Groups queries into batches for efficient batch retrieval.
**Usage**:
```python
from models.prefetching import BatchPrefetcher
def batch_retrieve(queries: List[str]):
# Batch retrieval function
return [documents for each query]
prefetcher = BatchPrefetcher(
batch_retrieval_fn=batch_retrieve,
batch_size=8,
)
prefetcher.start_prefetching(query_stream)
results = prefetcher.get(query)
```
### 5. Optimized Batch Inference
**Location**: `models/optimized_attention.py`
The `OptimizedInference` class provides batch generation utilities for processing multiple prompts efficiently.
**Features**:
- Batch processing for multiple prompts
- Automatic padding and batching
- Efficient memory usage
**Usage**:
```python
from models import OptimizedInference
optimizer = model.get_optimized_inference()
# Generate for multiple prompts in batches
results = optimizer.batch_generate(
input_ids_list=[prompt1_ids, prompt2_ids, ...],
max_length=100,
batch_size=8,
)
```
## Performance Improvements
These optimizations provide the following benefits:
1. **Faster Inference**: KV caching reduces generation time by 2-5x for long sequences
2. **Reduced Latency**: Prefetching reduces end-to-end latency by overlapping computation and I/O
3. **Lower Costs**: Retrieval caching reduces expensive vector database calls
4. **Better Throughput**: Batch processing increases throughput for multiple requests
## Integration
### Using Optimized Inference in Production
1. **Enable optimized attention** (for inference only):
```python
model = TransformerModel(
...,
use_optimized_attention=True, # Set in TransformerBlock
)
```
2. **Use optimized inference utility**:
```python
optimizer = model.get_optimized_inference()
generated = optimizer.generate_with_cache(...)
```
3. **Enable prefetching**:
```python
prefetch_loader = PrefetchDataLoader(dataloader, prefetch_factor=2)
```
### CLI Usage
Use the `--optimized` flag when running inference:
```bash
python inference.py \
--checkpoint checkpoints/best_checkpoint.pt \
--prompt "Your prompt here" \
--optimized \
--max-length 100
```
## Example Script
See `example_optimized.py` for complete examples of all optimizations.
## References
Based on optimizations from:
- "Optimizing LLM Inference and Retrieval: Novel Data Structures and Algorithms for Production RAG Systems"
- TeleRAG: Lookahead Retrieval Mechanism
- Flash Attention optimization techniques