Adding paper
This commit is contained in:
@@ -129,6 +129,7 @@ class TransformerModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Forward pass through the transformer model.
|
||||
@@ -136,6 +137,7 @@ class TransformerModel(nn.Module):
|
||||
Args:
|
||||
input_ids: Token indices [batch_size, seq_len]
|
||||
attention_mask: Optional attention mask [batch_size, seq_len]
|
||||
use_cache: Whether to use KV cache (for optimized attention)
|
||||
|
||||
Returns:
|
||||
logits: Output logits [batch_size, seq_len, vocab_size]
|
||||
@@ -166,7 +168,7 @@ class TransformerModel(nn.Module):
|
||||
|
||||
# Pass through transformer blocks
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask=attention_mask)
|
||||
x = layer(x, mask=attention_mask, use_cache=use_cache)
|
||||
|
||||
# Final layer norm
|
||||
x = self.final_norm(x)
|
||||
|
||||
Reference in New Issue
Block a user