Adding paper

This commit is contained in:
Carlos Gutierrez
2025-11-18 23:23:50 -05:00
parent 7501839145
commit 8b604a1925
3 changed files with 28 additions and 6 deletions

View File

@@ -125,6 +125,7 @@ class TransformerBlock(nn.Module):
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
) -> torch.Tensor:
"""
Forward pass through transformer block.
@@ -132,6 +133,7 @@ class TransformerBlock(nn.Module):
Args:
x: Input tensor [batch_size, seq_len, d_model]
mask: Optional attention mask
use_cache: Whether to use KV cache (for optimized attention)
Returns:
Output tensor [batch_size, seq_len, d_model]
@@ -139,7 +141,11 @@ class TransformerBlock(nn.Module):
# Pre-norm self-attention with residual connection
residual = x
x = self.norm1(x)
attn_out, _ = self.self_attn(x, x, x, mask=mask)
# Pass use_cache to attention layer if it's OptimizedMultiHeadAttention
if isinstance(self.self_attn, OptimizedMultiHeadAttention):
attn_out, _ = self.self_attn(x, x, x, mask=mask, use_cache=use_cache)
else:
attn_out, _ = self.self_attn(x, x, x, mask=mask)
x = residual + self.dropout(attn_out)
# Pre-norm feed-forward with residual connection