fixing memory

This commit is contained in:
Carlos Gutierrez
2025-11-16 16:39:11 -05:00
parent 49f9e700b4
commit 3fef3e2689
3 changed files with 41 additions and 1 deletions

View File

@@ -36,6 +36,7 @@ class TransformerModel(nn.Module):
layer_norm_eps: float = 1e-5,
bias: bool = False,
tie_weights: bool = True,
use_optimized_attention: bool = False,
):
"""
Args:
@@ -81,7 +82,7 @@ class TransformerModel(nn.Module):
layer_norm_eps=layer_norm_eps,
bias=bias,
causal=True, # Causal masking for autoregressive generation
use_optimized_attention=False, # Set to True for inference optimizations
use_optimized_attention=use_optimized_attention, # Use parameter value
)
for _ in range(num_layers)
])