Adding paper

This commit is contained in:
Carlos Gutierrez
2025-11-18 23:23:50 -05:00
parent 7501839145
commit 8b604a1925
3 changed files with 28 additions and 6 deletions

View File

@@ -129,6 +129,7 @@ class TransformerModel(nn.Module):
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Forward pass through the transformer model.
@@ -136,6 +137,7 @@ class TransformerModel(nn.Module):
Args:
input_ids: Token indices [batch_size, seq_len]
attention_mask: Optional attention mask [batch_size, seq_len]
use_cache: Whether to use KV cache (for optimized attention)
Returns:
logits: Output logits [batch_size, seq_len, vocab_size]
@@ -166,7 +168,7 @@ class TransformerModel(nn.Module):
# Pass through transformer blocks
for layer in self.layers:
x = layer(x, mask=attention_mask)
x = layer(x, mask=attention_mask, use_cache=use_cache)
# Final layer norm
x = self.final_norm(x)