From 8b604a19250d278d310e725d791aad5f1c0bbe1f Mon Sep 17 00:00:00 2001 From: Carlos Gutierrez Date: Tue, 18 Nov 2025 23:23:50 -0500 Subject: [PATCH] Adding paper --- models/blocks.py | 8 +++++++- models/optimized_attention.py | 22 ++++++++++++++++++---- models/transformer.py | 4 +++- 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/models/blocks.py b/models/blocks.py index 4c113eb..df8f68f 100644 --- a/models/blocks.py +++ b/models/blocks.py @@ -125,6 +125,7 @@ class TransformerBlock(nn.Module): self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, + use_cache: bool = False, ) -> torch.Tensor: """ Forward pass through transformer block. @@ -132,6 +133,7 @@ class TransformerBlock(nn.Module): Args: x: Input tensor [batch_size, seq_len, d_model] mask: Optional attention mask + use_cache: Whether to use KV cache (for optimized attention) Returns: Output tensor [batch_size, seq_len, d_model] @@ -139,7 +141,11 @@ class TransformerBlock(nn.Module): # Pre-norm self-attention with residual connection residual = x x = self.norm1(x) - attn_out, _ = self.self_attn(x, x, x, mask=mask) + # Pass use_cache to attention layer if it's OptimizedMultiHeadAttention + if isinstance(self.self_attn, OptimizedMultiHeadAttention): + attn_out, _ = self.self_attn(x, x, x, mask=mask, use_cache=use_cache) + else: + attn_out, _ = self.self_attn(x, x, x, mask=mask) x = residual + self.dropout(attn_out) # Pre-norm feed-forward with residual connection diff --git a/models/optimized_attention.py b/models/optimized_attention.py index c1ddc08..f98b99d 100644 --- a/models/optimized_attention.py +++ b/models/optimized_attention.py @@ -119,8 +119,17 @@ class OptimizedMultiHeadAttention(nn.Module): # Use KV cache if available and enabled if use_cache and self.kv_cache is not None: - # Append new keys and values to cache - self.kv_cache.append(K, V) + # Get current cache length + cached_len = self.kv_cache.keys.shape[2] if self.kv_cache.keys.numel() > 0 else 0 + + # Only append new tokens (those after the cached length) + if cached_len < seq_len: + # Extract only the new tokens to append + new_k = K[:, :, cached_len:, :] # [batch_size, num_heads, new_seq_len, d_k] + new_v = V[:, :, cached_len:, :] + self.kv_cache.append(new_k, new_v) + + # Use all cached keys and values K = self.kv_cache.keys V = self.kv_cache.values kv_seq_len = K.shape[2] @@ -328,9 +337,11 @@ class OptimizedInference: generated = input_ids.clone() + # Process initial prompt to populate cache + logits, _ = self.model(generated, use_cache=True) + for _ in range(max_length - input_ids.shape[1]): - # Forward pass - logits, _ = self.model(generated) + # Get next token logits from last position next_token_logits = logits[:, -1, :] / temperature # Apply top-k filtering @@ -370,6 +381,9 @@ class OptimizedInference: # Append to generated sequence generated = torch.cat([generated, next_token], dim=1) + + # Forward pass with cache enabled (processes full sequence but uses cached K/V) + logits, _ = self.model(generated, use_cache=True) # Clear KV cache for module in self.model.modules(): diff --git a/models/transformer.py b/models/transformer.py index 8cae9f2..ff9e099 100644 --- a/models/transformer.py +++ b/models/transformer.py @@ -129,6 +129,7 @@ class TransformerModel(nn.Module): self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Forward pass through the transformer model. @@ -136,6 +137,7 @@ class TransformerModel(nn.Module): Args: input_ids: Token indices [batch_size, seq_len] attention_mask: Optional attention mask [batch_size, seq_len] + use_cache: Whether to use KV cache (for optimized attention) Returns: logits: Output logits [batch_size, seq_len, vocab_size] @@ -166,7 +168,7 @@ class TransformerModel(nn.Module): # Pass through transformer blocks for layer in self.layers: - x = layer(x, mask=attention_mask) + x = layer(x, mask=attention_mask, use_cache=use_cache) # Final layer norm x = self.final_norm(x)