Initial commit: SheepOp LLM - Transformer-based language model implementation
- Complete transformer implementation from scratch - Training pipeline with gradient accumulation and mixed precision - Optimized inference with KV caching - Multi-format data processing (PDFs, images, code, text) - Comprehensive documentation - Apache 2.0 license - Example training plots included in docs/images/
This commit is contained in:
223
docs/OPTIMIZATIONS.md
Normal file
223
docs/OPTIMIZATIONS.md
Normal file
@@ -0,0 +1,223 @@
|
||||
# Optimizations for Production RAG Systems
|
||||
|
||||
This document describes the optimizations implemented based on the paper "Optimizing LLM Inference and Retrieval: Novel Data Structures and Algorithms for Production RAG Systems".
|
||||
|
||||
## Implemented Optimizations
|
||||
|
||||
### 1. KV Cache for Efficient Autoregressive Generation
|
||||
|
||||
**Location**: `models/optimized_attention.py`
|
||||
|
||||
The KV (Key-Value) cache mechanism stores computed keys and values from previous tokens during autoregressive generation, eliminating redundant computations.
|
||||
|
||||
**Benefits**:
|
||||
- Reduces computational cost from O(n²) to O(n) for each new token
|
||||
- Significantly faster generation for long sequences
|
||||
- Lower memory bandwidth usage
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
from models import TransformerModel, OptimizedInference
|
||||
|
||||
model = TransformerModel(...)
|
||||
optimizer = model.get_optimized_inference()
|
||||
|
||||
# Generate with KV caching
|
||||
generated = optimizer.generate_with_cache(
|
||||
input_ids=input_ids,
|
||||
max_length=100,
|
||||
temperature=0.8,
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Optimized Attention Computation
|
||||
|
||||
**Location**: `models/optimized_attention.py`
|
||||
|
||||
Implements optimized attention computation using PyTorch's `scaled_dot_product_attention` when available (similar to Flash Attention).
|
||||
|
||||
**Features**:
|
||||
- Uses PyTorch's optimized attention implementation
|
||||
- Supports causal masking efficiently
|
||||
- Reduces memory usage during attention computation
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
from models import TransformerModel
|
||||
from models.blocks import TransformerBlock
|
||||
|
||||
# Use optimized attention in transformer blocks
|
||||
block = TransformerBlock(
|
||||
d_model=512,
|
||||
num_heads=8,
|
||||
use_optimized_attention=True, # Enable optimized attention
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Retrieval Cache for Similar Queries
|
||||
|
||||
**Location**: `models/optimized_attention.py`
|
||||
|
||||
Implements approximate caching for retrieval results, reducing expensive vector database lookups by caching similar queries.
|
||||
|
||||
**Features**:
|
||||
- Cosine similarity-based cache lookup
|
||||
- Configurable similarity threshold
|
||||
- Automatic cache eviction when full
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
from models.optimized_attention import RetrievalCache
|
||||
|
||||
cache = RetrievalCache(max_size=1000, similarity_threshold=0.9)
|
||||
|
||||
# Store retrieval results
|
||||
cache.set(query_hash, query_embedding, retrieved_docs)
|
||||
|
||||
# Retrieve cached results
|
||||
results = cache.get(query_hash, query_embedding)
|
||||
```
|
||||
|
||||
### 4. Prefetching Mechanisms
|
||||
|
||||
**Location**: `models/prefetching.py`
|
||||
|
||||
#### 4.1 PrefetchDataLoader
|
||||
Prefetches batches in background threads, reducing GPU idle time.
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
from models.prefetching import PrefetchDataLoader
|
||||
from data import create_dataloader
|
||||
|
||||
dataloader = create_dataloader(...)
|
||||
prefetch_loader = PrefetchDataLoader(
|
||||
dataloader=dataloader,
|
||||
prefetch_factor=2,
|
||||
device=device,
|
||||
)
|
||||
```
|
||||
|
||||
#### 4.2 LookaheadRetriever
|
||||
Prefetches retrieval results for anticipated queries.
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
from models.prefetching import LookaheadRetriever
|
||||
|
||||
def retrieve(query: str):
|
||||
# Your retrieval function
|
||||
return documents
|
||||
|
||||
retriever = LookaheadRetriever(
|
||||
retrieval_fn=retrieve,
|
||||
lookahead_window=3,
|
||||
)
|
||||
|
||||
# Start prefetching
|
||||
retriever.start_prefetching(query_stream)
|
||||
|
||||
# Get results (checks cache first)
|
||||
results = retriever.get(query)
|
||||
```
|
||||
|
||||
#### 4.3 BatchPrefetcher
|
||||
Groups queries into batches for efficient batch retrieval.
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
from models.prefetching import BatchPrefetcher
|
||||
|
||||
def batch_retrieve(queries: List[str]):
|
||||
# Batch retrieval function
|
||||
return [documents for each query]
|
||||
|
||||
prefetcher = BatchPrefetcher(
|
||||
batch_retrieval_fn=batch_retrieve,
|
||||
batch_size=8,
|
||||
)
|
||||
|
||||
prefetcher.start_prefetching(query_stream)
|
||||
results = prefetcher.get(query)
|
||||
```
|
||||
|
||||
### 5. Optimized Batch Inference
|
||||
|
||||
**Location**: `models/optimized_attention.py`
|
||||
|
||||
The `OptimizedInference` class provides batch generation utilities for processing multiple prompts efficiently.
|
||||
|
||||
**Features**:
|
||||
- Batch processing for multiple prompts
|
||||
- Automatic padding and batching
|
||||
- Efficient memory usage
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
from models import OptimizedInference
|
||||
|
||||
optimizer = model.get_optimized_inference()
|
||||
|
||||
# Generate for multiple prompts in batches
|
||||
results = optimizer.batch_generate(
|
||||
input_ids_list=[prompt1_ids, prompt2_ids, ...],
|
||||
max_length=100,
|
||||
batch_size=8,
|
||||
)
|
||||
```
|
||||
|
||||
## Performance Improvements
|
||||
|
||||
These optimizations provide the following benefits:
|
||||
|
||||
1. **Faster Inference**: KV caching reduces generation time by 2-5x for long sequences
|
||||
2. **Reduced Latency**: Prefetching reduces end-to-end latency by overlapping computation and I/O
|
||||
3. **Lower Costs**: Retrieval caching reduces expensive vector database calls
|
||||
4. **Better Throughput**: Batch processing increases throughput for multiple requests
|
||||
|
||||
## Integration
|
||||
|
||||
### Using Optimized Inference in Production
|
||||
|
||||
1. **Enable optimized attention** (for inference only):
|
||||
```python
|
||||
model = TransformerModel(
|
||||
...,
|
||||
use_optimized_attention=True, # Set in TransformerBlock
|
||||
)
|
||||
```
|
||||
|
||||
2. **Use optimized inference utility**:
|
||||
```python
|
||||
optimizer = model.get_optimized_inference()
|
||||
generated = optimizer.generate_with_cache(...)
|
||||
```
|
||||
|
||||
3. **Enable prefetching**:
|
||||
```python
|
||||
prefetch_loader = PrefetchDataLoader(dataloader, prefetch_factor=2)
|
||||
```
|
||||
|
||||
### CLI Usage
|
||||
|
||||
Use the `--optimized` flag when running inference:
|
||||
|
||||
```bash
|
||||
python inference.py \
|
||||
--checkpoint checkpoints/best_checkpoint.pt \
|
||||
--prompt "Your prompt here" \
|
||||
--optimized \
|
||||
--max-length 100
|
||||
```
|
||||
|
||||
## Example Script
|
||||
|
||||
See `example_optimized.py` for complete examples of all optimizations.
|
||||
|
||||
## References
|
||||
|
||||
Based on optimizations from:
|
||||
- "Optimizing LLM Inference and Retrieval: Novel Data Structures and Algorithms for Production RAG Systems"
|
||||
- TeleRAG: Lookahead Retrieval Mechanism
|
||||
- Flash Attention optimization techniques
|
||||
|
||||
Reference in New Issue
Block a user