fixing memory
This commit is contained in:
@@ -36,6 +36,7 @@ class TransformerModel(nn.Module):
|
||||
layer_norm_eps: float = 1e-5,
|
||||
bias: bool = False,
|
||||
tie_weights: bool = True,
|
||||
use_optimized_attention: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -81,7 +82,7 @@ class TransformerModel(nn.Module):
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
bias=bias,
|
||||
causal=True, # Causal masking for autoregressive generation
|
||||
use_optimized_attention=False, # Set to True for inference optimizations
|
||||
use_optimized_attention=use_optimized_attention, # Use parameter value
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user