fixing memory
This commit is contained in:
@@ -21,6 +21,7 @@ class ModelConfig:
|
|||||||
layer_norm_eps: float = 1e-5
|
layer_norm_eps: float = 1e-5
|
||||||
bias: bool = False
|
bias: bool = False
|
||||||
tie_weights: bool = True
|
tie_weights: bool = True
|
||||||
|
use_optimized_attention: bool = False # Enable KV caching optimizations
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
38
config_quick_optimized.json
Normal file
38
config_quick_optimized.json
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
{
|
||||||
|
"model": {
|
||||||
|
"vocab_size": 50257,
|
||||||
|
"d_model": 512,
|
||||||
|
"num_layers": 6,
|
||||||
|
"num_heads": 8,
|
||||||
|
"d_ff": 2048,
|
||||||
|
"max_seq_len": 512,
|
||||||
|
"dropout": 0.1,
|
||||||
|
"activation": "gelu",
|
||||||
|
"layer_norm_eps": 1e-5,
|
||||||
|
"bias": false,
|
||||||
|
"tie_weights": true,
|
||||||
|
"use_optimized_attention": true
|
||||||
|
},
|
||||||
|
"training": {
|
||||||
|
"batch_size": 16,
|
||||||
|
"max_epochs": 5,
|
||||||
|
"learning_rate": 2e-4,
|
||||||
|
"weight_decay": 0.01,
|
||||||
|
"warmup_steps": 500,
|
||||||
|
"max_grad_norm": 1.0,
|
||||||
|
"gradient_accumulation_steps": 8,
|
||||||
|
"use_amp": true,
|
||||||
|
"save_dir": "./checkpoints_optimized",
|
||||||
|
"log_interval": 25,
|
||||||
|
"eval_interval": 200
|
||||||
|
},
|
||||||
|
"data": {
|
||||||
|
"data_dir": "./data",
|
||||||
|
"max_length": 512,
|
||||||
|
"stride": null,
|
||||||
|
"num_workers": 8
|
||||||
|
},
|
||||||
|
"device": "cuda",
|
||||||
|
"seed": 42
|
||||||
|
}
|
||||||
|
|
||||||
@@ -36,6 +36,7 @@ class TransformerModel(nn.Module):
|
|||||||
layer_norm_eps: float = 1e-5,
|
layer_norm_eps: float = 1e-5,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
tie_weights: bool = True,
|
tie_weights: bool = True,
|
||||||
|
use_optimized_attention: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -81,7 +82,7 @@ class TransformerModel(nn.Module):
|
|||||||
layer_norm_eps=layer_norm_eps,
|
layer_norm_eps=layer_norm_eps,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
causal=True, # Causal masking for autoregressive generation
|
causal=True, # Causal masking for autoregressive generation
|
||||||
use_optimized_attention=False, # Set to True for inference optimizations
|
use_optimized_attention=use_optimized_attention, # Use parameter value
|
||||||
)
|
)
|
||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
])
|
])
|
||||||
|
|||||||
Reference in New Issue
Block a user