Files
sheepOp/models/transformer.py
Carlos Gutierrez 8b604a1925 Adding paper
2025-11-18 23:23:50 -05:00

272 lines
9.9 KiB
Python

"""
Complete Transformer model for language modeling
Incorporates best practices from multiple research papers
Optimized for production RAG systems with KV caching and efficient inference
"""
import torch
import torch.nn as nn
from typing import Optional, Tuple
from .blocks import TransformerBlock
from .attention import PositionalEncoding
from .optimized_attention import OptimizedMultiHeadAttention, RetrievalCache, OptimizedInference
class TransformerModel(nn.Module):
"""
Full Transformer Language Model.
Features:
- Multi-head self-attention
- Positional encoding
- Layer normalization
- Residual connections
- Causal masking for autoregressive generation
"""
def __init__(
self,
vocab_size: int,
d_model: int = 512,
num_layers: int = 6,
num_heads: int = 8,
d_ff: int = 2048,
max_seq_len: int = 512,
dropout: float = 0.1,
activation: str = 'gelu',
layer_norm_eps: float = 1e-5,
bias: bool = False,
tie_weights: bool = True,
use_optimized_attention: bool = False,
):
"""
Args:
vocab_size: Vocabulary size
d_model: Model dimension
num_layers: Number of transformer layers
num_heads: Number of attention heads
d_ff: Feed-forward dimension
max_seq_len: Maximum sequence length
dropout: Dropout probability
activation: Activation function ('gelu' or 'relu')
layer_norm_eps: Epsilon for layer normalization
bias: Whether to use bias in linear layers
tie_weights: Whether to tie input and output embeddings
"""
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.num_layers = num_layers
self.num_heads = num_heads
self.max_seq_len = max_seq_len
# Token embeddings
self.token_embedding = nn.Embedding(vocab_size, d_model)
# Positional encoding
self.pos_encoding = PositionalEncoding(
d_model=d_model,
max_len=max_seq_len,
dropout=dropout,
)
# Transformer blocks (use optimized attention if available)
# Note: Set use_optimized_attention=True for production inference
self.layers = nn.ModuleList([
TransformerBlock(
d_model=d_model,
num_heads=num_heads,
d_ff=d_ff,
dropout=dropout,
activation=activation,
layer_norm_eps=layer_norm_eps,
bias=bias,
causal=True, # Causal masking for autoregressive generation
use_optimized_attention=use_optimized_attention, # Use parameter value
)
for _ in range(num_layers)
])
# Final layer norm
self.final_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
# Output projection
self.output_proj = nn.Linear(d_model, vocab_size, bias=bias)
# Optionally tie weights
if tie_weights:
self.output_proj.weight = self.token_embedding.weight
self.dropout = nn.Dropout(dropout)
# Retrieval cache for RAG systems
self.retrieval_cache = RetrievalCache(max_size=1000, similarity_threshold=0.9)
# Initialize weights
self._init_weights()
def _init_weights(self):
"""Initialize weights following best practices."""
# Initialize embeddings
nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02)
# Initialize output projection
if self.output_proj.weight is not self.token_embedding.weight:
nn.init.normal_(self.output_proj.weight, mean=0.0, std=0.02)
# Initialize linear layers
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Forward pass through the transformer model.
Args:
input_ids: Token indices [batch_size, seq_len]
attention_mask: Optional attention mask [batch_size, seq_len]
use_cache: Whether to use KV cache (for optimized attention)
Returns:
logits: Output logits [batch_size, seq_len, vocab_size]
attention_weights: Optional attention weights
"""
batch_size, seq_len = input_ids.shape
# Token embeddings
x = self.token_embedding(input_ids) # [batch_size, seq_len, d_model]
# Add positional encoding
x = self.pos_encoding(x)
x = self.dropout(x)
# Create attention mask if not provided
if attention_mask is None:
attention_mask = torch.ones(
batch_size, seq_len, device=input_ids.device, dtype=torch.bool
)
# Expand mask for attention
# [batch_size, seq_len] -> [batch_size, seq_len, seq_len]
if attention_mask.dim() == 2:
attention_mask = attention_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len)
# Apply attention mask (invert: 1 for valid, 0 for masked)
attention_mask = attention_mask.float()
# Pass through transformer blocks
for layer in self.layers:
x = layer(x, mask=attention_mask, use_cache=use_cache)
# Final layer norm
x = self.final_norm(x)
# Output projection
logits = self.output_proj(x) # [batch_size, seq_len, vocab_size]
return logits, None
def generate(
self,
input_ids: torch.Tensor,
max_length: int = 100,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: float = 1.0,
do_sample: bool = True,
pad_token_id: Optional[int] = None,
) -> torch.Tensor:
"""
Autoregressive generation.
Args:
input_ids: Starting token indices [batch_size, seq_len]
max_length: Maximum generation length
temperature: Sampling temperature
top_k: Top-k sampling parameter
top_p: Nucleus sampling parameter
do_sample: Whether to sample or use greedy decoding
pad_token_id: Padding token ID
Returns:
Generated token sequences
"""
self.eval()
device = input_ids.device
batch_size = input_ids.shape[0]
generated = input_ids.clone()
with torch.no_grad():
for _ in range(max_length - input_ids.shape[1]):
# Forward pass
logits, _ = self.forward(generated)
next_token_logits = logits[:, -1, :] / temperature
# Apply top-k filtering
if top_k is not None and top_k > 0:
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_logits[indices_to_remove] = float('-inf')
# Apply top-p (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above threshold
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
# Sample or take argmax
if do_sample:
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
# Early stopping: stop if EOS or padding token is generated (for batch_size=1)
if batch_size == 1:
eos_token_id = getattr(self, 'eos_token_id', None) or 3 # Default EOS token
if next_token.item() == eos_token_id:
break
# Early stopping: stop if padding token is generated (prevent generating padding)
if pad_token_id is not None and next_token.item() == pad_token_id:
break
# Append to generated sequence
generated = torch.cat([generated, next_token], dim=1)
return generated
def get_optimized_inference(self) -> OptimizedInference:
"""
Get optimized inference utility with KV caching and batching.
Returns:
OptimizedInference instance
"""
return OptimizedInference(self, next(self.parameters()).device)
def get_num_params(self) -> int:
"""Return the number of trainable parameters."""
return sum(p.numel() for p in self.parameters() if p.requires_grad)