Files
sheepOp/models/optimized_attention.py
Carlos Gutierrez 8b604a1925 Adding paper
2025-11-18 23:23:50 -05:00

447 lines
17 KiB
Python

"""
Optimized attention mechanisms for production RAG systems
Implements KV caching, optimized attention computation, and retrieval optimizations
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple, Dict, List
from dataclasses import dataclass
@dataclass
class KVCache:
"""Key-Value cache for efficient autoregressive generation."""
keys: torch.Tensor # [batch_size, num_heads, seq_len, d_k]
values: torch.Tensor # [batch_size, num_heads, seq_len, d_k]
def append(self, new_keys: torch.Tensor, new_values: torch.Tensor):
"""Append new keys and values to the cache."""
self.keys = torch.cat([self.keys, new_keys], dim=2)
self.values = torch.cat([self.values, new_values], dim=2)
def clear(self):
"""Clear the cache."""
self.keys = None
self.values = None
class OptimizedMultiHeadAttention(nn.Module):
"""
Optimized Multi-Head Attention with KV caching and efficient computation.
Features:
- KV cache for autoregressive generation
- Optimized attention computation
- Support for incremental decoding
"""
def __init__(
self,
d_model: int,
num_heads: int,
dropout: float = 0.1,
bias: bool = False,
causal: bool = False,
use_flash_attention: bool = False,
):
"""
Args:
d_model: Model dimension
num_heads: Number of attention heads
dropout: Dropout probability
bias: Whether to use bias in linear layers
causal: Whether to use causal masking
use_flash_attention: Whether to use optimized flash attention (if available)
"""
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.causal = causal
self.use_flash_attention = use_flash_attention
# Linear projections for Q, K, V
self.q_proj = nn.Linear(d_model, d_model, bias=bias)
self.k_proj = nn.Linear(d_model, d_model, bias=bias)
self.v_proj = nn.Linear(d_model, d_model, bias=bias)
self.out_proj = nn.Linear(d_model, d_model, bias=bias)
self.dropout = nn.Dropout(dropout)
self.scale = 1.0 / math.sqrt(self.d_k)
# KV cache for inference
self.kv_cache: Optional[KVCache] = None
def forward(
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
cache_position: Optional[int] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Forward pass with optional KV caching.
Args:
query: Query tensor [batch_size, seq_len, d_model]
key: Key tensor [batch_size, seq_len, d_model] (if None, uses query)
value: Value tensor [batch_size, seq_len, d_model] (if None, uses query)
mask: Optional attention mask [batch_size, seq_len, seq_len]
use_cache: Whether to use KV cache
cache_position: Position in cache for incremental decoding
Returns:
output: Attention output [batch_size, seq_len, d_model]
attention_weights: Attention weights [batch_size, num_heads, seq_len, seq_len]
"""
if key is None:
key = query
if value is None:
value = query
batch_size, seq_len, _ = query.shape
# Project Q, K, V
Q = self.q_proj(query) # [batch_size, seq_len, d_model]
K = self.k_proj(key)
V = self.v_proj(value)
# Reshape for multi-head attention
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) # [batch_size, num_heads, seq_len, d_k]
K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# Use KV cache if available and enabled
if use_cache and self.kv_cache is not None:
# Get current cache length
cached_len = self.kv_cache.keys.shape[2] if self.kv_cache.keys.numel() > 0 else 0
# Only append new tokens (those after the cached length)
if cached_len < seq_len:
# Extract only the new tokens to append
new_k = K[:, :, cached_len:, :] # [batch_size, num_heads, new_seq_len, d_k]
new_v = V[:, :, cached_len:, :]
self.kv_cache.append(new_k, new_v)
# Use all cached keys and values
K = self.kv_cache.keys
V = self.kv_cache.values
kv_seq_len = K.shape[2]
else:
kv_seq_len = seq_len
# Compute attention scores with optimized computation
if self.use_flash_attention and hasattr(F, 'scaled_dot_product_attention'):
# Use PyTorch's optimized scaled dot product attention
# Prepare mask for scaled_dot_product_attention
# Note: scaled_dot_product_attention expects boolean mask where True=attend, False=mask
# Our mask is float where 1.0=valid, 0.0=masked, so we need to convert
attn_mask = None
if mask is not None:
# Convert float mask to boolean: 1.0 -> True (attend), 0.0 -> False (mask)
if mask.dim() == 2:
# [seq_len, seq_len] -> [1, 1, seq_len, seq_len] for broadcasting
attn_mask = (mask > 0.5).unsqueeze(0).unsqueeze(0)
elif mask.dim() == 3:
# [batch, seq_len, seq_len] -> [batch, 1, seq_len, seq_len] for broadcasting
attn_mask = (mask > 0.5).unsqueeze(1)
else:
attn_mask = (mask > 0.5)
# Use causal masking only if no custom mask is provided
# If we have a custom mask, it should already include causal masking
use_causal = self.causal and (attn_mask is None)
output = F.scaled_dot_product_attention(
Q, K, V,
attn_mask=attn_mask,
dropout_p=self.dropout.p if self.training else 0.0,
is_causal=use_causal,
)
attention_weights = None # Flash attention doesn't return weights
else:
# Standard attention computation
scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # [batch_size, num_heads, seq_len, kv_seq_len]
# Apply causal mask if needed
if self.causal:
causal_mask = torch.triu(
torch.ones(seq_len, kv_seq_len, device=query.device, dtype=torch.bool),
diagonal=1
)
scores.masked_fill_(causal_mask, float('-inf'))
# Apply external mask if provided
if mask is not None:
if mask.dim() == 2:
mask = mask.unsqueeze(1).unsqueeze(1) # [batch_size, 1, seq_len, kv_seq_len]
scores.masked_fill_(mask == 0, float('-inf'))
# Compute attention weights
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
# Apply attention to values
output = torch.matmul(attention_weights, V) # [batch_size, num_heads, seq_len, d_k]
# Concatenate heads
output = output.transpose(1, 2).contiguous() # [batch_size, seq_len, num_heads, d_k]
output = output.view(batch_size, seq_len, self.d_model) # [batch_size, seq_len, d_model]
# Final projection
output = self.out_proj(output)
return output, attention_weights
def init_kv_cache(self, batch_size: int, max_length: int, device: torch.device):
"""Initialize KV cache for inference."""
self.kv_cache = KVCache(
keys=torch.empty(batch_size, self.num_heads, 0, self.d_k, device=device),
values=torch.empty(batch_size, self.num_heads, 0, self.d_k, device=device),
)
def clear_cache(self):
"""Clear the KV cache."""
self.kv_cache = None
class RetrievalCache:
"""
Approximate cache for retrieval results.
Reduces expensive vector database lookups by caching similar queries.
"""
def __init__(self, max_size: int = 1000, similarity_threshold: float = 0.9):
"""
Args:
max_size: Maximum number of cached entries
similarity_threshold: Minimum similarity to consider a cache hit
"""
self.max_size = max_size
self.similarity_threshold = similarity_threshold
self.cache: Dict[str, List[Dict]] = {} # query_hash -> retrieved_docs
self.query_embeddings: Dict[str, torch.Tensor] = {} # query_hash -> embedding
def get(self, query_hash: str, query_embedding: torch.Tensor) -> Optional[List[Dict]]:
"""
Retrieve cached results if similar query exists.
Args:
query_hash: Hash of the query
query_embedding: Embedding of the query
Returns:
Cached results if found, None otherwise
"""
# Check exact match first
if query_hash in self.cache:
return self.cache[query_hash]
# Check for similar queries
best_match = None
best_similarity = 0.0
for cached_hash, cached_embedding in self.query_embeddings.items():
# Compute cosine similarity
similarity = F.cosine_similarity(
query_embedding.unsqueeze(0),
cached_embedding.unsqueeze(0)
).item()
if similarity > best_similarity:
best_similarity = similarity
best_match = cached_hash
if best_similarity >= self.similarity_threshold and best_match:
return self.cache[best_match]
return None
def set(self, query_hash: str, query_embedding: torch.Tensor, results: List[Dict]):
"""
Store query and results in cache.
Args:
query_hash: Hash of the query
query_embedding: Embedding of the query
results: Retrieved documents/results
"""
# Remove oldest entry if cache is full
if len(self.cache) >= self.max_size:
oldest_key = next(iter(self.cache))
del self.cache[oldest_key]
del self.query_embeddings[oldest_key]
self.cache[query_hash] = results
self.query_embeddings[query_hash] = query_embedding
def clear(self):
"""Clear the cache."""
self.cache.clear()
self.query_embeddings.clear()
class OptimizedInference:
"""
Optimized inference utilities for production RAG systems.
Includes prefetching, batching, and parallel processing.
"""
def __init__(self, model: nn.Module, device: torch.device):
"""
Args:
model: Model to use for inference
device: Device to run inference on
"""
self.model = model
self.device = device
self.model.eval()
@torch.no_grad()
def generate_with_cache(
self,
input_ids: torch.Tensor,
max_length: int = 100,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: float = 1.0,
do_sample: bool = True,
) -> torch.Tensor:
"""
Generate with KV cache for efficient autoregressive generation.
Args:
input_ids: Starting token indices [batch_size, seq_len]
max_length: Maximum generation length
temperature: Sampling temperature
top_k: Top-k sampling parameter
top_p: Nucleus sampling parameter
do_sample: Whether to sample or use greedy decoding
Returns:
Generated token sequences
"""
batch_size = input_ids.shape[0]
device = input_ids.device
# Initialize KV cache in all attention layers
for module in self.model.modules():
if isinstance(module, OptimizedMultiHeadAttention):
module.init_kv_cache(batch_size, max_length, device)
generated = input_ids.clone()
# Process initial prompt to populate cache
logits, _ = self.model(generated, use_cache=True)
for _ in range(max_length - input_ids.shape[1]):
# Get next token logits from last position
next_token_logits = logits[:, -1, :] / temperature
# Apply top-k filtering
if top_k is not None and top_k > 0:
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_logits[indices_to_remove] = float('-inf')
# Apply top-p (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
# Sample or take argmax
if do_sample:
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
# Early stopping: stop if EOS or padding token is generated (for batch_size=1)
if batch_size == 1:
eos_token_id = 3 # Default EOS token ID
if next_token.item() == eos_token_id:
break
# Early stopping: stop if padding token is generated (prevent generating padding)
pad_token_id = 0 # Default padding token ID
if next_token.item() == pad_token_id:
break
# Append to generated sequence
generated = torch.cat([generated, next_token], dim=1)
# Forward pass with cache enabled (processes full sequence but uses cached K/V)
logits, _ = self.model(generated, use_cache=True)
# Clear KV cache
for module in self.model.modules():
if isinstance(module, OptimizedMultiHeadAttention):
module.clear_cache()
return generated
@torch.no_grad()
def batch_generate(
self,
input_ids_list: List[torch.Tensor],
max_length: int = 100,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: float = 1.0,
batch_size: int = 8,
) -> List[torch.Tensor]:
"""
Generate for multiple prompts in batches for efficiency.
Args:
input_ids_list: List of starting token sequences
max_length: Maximum generation length
temperature: Sampling temperature
top_k: Top-k sampling parameter
top_p: Nucleus sampling parameter
batch_size: Batch size for processing
Returns:
List of generated sequences
"""
results = []
for i in range(0, len(input_ids_list), batch_size):
batch = input_ids_list[i:i + batch_size]
# Pad to same length
max_len = max(seq.shape[1] for seq in batch)
padded_batch = []
for seq in batch:
padding = torch.zeros(seq.shape[0], max_len - seq.shape[1],
dtype=seq.dtype, device=seq.device)
padded_batch.append(torch.cat([seq, padding], dim=1))
batch_tensor = torch.cat(padded_batch, dim=0)
# Generate for batch
generated = self.generate_with_cache(
batch_tensor,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
results.extend([gen for gen in generated])
return results