Adding paper
This commit is contained in:
@@ -125,6 +125,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass through transformer block.
|
||||
@@ -132,6 +133,7 @@ class TransformerBlock(nn.Module):
|
||||
Args:
|
||||
x: Input tensor [batch_size, seq_len, d_model]
|
||||
mask: Optional attention mask
|
||||
use_cache: Whether to use KV cache (for optimized attention)
|
||||
|
||||
Returns:
|
||||
Output tensor [batch_size, seq_len, d_model]
|
||||
@@ -139,7 +141,11 @@ class TransformerBlock(nn.Module):
|
||||
# Pre-norm self-attention with residual connection
|
||||
residual = x
|
||||
x = self.norm1(x)
|
||||
attn_out, _ = self.self_attn(x, x, x, mask=mask)
|
||||
# Pass use_cache to attention layer if it's OptimizedMultiHeadAttention
|
||||
if isinstance(self.self_attn, OptimizedMultiHeadAttention):
|
||||
attn_out, _ = self.self_attn(x, x, x, mask=mask, use_cache=use_cache)
|
||||
else:
|
||||
attn_out, _ = self.self_attn(x, x, x, mask=mask)
|
||||
x = residual + self.dropout(attn_out)
|
||||
|
||||
# Pre-norm feed-forward with residual connection
|
||||
|
||||
Reference in New Issue
Block a user