- Complete transformer implementation from scratch - Training pipeline with gradient accumulation and mixed precision - Optimized inference with KV caching - Multi-format data processing (PDFs, images, code, text) - Comprehensive documentation - Apache 2.0 license - Example training plots included in docs/images/
144 lines
3.5 KiB
Markdown
144 lines
3.5 KiB
Markdown
# Retraining Guide for Better Model Performance
|
|
|
|
## Current Issues
|
|
|
|
Your model only trained for **40 global steps** across 10 epochs, which means:
|
|
- Very little training data (~4 batches per epoch)
|
|
- Model hasn't learned language patterns
|
|
- Model just repeats input and stops
|
|
|
|
## Retraining Recommendations
|
|
|
|
### 1. **Increase Training Data**
|
|
|
|
The model needs much more data. Check your current data:
|
|
|
|
```bash
|
|
# Check how much data you have
|
|
wc -l data/*.txt
|
|
```
|
|
|
|
**Recommendations:**
|
|
- **Minimum**: 10,000+ text samples
|
|
- **Good**: 100,000+ text samples
|
|
- **Better**: 1,000,000+ text samples
|
|
|
|
### 2. **Update Training Configuration**
|
|
|
|
Edit `config.json` for better training:
|
|
|
|
```json
|
|
{
|
|
"training": {
|
|
"batch_size": 32,
|
|
"max_epochs": 50, // Increase from 10 to 50+
|
|
"learning_rate": 1e-4,
|
|
"weight_decay": 0.01,
|
|
"warmup_steps": 1000,
|
|
"max_grad_norm": 1.0,
|
|
"gradient_accumulation_steps": 4, // Increase to simulate larger batches
|
|
"use_amp": true,
|
|
"save_dir": "./checkpoints",
|
|
"log_interval": 10, // More frequent logging
|
|
"eval_interval": 500 // More frequent evaluation
|
|
}
|
|
}
|
|
```
|
|
|
|
### 3. **Add Validation Set**
|
|
|
|
Split your data for validation:
|
|
|
|
```python
|
|
# In train.py, add validation split
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
train_texts, val_texts = train_test_split(texts, test_size=0.1, random_state=42)
|
|
```
|
|
|
|
### 4. **Improve Training Data Quality**
|
|
|
|
Ensure your training data:
|
|
- ✅ Contains complete sentences/paragraphs
|
|
- ✅ Has diverse topics and styles
|
|
- ✅ Doesn't have excessive padding
|
|
- ✅ Uses proper text formatting
|
|
|
|
### 5. **Monitor Training**
|
|
|
|
Watch for:
|
|
- **Loss decreasing**: Should trend downward
|
|
- **Perplexity**: Should decrease (lower is better)
|
|
- **Generation quality**: Test periodically during training
|
|
|
|
### 6. **Training Command**
|
|
|
|
```bash
|
|
# Train with more data
|
|
python3 train.py \
|
|
--data data/your_training_data.txt \
|
|
--config config.json \
|
|
--output ./checkpoints \
|
|
--device cpu # or cuda/mps
|
|
```
|
|
|
|
### 7. **Check Training Progress**
|
|
|
|
During training, you should see:
|
|
```
|
|
Epoch 1: Train Loss = 8.5 → Epoch 10: Train Loss = 6.0 → Epoch 50: Train Loss = 3.5
|
|
```
|
|
|
|
If loss stops decreasing, the model has converged.
|
|
|
|
### 8. **Early Stopping**
|
|
|
|
Consider adding early stopping if validation loss plateaus:
|
|
- Stop if validation loss doesn't improve for 5 epochs
|
|
- Save the best model based on validation loss
|
|
|
|
### 9. **Test During Training**
|
|
|
|
After each epoch, test generation:
|
|
|
|
```bash
|
|
python3 inference.py \
|
|
--checkpoint checkpoints/checkpoint_epoch_X.pt \
|
|
--prompt "The future of" \
|
|
--optimized
|
|
```
|
|
|
|
Good training should show:
|
|
- ✅ Model generates coherent text
|
|
- ✅ Model continues beyond input prompt
|
|
- ✅ Model doesn't immediately generate padding tokens
|
|
|
|
## Quick Start Retraining
|
|
|
|
1. **Get more training data** (most important!)
|
|
2. **Update config.json** with more epochs
|
|
3. **Start training**:
|
|
```bash
|
|
python3 train.py --data data/your_data.txt --config config.json
|
|
```
|
|
4. **Monitor loss** - should decrease over time
|
|
5. **Test periodically** - check if generation improves
|
|
|
|
## Expected Results
|
|
|
|
After proper training:
|
|
- Loss should decrease from ~8-10 to ~2-4
|
|
- Perplexity should decrease from ~3000 to ~10-50
|
|
- Model should generate 50+ tokens before stopping
|
|
- Generated text should be coherent and diverse
|
|
|
|
## Next Steps
|
|
|
|
1. ✅ Early stopping is now fixed (prevents padding tokens)
|
|
2. ⏳ **Retrain with more data and epochs**
|
|
3. ⏳ Monitor training metrics
|
|
4. ⏳ Test generation quality during training
|
|
|
|
Good luck with retraining! 🚀
|
|
|