- Complete transformer implementation from scratch - Training pipeline with gradient accumulation and mixed precision - Optimized inference with KV caching - Multi-format data processing (PDFs, images, code, text) - Comprehensive documentation - Apache 2.0 license - Example training plots included in docs/images/
224 lines
5.5 KiB
Markdown
224 lines
5.5 KiB
Markdown
# Optimizations for Production RAG Systems
|
|
|
|
This document describes the optimizations implemented based on the paper "Optimizing LLM Inference and Retrieval: Novel Data Structures and Algorithms for Production RAG Systems".
|
|
|
|
## Implemented Optimizations
|
|
|
|
### 1. KV Cache for Efficient Autoregressive Generation
|
|
|
|
**Location**: `models/optimized_attention.py`
|
|
|
|
The KV (Key-Value) cache mechanism stores computed keys and values from previous tokens during autoregressive generation, eliminating redundant computations.
|
|
|
|
**Benefits**:
|
|
- Reduces computational cost from O(n²) to O(n) for each new token
|
|
- Significantly faster generation for long sequences
|
|
- Lower memory bandwidth usage
|
|
|
|
**Usage**:
|
|
```python
|
|
from models import TransformerModel, OptimizedInference
|
|
|
|
model = TransformerModel(...)
|
|
optimizer = model.get_optimized_inference()
|
|
|
|
# Generate with KV caching
|
|
generated = optimizer.generate_with_cache(
|
|
input_ids=input_ids,
|
|
max_length=100,
|
|
temperature=0.8,
|
|
)
|
|
```
|
|
|
|
### 2. Optimized Attention Computation
|
|
|
|
**Location**: `models/optimized_attention.py`
|
|
|
|
Implements optimized attention computation using PyTorch's `scaled_dot_product_attention` when available (similar to Flash Attention).
|
|
|
|
**Features**:
|
|
- Uses PyTorch's optimized attention implementation
|
|
- Supports causal masking efficiently
|
|
- Reduces memory usage during attention computation
|
|
|
|
**Usage**:
|
|
```python
|
|
from models import TransformerModel
|
|
from models.blocks import TransformerBlock
|
|
|
|
# Use optimized attention in transformer blocks
|
|
block = TransformerBlock(
|
|
d_model=512,
|
|
num_heads=8,
|
|
use_optimized_attention=True, # Enable optimized attention
|
|
)
|
|
```
|
|
|
|
### 3. Retrieval Cache for Similar Queries
|
|
|
|
**Location**: `models/optimized_attention.py`
|
|
|
|
Implements approximate caching for retrieval results, reducing expensive vector database lookups by caching similar queries.
|
|
|
|
**Features**:
|
|
- Cosine similarity-based cache lookup
|
|
- Configurable similarity threshold
|
|
- Automatic cache eviction when full
|
|
|
|
**Usage**:
|
|
```python
|
|
from models.optimized_attention import RetrievalCache
|
|
|
|
cache = RetrievalCache(max_size=1000, similarity_threshold=0.9)
|
|
|
|
# Store retrieval results
|
|
cache.set(query_hash, query_embedding, retrieved_docs)
|
|
|
|
# Retrieve cached results
|
|
results = cache.get(query_hash, query_embedding)
|
|
```
|
|
|
|
### 4. Prefetching Mechanisms
|
|
|
|
**Location**: `models/prefetching.py`
|
|
|
|
#### 4.1 PrefetchDataLoader
|
|
Prefetches batches in background threads, reducing GPU idle time.
|
|
|
|
**Usage**:
|
|
```python
|
|
from models.prefetching import PrefetchDataLoader
|
|
from data import create_dataloader
|
|
|
|
dataloader = create_dataloader(...)
|
|
prefetch_loader = PrefetchDataLoader(
|
|
dataloader=dataloader,
|
|
prefetch_factor=2,
|
|
device=device,
|
|
)
|
|
```
|
|
|
|
#### 4.2 LookaheadRetriever
|
|
Prefetches retrieval results for anticipated queries.
|
|
|
|
**Usage**:
|
|
```python
|
|
from models.prefetching import LookaheadRetriever
|
|
|
|
def retrieve(query: str):
|
|
# Your retrieval function
|
|
return documents
|
|
|
|
retriever = LookaheadRetriever(
|
|
retrieval_fn=retrieve,
|
|
lookahead_window=3,
|
|
)
|
|
|
|
# Start prefetching
|
|
retriever.start_prefetching(query_stream)
|
|
|
|
# Get results (checks cache first)
|
|
results = retriever.get(query)
|
|
```
|
|
|
|
#### 4.3 BatchPrefetcher
|
|
Groups queries into batches for efficient batch retrieval.
|
|
|
|
**Usage**:
|
|
```python
|
|
from models.prefetching import BatchPrefetcher
|
|
|
|
def batch_retrieve(queries: List[str]):
|
|
# Batch retrieval function
|
|
return [documents for each query]
|
|
|
|
prefetcher = BatchPrefetcher(
|
|
batch_retrieval_fn=batch_retrieve,
|
|
batch_size=8,
|
|
)
|
|
|
|
prefetcher.start_prefetching(query_stream)
|
|
results = prefetcher.get(query)
|
|
```
|
|
|
|
### 5. Optimized Batch Inference
|
|
|
|
**Location**: `models/optimized_attention.py`
|
|
|
|
The `OptimizedInference` class provides batch generation utilities for processing multiple prompts efficiently.
|
|
|
|
**Features**:
|
|
- Batch processing for multiple prompts
|
|
- Automatic padding and batching
|
|
- Efficient memory usage
|
|
|
|
**Usage**:
|
|
```python
|
|
from models import OptimizedInference
|
|
|
|
optimizer = model.get_optimized_inference()
|
|
|
|
# Generate for multiple prompts in batches
|
|
results = optimizer.batch_generate(
|
|
input_ids_list=[prompt1_ids, prompt2_ids, ...],
|
|
max_length=100,
|
|
batch_size=8,
|
|
)
|
|
```
|
|
|
|
## Performance Improvements
|
|
|
|
These optimizations provide the following benefits:
|
|
|
|
1. **Faster Inference**: KV caching reduces generation time by 2-5x for long sequences
|
|
2. **Reduced Latency**: Prefetching reduces end-to-end latency by overlapping computation and I/O
|
|
3. **Lower Costs**: Retrieval caching reduces expensive vector database calls
|
|
4. **Better Throughput**: Batch processing increases throughput for multiple requests
|
|
|
|
## Integration
|
|
|
|
### Using Optimized Inference in Production
|
|
|
|
1. **Enable optimized attention** (for inference only):
|
|
```python
|
|
model = TransformerModel(
|
|
...,
|
|
use_optimized_attention=True, # Set in TransformerBlock
|
|
)
|
|
```
|
|
|
|
2. **Use optimized inference utility**:
|
|
```python
|
|
optimizer = model.get_optimized_inference()
|
|
generated = optimizer.generate_with_cache(...)
|
|
```
|
|
|
|
3. **Enable prefetching**:
|
|
```python
|
|
prefetch_loader = PrefetchDataLoader(dataloader, prefetch_factor=2)
|
|
```
|
|
|
|
### CLI Usage
|
|
|
|
Use the `--optimized` flag when running inference:
|
|
|
|
```bash
|
|
python inference.py \
|
|
--checkpoint checkpoints/best_checkpoint.pt \
|
|
--prompt "Your prompt here" \
|
|
--optimized \
|
|
--max-length 100
|
|
```
|
|
|
|
## Example Script
|
|
|
|
See `example_optimized.py` for complete examples of all optimizations.
|
|
|
|
## References
|
|
|
|
Based on optimizations from:
|
|
- "Optimizing LLM Inference and Retrieval: Novel Data Structures and Algorithms for Production RAG Systems"
|
|
- TeleRAG: Lookahead Retrieval Mechanism
|
|
- Flash Attention optimization techniques
|
|
|