""" Complete Transformer model for language modeling Incorporates best practices from multiple research papers Optimized for production RAG systems with KV caching and efficient inference """ import torch import torch.nn as nn from typing import Optional, Tuple from .blocks import TransformerBlock from .attention import PositionalEncoding from .optimized_attention import OptimizedMultiHeadAttention, RetrievalCache, OptimizedInference class TransformerModel(nn.Module): """ Full Transformer Language Model. Features: - Multi-head self-attention - Positional encoding - Layer normalization - Residual connections - Causal masking for autoregressive generation """ def __init__( self, vocab_size: int, d_model: int = 512, num_layers: int = 6, num_heads: int = 8, d_ff: int = 2048, max_seq_len: int = 512, dropout: float = 0.1, activation: str = 'gelu', layer_norm_eps: float = 1e-5, bias: bool = False, tie_weights: bool = True, use_optimized_attention: bool = False, ): """ Args: vocab_size: Vocabulary size d_model: Model dimension num_layers: Number of transformer layers num_heads: Number of attention heads d_ff: Feed-forward dimension max_seq_len: Maximum sequence length dropout: Dropout probability activation: Activation function ('gelu' or 'relu') layer_norm_eps: Epsilon for layer normalization bias: Whether to use bias in linear layers tie_weights: Whether to tie input and output embeddings """ super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.num_layers = num_layers self.num_heads = num_heads self.max_seq_len = max_seq_len # Token embeddings self.token_embedding = nn.Embedding(vocab_size, d_model) # Positional encoding self.pos_encoding = PositionalEncoding( d_model=d_model, max_len=max_seq_len, dropout=dropout, ) # Transformer blocks (use optimized attention if available) # Note: Set use_optimized_attention=True for production inference self.layers = nn.ModuleList([ TransformerBlock( d_model=d_model, num_heads=num_heads, d_ff=d_ff, dropout=dropout, activation=activation, layer_norm_eps=layer_norm_eps, bias=bias, causal=True, # Causal masking for autoregressive generation use_optimized_attention=use_optimized_attention, # Use parameter value ) for _ in range(num_layers) ]) # Final layer norm self.final_norm = nn.LayerNorm(d_model, eps=layer_norm_eps) # Output projection self.output_proj = nn.Linear(d_model, vocab_size, bias=bias) # Optionally tie weights if tie_weights: self.output_proj.weight = self.token_embedding.weight self.dropout = nn.Dropout(dropout) # Retrieval cache for RAG systems self.retrieval_cache = RetrievalCache(max_size=1000, similarity_threshold=0.9) # Initialize weights self._init_weights() def _init_weights(self): """Initialize weights following best practices.""" # Initialize embeddings nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02) # Initialize output projection if self.output_proj.weight is not self.token_embedding.weight: nn.init.normal_(self.output_proj.weight, mean=0.0, std=0.02) # Initialize linear layers for module in self.modules(): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Forward pass through the transformer model. Args: input_ids: Token indices [batch_size, seq_len] attention_mask: Optional attention mask [batch_size, seq_len] use_cache: Whether to use KV cache (for optimized attention) Returns: logits: Output logits [batch_size, seq_len, vocab_size] attention_weights: Optional attention weights """ batch_size, seq_len = input_ids.shape # Token embeddings x = self.token_embedding(input_ids) # [batch_size, seq_len, d_model] # Add positional encoding x = self.pos_encoding(x) x = self.dropout(x) # Create attention mask if not provided if attention_mask is None: attention_mask = torch.ones( batch_size, seq_len, device=input_ids.device, dtype=torch.bool ) # Expand mask for attention # [batch_size, seq_len] -> [batch_size, seq_len, seq_len] if attention_mask.dim() == 2: attention_mask = attention_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len) # Apply attention mask (invert: 1 for valid, 0 for masked) attention_mask = attention_mask.float() # Pass through transformer blocks for layer in self.layers: x = layer(x, mask=attention_mask, use_cache=use_cache) # Final layer norm x = self.final_norm(x) # Output projection logits = self.output_proj(x) # [batch_size, seq_len, vocab_size] return logits, None def generate( self, input_ids: torch.Tensor, max_length: int = 100, temperature: float = 1.0, top_k: Optional[int] = None, top_p: float = 1.0, do_sample: bool = True, pad_token_id: Optional[int] = None, ) -> torch.Tensor: """ Autoregressive generation. Args: input_ids: Starting token indices [batch_size, seq_len] max_length: Maximum generation length temperature: Sampling temperature top_k: Top-k sampling parameter top_p: Nucleus sampling parameter do_sample: Whether to sample or use greedy decoding pad_token_id: Padding token ID Returns: Generated token sequences """ self.eval() device = input_ids.device batch_size = input_ids.shape[0] generated = input_ids.clone() with torch.no_grad(): for _ in range(max_length - input_ids.shape[1]): # Forward pass logits, _ = self.forward(generated) next_token_logits = logits[:, -1, :] / temperature # Apply top-k filtering if top_k is not None and top_k > 0: indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] next_token_logits[indices_to_remove] = float('-inf') # Apply top-p (nucleus) filtering if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above threshold sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) next_token_logits[indices_to_remove] = float('-inf') # Sample or take argmax if do_sample: probs = torch.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) # Early stopping: stop if EOS or padding token is generated (for batch_size=1) if batch_size == 1: eos_token_id = getattr(self, 'eos_token_id', None) or 3 # Default EOS token if next_token.item() == eos_token_id: break # Early stopping: stop if padding token is generated (prevent generating padding) if pad_token_id is not None and next_token.item() == pad_token_id: break # Append to generated sequence generated = torch.cat([generated, next_token], dim=1) return generated def get_optimized_inference(self) -> OptimizedInference: """ Get optimized inference utility with KV caching and batching. Returns: OptimizedInference instance """ return OptimizedInference(self, next(self.parameters()).device) def get_num_params(self) -> int: """Return the number of trainable parameters.""" return sum(p.numel() for p in self.parameters() if p.requires_grad)