diff --git a/config.py b/config.py index 0f4b8a7..595f140 100644 --- a/config.py +++ b/config.py @@ -21,6 +21,7 @@ class ModelConfig: layer_norm_eps: float = 1e-5 bias: bool = False tie_weights: bool = True + use_optimized_attention: bool = False # Enable KV caching optimizations @dataclass diff --git a/config_quick_optimized.json b/config_quick_optimized.json new file mode 100644 index 0000000..ba54287 --- /dev/null +++ b/config_quick_optimized.json @@ -0,0 +1,38 @@ +{ + "model": { + "vocab_size": 50257, + "d_model": 512, + "num_layers": 6, + "num_heads": 8, + "d_ff": 2048, + "max_seq_len": 512, + "dropout": 0.1, + "activation": "gelu", + "layer_norm_eps": 1e-5, + "bias": false, + "tie_weights": true, + "use_optimized_attention": true + }, + "training": { + "batch_size": 16, + "max_epochs": 5, + "learning_rate": 2e-4, + "weight_decay": 0.01, + "warmup_steps": 500, + "max_grad_norm": 1.0, + "gradient_accumulation_steps": 8, + "use_amp": true, + "save_dir": "./checkpoints_optimized", + "log_interval": 25, + "eval_interval": 200 + }, + "data": { + "data_dir": "./data", + "max_length": 512, + "stride": null, + "num_workers": 8 + }, + "device": "cuda", + "seed": 42 +} + diff --git a/models/transformer.py b/models/transformer.py index 0cd3690..8cae9f2 100644 --- a/models/transformer.py +++ b/models/transformer.py @@ -36,6 +36,7 @@ class TransformerModel(nn.Module): layer_norm_eps: float = 1e-5, bias: bool = False, tie_weights: bool = True, + use_optimized_attention: bool = False, ): """ Args: @@ -81,7 +82,7 @@ class TransformerModel(nn.Module): layer_norm_eps=layer_norm_eps, bias=bias, causal=True, # Causal masking for autoregressive generation - use_optimized_attention=False, # Set to True for inference optimizations + use_optimized_attention=use_optimized_attention, # Use parameter value ) for _ in range(num_layers) ])