- Complete transformer implementation from scratch - Training pipeline with gradient accumulation and mixed precision - Optimized inference with KV caching - Multi-format data processing (PDFs, images, code, text) - Comprehensive documentation - Apache 2.0 license - Example training plots included in docs/images/
27 KiB
What is Training? Step-by-Step Explanation
Complete step-by-step explanation of model training: what it is, why we need data, why more data is better, and how the model learns.
Table of Contents
- What is Training?
- Why Do We Need Training?
- What Does the Model Learn?
- Why Do We Need Data?
- Why More Data is Better
- How Training Works: Step-by-Step
- The Training Process
- Loss Function
- Optimization
- Evaluation
- Common Questions
- Training Metrics and Artifacts
5.1 What is Training?
Simple Definition
Training is the process of teaching a neural network to make predictions by showing it examples and adjusting its parameters to minimize errors.
The Learning Analogy
Think of training like teaching a child:
Child Learning:
- You show examples: "This is a cat", "This is a dog"
- Child makes mistakes: Calls a cat "dog"
- You correct: "No, that's a cat"
- Child learns patterns from many examples
- Eventually, child recognizes cats and dogs correctly
Model Training:
- You show examples: "Hello" → next word "World"
- Model makes predictions: "Hello" → predicts "Hi"
- You compute error: Compare prediction to actual
- Model adjusts: Updates parameters to reduce error
- Process repeats: Shows many examples
- Eventually, model learns to predict correctly
What Happens During Training?
The model:
- Sees input data (examples)
- Makes predictions
- Compares predictions to correct answers
- Calculates how wrong it was (loss)
- Adjusts parameters to be less wrong
- Repeats millions of times
Result: A model that can make accurate predictions!
5.2 Why Do We Need Training?
The Problem
Untrained models are random:
Initial State:
Input: "Hello"
Model Prediction: Random guess
→ "apple" (30%)
→ "zebra" (25%)
→ "World" (5%)
→ Other random words...
The model doesn't know anything yet!
The Solution: Training
After Training:
Input: "Hello"
Model Prediction: Learned pattern
→ "World" (85%)
→ "there" (8%)
→ "friend" (3%)
→ Other reasonable words...
The model learned language patterns!
Why Training is Essential
Without Training:
- Random predictions
- No understanding of language
- No useful output
- Model is useless
With Training:
- Learned patterns
- Understanding of language
- Useful predictions
- Model is valuable
5.3 What Does the Model Learn?
1. Language Patterns
Learns:
- Word relationships ("Hello" often followed by "World")
- Grammar rules (subject-verb agreement)
- Sentence structure (nouns, verbs, adjectives)
- Context understanding (same word means different things)
Example:
"You" → "are" (learned: pronoun + verb agreement)
"The cat" → "sat" (learned: noun + verb)
"Machine learning" → "is" (learned: compound noun + verb)
2. Semantic Relationships
Learns:
- Similar words have similar meanings
- Related concepts cluster together
- Word embeddings capture meaning
- Context determines word usage
Example:
"cat" and "dog" → Similar embeddings (both animals)
"king" - "man" + "woman" ≈ "queen" (learned relationships)
3. Sequential Patterns
Learns:
- Predict next token based on context
- Long-range dependencies
- Common phrases and idioms
- Writing style and tone
Example:
"Once upon a time" → "there" (learned: story beginning)
"The quick brown fox" → "jumps" (learned: common phrase)
4. Statistical Patterns
Learns:
- How often words appear together
- Probability distributions over vocabulary
- Language statistics
- Common word sequences
Example:
"The" → very common (appears frequently)
"Antidisestablishmentarianism" → rare (appears rarely)
5.4 Why Do We Need Data?
The Fundamental Need
Models learn from examples, not from rules:
Rule-Based Approach (Old Way):
Programmer writes rules:
IF word == "Hello" THEN next_word = "World"
IF word == "The" THEN next_word = "cat"
...
Problems:
- Need to write millions of rules
- Can't capture all patterns
- Brittle and breaks easily
- Doesn't generalize
Data-Driven Approach (Modern Way):
Model learns from examples:
"Hello World" (example 1)
"The cat sat" (example 2)
...
Benefits:
- Automatically learns patterns
- Captures complex relationships
- Generalizes to new examples
- Handles ambiguity
Why Data is Essential
Data provides:
1. Examples to Learn From
Without data: Model has no examples
With data: Model sees millions of examples
2. Ground Truth
Without data: No correct answers
With data: Knows what correct predictions are
3. Patterns to Discover
Without data: Can't find patterns
With data: Discovers language patterns automatically
4. Evaluation
Without data: Can't measure performance
With data: Can test if model learned correctly
What Data Provides
Training Data:
- Input-output pairs
- Examples to learn from
- Patterns to discover
- Ground truth for comparison
Example:
Input: "Hello"
Output: "World"
Input: "Machine learning"
Output: "is"
Input: "The cat"
Output: "sat"
Each example teaches the model something!
5.5 Why More Data is Better
The Relationship Between Data and Performance
General Rule:
More Data → Better Performance
But why?
Reason 1: More Patterns
Little Data:
100 examples:
- See "Hello World" once
- See "Hello there" once
- Model uncertain: Which is more common?
More Data:
1,000,000 examples:
- See "Hello World" 500,000 times
- See "Hello there" 200,000 times
- See "Hello friend" 300,000 times
- Model confident: "Hello World" is most common
More examples = Better pattern recognition
Reason 2: Broader Coverage
Little Data:
Limited vocabulary:
- Only sees common words
- Misses rare words
- Poor generalization
More Data:
Comprehensive vocabulary:
- Sees common words frequently
- Sees rare words occasionally
- Good generalization
More examples = Better coverage
Reason 3: Better Generalization
Little Data:
Sees: "The cat sat on the mat"
Learns: Exact pattern
Test: "The dog sat on the rug"
Fails: Never saw "dog" or "rug"
More Data:
Sees: Many variations
- "The cat sat on the mat"
- "The dog sat on the rug"
- "The bird sat on the branch"
Learns: General pattern
Test: "The dog sat on the rug"
Succeeds: Understands pattern
More examples = Better generalization
Reason 4: Reduces Overfitting
Little Data:
Model memorizes examples:
- Perfect on training data
- Poor on new data
- Overfitting!
More Data:
Model learns patterns:
- Good on training data
- Good on new data
- Generalizes well!
More examples = Less overfitting
Reason 5: Statistical Confidence
Little Data:
10 examples of "Hello World"
→ Statistically uncertain
→ High variance in predictions
More Data:
1,000,000 examples of "Hello World"
→ Statistically confident
→ Low variance in predictions
More examples = More confident predictions
The Data-Performance Curve
Performance
│
100%│ ●───── (More data needed)
│ ●─────
│ ●─────
│ ●─────
│ ●─────
│
0%├───────────────────────────────────── Data
0 1K 10K 100K 1M 10M 100M
Diminishing Returns:
- First 1M examples: Huge improvement
- Next 9M examples: Good improvement
- Next 90M examples: Smaller improvement
- Beyond: Very small improvements
But more data is almost always better!
Real-World Examples
GPT-3:
- Trained on ~300 billion tokens
- Requires massive datasets
- Better performance with more data
Our Model:
- Can train on any amount of data
- More data = better performance
- Scales with dataset size
5.6 How Training Works: Step-by-Step
High-Level Overview
1. Initialize model (random weights)
2. For each epoch:
a. For each batch:
- Forward pass (make predictions)
- Compute loss (measure error)
- Backward pass (compute gradients)
- Update weights (reduce error)
3. Evaluate model
4. Repeat until convergence
Detailed Step-by-Step
Step 1: Initialize Model
Start with random weights:
Embedding weights: Random values
Attention weights: Random values
FFN weights: Random values
→ Model makes random predictions
Example:
Weight initialization: [-0.1, 0.05, 0.2, ...] (random)
Initial prediction: Random token (meaningless)
Step 2: Forward Pass
Process input through model:
Input: "Hello"
↓
Embedding: [0.1, -0.2, 0.3, ...]
↓
Attention: [0.15, 0.08, 0.22, ...]
↓
FFN: [0.18, -0.12, 0.24, ...]
↓
Output: [logits for all tokens]
↓
Prediction: "apple" (highest logit)
Step 3: Compute Loss
Compare prediction to correct answer:
Expected: "World"
Predicted: "apple"
Loss: High (wrong prediction)
Loss Function (Cross-Entropy):
Loss = -log(P(predicted = correct))
Example:
Correct token: "World" (ID 87)
Predicted probability: 0.05 (5%)
Loss = -log(0.05) ≈ 2.996 (high loss)
Step 4: Backward Pass
Compute gradients:
Loss: 2.996
↓
Compute gradients: ∂Loss/∂weights
↓
Gradients: [0.5, -0.3, 0.8, ...]
↓
Shows direction to reduce loss
Meaning: Gradients tell us how to adjust weights to reduce error
Step 5: Update Weights
Adjust weights using optimizer:
Current weights: [0.1, -0.2, 0.3, ...]
Gradients: [0.5, -0.3, 0.8, ...]
Learning rate: 0.0001
New weights = Old weights - Learning rate × Gradients
= [0.1, -0.2, 0.3, ...] - 0.0001 × [0.5, -0.3, 0.8, ...]
= [0.09995, -0.19997, 0.29992, ...]
Result: Weights slightly adjusted to reduce loss
Step 6: Repeat
Process repeats for millions of examples:
Batch 1: Update weights slightly
Batch 2: Update weights slightly
Batch 3: Update weights slightly
...
Batch 1,000,000: Update weights slightly
Result: Cumulative improvements → Model learns!
5.7 The Training Process
Complete Training Loop
For each epoch:
Epoch 1:
Batch 1: [Input: "Hello", Output: "World"] → Loss: 2.996 → Update
Batch 2: [Input: "Machine", Output: "learning"] → Loss: 3.2 → Update
Batch 3: [Input: "The cat", Output: "sat"] → Loss: 3.1 → Update
...
Average Loss: 3.05
Epoch 2:
Batch 1: [Input: "Hello", Output: "World"] → Loss: 2.5 → Update
Batch 2: [Input: "Machine", Output: "learning"] → Loss: 2.8 → Update
Batch 3: [Input: "The cat", Output: "sat"] → Loss: 2.7 → Update
...
Average Loss: 2.65 (improved!)
Epoch 3:
...
Average Loss: 2.3 (improved!)
...
Epoch 10:
...
Average Loss: 1.2 (much better!)
Key Training Concepts
Epoch:
- One complete pass through the training data
- All examples seen once
Batch:
- Small group of examples processed together
- Enables efficient training
Iteration:
- Processing one batch
- One weight update
Loss:
- Measure of prediction error
- Lower is better
Learning Rate:
- How much to adjust weights
- Controls training speed
Training Metrics
Loss Over Time:
Loss
│
4.0│●
│
3.0│ ●
│
2.0│ ●
│
1.0│ ●
│
0.0├──────────────── Epochs
0 2 4 6 8 10
Decreasing loss = Model learning!
5.8 Loss Function
What is Loss?
Loss measures how wrong the model is:
Low Loss:
Prediction: "World" (95% confidence)
Correct: "World"
Loss: 0.05 (very low, almost perfect!)
High Loss:
Prediction: "apple" (10% confidence)
Correct: "World"
Loss: 2.3 (high, very wrong!)
Cross-Entropy Loss
Formula:
L = -\frac{1}{N} \sum_{i=1}^{N} \log P(y_i | x_i)
Where:
N= number of tokensy_i= correct tokenp(y_i | x_i)= predicted probability of correct token
Example:
Input: "Hello"
Correct: "World"
Predicted probabilities:
"World": 0.05 (5%)
"there": 0.03 (3%)
"Hello": 0.02 (2%)
...
Loss:
L = -log(0.05) ≈ 2.996
Meaning: Model is uncertain, high loss
After Training:
"World": 0.85 (85%)
"there": 0.10 (10%)
"Hello": 0.03 (3%)
...
Loss = -log(0.85) ≈ 0.162
Meaning: Model is confident, low loss!
Why Cross-Entropy?
Properties:
- Penalizes confident wrong predictions: High loss for wrong + confident
- Rewards confident correct predictions: Low loss for correct + confident
- Smooth gradient: Easy to optimize
- Probabilistic interpretation: Works with probabilities
5.9 Optimization
What is Optimization?
Optimization = Finding best weights
Goal:
Minimize Loss(weights)
How:
1. Compute gradients
2. Update weights in direction that reduces loss
3. Repeat until convergence
AdamW Optimizer
Our model uses AdamW:
Why AdamW?
- Adaptive learning rate per parameter
- Handles sparse gradients well
- Weight decay for regularization
- Works well for transformers
How it works:
Step 1: Compute Gradients
g_t = ∂Loss/∂weights
Step 2: Update Momentum
m_t = β₁ × m_{t-1} + (1 - β₁) × g_t
Step 3: Update Variance
v_t = β₂ × v_{t-1} + (1 - β₂) × g_t²
Step 4: Update Weights
weights_t = weights_{t-1} - lr × (m_t / (√v_t + ε)) - λ × weights_{t-1}
Where:
- β₁ = 0.9 (momentum decay)
- β₂ = 0.999 (variance decay)
- lr = learning rate
- λ = weight decay
- ε = small constant
Result: Efficient weight updates!
Learning Rate Scheduling
Cosine Annealing:
Start: High learning rate (fast learning)
Middle: Decreasing learning rate
End: Low learning rate (fine-tuning)
Visualization:
Learning Rate
│
0.001│●───────────────
│ \
│ \
│ \
│ \
│ \
0.000│ ●─────
└────────────────────────── Steps
Training Progress
Benefits:
- Fast initial learning
- Stable convergence
- Better final performance
5.10 Evaluation
Why Evaluate?
Check if model learned:
Training Loss: 0.5 (low)
→ Model learned training data well
But is it good on new data?
Evaluation Metrics
1. Loss (Perplexity)
Lower is better
Measures prediction uncertainty
2. Accuracy
Percentage of correct predictions
Higher is better
3. Perplexity
Perplexity = exp(loss)
Lower is better
Measures "surprise" of model
Example:
Loss: 2.0
Perplexity: exp(2.0) ≈ 7.39
Meaning: Model is "surprised" by about 7.39 choices on average
Lower perplexity = Better predictions
Validation Set
Separate data for evaluation:
Training Set: 80% (learn from this)
Validation Set: 20% (test on this)
Train on training set
Evaluate on validation set
→ See if model generalizes!
Why Separate?
- Test on unseen data
- Detect overfitting
- Measure real performance
5.11 Common Questions
Q1: How long does training take?
Answer: Depends on:
- Dataset size
- Model size
- Hardware (GPU/CPU)
- Number of epochs
Example:
Small model (1M params), 1M tokens:
- CPU: Days
- GPU: Hours
Large model (100M params), 100M tokens:
- CPU: Weeks
- GPU: Days
Q2: When should training stop?
Answer:
- When validation loss stops improving
- After fixed number of epochs
- When loss converges
- When overfitting detected
Early Stopping:
If validation loss doesn't improve for N epochs:
→ Stop training
→ Prevent overfitting
Q3: Why does loss sometimes increase?
Answer: Normal! Can happen due to:
- Learning rate too high
- Difficult batch
- Optimization noise
- Normal fluctuations
Long-term trend should decrease:
Loss: 3.0 → 2.8 → 2.9 → 2.7 → 2.8 → 2.6
↑ ↑ ↑ ↑ ↑ ↑
Small increases OK, overall decreasing
Q4: Can I train on different types of data?
Answer: Yes! Model learns from whatever data you provide:
Books → Learns literary style
Code → Learns programming patterns
Scientific papers → Learns technical language
Mixed → Learns diverse patterns
More diverse data = More versatile model
Q5: What if I don't have much data?
Answer:
- Can still train with small datasets
- May need more epochs
- May need smaller model
- Consider data augmentation
However:
- More data almost always better
- Try to collect more if possible
Q6: How do I know if training is working?
Answer: Check:
- Loss decreasing over time ✓
- Validation loss improving ✓
- Predictions getting better ✓
- Model generating reasonable text ✓
Signs of problems:
- Loss not decreasing → Check learning rate
- Loss increasing → Check data or model
- Predictions random → Check training
Q7: What's the difference between training and inference?
Answer:
Training:
- Model learns from data
- Updates weights
- Computes gradients
- Optimizes parameters
Inference:
- Model makes predictions
- Fixed weights (no updates)
- No gradients computed
- Just forward pass
Analogy:
- Training: Student studying (learning)
- Inference: Student taking exam (using knowledge)
5.12 Training Metrics and Artifacts
When you run training locally, the system automatically generates several files to help you monitor and understand the training process. These files are saved in your checkpoint directory (default: ./checkpoints or ./checkpoints_test).
Generated Files
After training completes (or during training), you'll find these files:
training_metrics.json- Complete training history in JSON formattraining_curve.png- Visual plots of loss and learning rate over timeloss_by_epoch.png- Average loss per epoch visualization
training_metrics.json
Location: checkpoints/training_metrics.json (or your configured save directory)
Contents:
This JSON file contains the complete training history with the following fields:
{
"train_loss": [4.19, 3.70, 3.29, ...], // Training loss at each logging step
"val_loss": [null, null, null, ...], // Validation loss (null if not evaluated)
"learning_rate": [0.0001, 0.0001, ...], // Learning rate at each step
"epochs": [0, 0, 0, ...], // Epoch number for each step
"steps": [5, 10, 15, ...] // Global training step number
}
What Each Field Means:
train_loss: Array of training loss values. Lower is better. Shows how well the model fits the training data.val_loss: Array of validation loss values (ornullif validation wasn't run). Lower is better. Shows generalization to unseen data.learning_rate: Array of learning rate values. Shows how the learning rate scheduler adjusted the learning rate over time.epochs: Array indicating which epoch each metric was recorded in.steps: Array of global step numbers. Each step represents one batch processed.
How to Use:
import json
# Load metrics
with open('checkpoints/training_metrics.json', 'r') as f:
metrics = json.load(f)
# Get final training loss
final_loss = metrics['train_loss'][-1]
print(f"Final training loss: {final_loss:.4f}")
# Find minimum validation loss
val_losses = [v for v in metrics['val_loss'] if v is not None]
if val_losses:
min_val_loss = min(val_losses)
print(f"Best validation loss: {min_val_loss:.4f}")
# Calculate average loss per epoch
epoch_0_losses = [metrics['train_loss'][i]
for i, e in enumerate(metrics['epochs']) if e == 0]
avg_epoch_0_loss = sum(epoch_0_losses) / len(epoch_0_losses)
print(f"Average loss for epoch 0: {avg_epoch_0_loss:.4f}")
training_curve.png
Location: checkpoints/training_curve.png
What It Shows:
This plot contains two subplots:
-
Top Plot: Training and Validation Loss
- X-axis: Training steps
- Y-axis: Loss value
- Blue line: Training loss over time
- Red line: Validation loss (if available)
- Shows how loss decreases during training
-
Bottom Plot: Learning Rate Schedule
- X-axis: Training steps
- Y-axis: Learning rate (log scale)
- Green line: Learning rate over time
- Shows how the learning rate scheduler adjusted the learning rate
How to Interpret:
Good Training:
Example training curve showing smooth loss decrease and learning rate schedule. Your actual plot will be saved in your checkpoint directory.
Signs of Problems:
- Loss not decreasing: Learning rate too low, or model too small
- Loss increasing: Learning rate too high, or data issues
- Loss oscillating wildly: Learning rate too high
- Training loss much lower than validation loss: Overfitting
Example from Your Training:
Based on your training_metrics.json, your training shows:
- Initial loss: ~4.19 (high, model is random)
- Final loss: ~0.92 (much lower, model learned!)
- Smooth decrease: Training progressed well
- Learning rate decayed from ~0.0001 to near zero: Proper cosine annealing schedule
loss_by_epoch.png
Location: checkpoints/loss_by_epoch.png
What It Shows:
- X-axis: Epoch number
- Y-axis: Average loss for that epoch
- Single data point per epoch
- Shows overall training progress at epoch level
How to Interpret:
This plot gives you a high-level view of training progress:
Good Training:
Example loss by epoch plot showing steady decrease. Your actual plot will be saved in your checkpoint directory.
What to Look For:
- Decreasing trend: Model is learning ✓
- Plateau: Model may have converged
- Increasing: Possible overfitting or learning rate issues
Interpreting Your Training Results
Based on the metrics from your local training run:
Training Progress:
- Started at loss ~4.19 (random initialization)
- Ended at loss ~0.92 (significant improvement!)
- Total steps: ~5,625 steps
- Loss decreased smoothly throughout training
Learning Rate Schedule:
- Started at ~0.0001 (1e-4)
- Followed cosine annealing schedule
- Decayed smoothly to near zero
- Proper warmup and decay phases
What This Means:
- ✅ Training was successful - loss decreased significantly
- ✅ Learning rate schedule worked correctly
- ✅ Model learned patterns from the training data
- ✅ No signs of overfitting (smooth decrease, no sudden spikes)
Using Metrics for Debugging
Problem: Loss Not Decreasing
# Check learning rate
metrics = json.load(open('checkpoints/training_metrics.json'))
initial_lr = metrics['learning_rate'][0]
final_lr = metrics['learning_rate'][-1]
print(f"LR: {initial_lr} -> {final_lr}")
# If LR is too low, increase in config
# If LR is too high, decrease in config
Problem: Overfitting
# Compare train vs validation loss
train_losses = metrics['train_loss']
val_losses = [v for v in metrics['val_loss'] if v is not None]
if val_losses:
final_train = train_losses[-1]
final_val = val_losses[-1]
gap = final_val - final_train
if gap > 0.5:
print("Warning: Large gap suggests overfitting")
print("Consider: More data, regularization, or early stopping")
Problem: Training Too Slow
# Check loss decrease rate
losses = metrics['train_loss']
initial_loss = losses[0]
final_loss = losses[-1]
steps = len(losses)
decrease_rate = (initial_loss - final_loss) / steps
print(f"Loss decrease per step: {decrease_rate:.6f}")
# If too slow, consider:
# - Increase learning rate
# - Increase batch size
# - Check data quality
Best Practices
- Monitor During Training: Check
training_metrics.jsonperiodically to catch issues early - Save Checkpoints: The metrics file is updated continuously, so you can monitor progress even if training is interrupted
- Compare Runs: Save metrics from different training runs to compare hyperparameters
- Visual Inspection: Always look at the plots - they reveal patterns that numbers alone don't show
- Early Stopping: Use validation loss from metrics to implement early stopping if needed
Example: Analyzing Your Training Run
import json
import matplotlib.pyplot as plt
# Load your training metrics
with open('checkpoints_test/training_metrics.json', 'r') as f:
metrics = json.load(f)
# Quick analysis
print("=== Training Summary ===")
print(f"Total steps: {len(metrics['steps'])}")
print(f"Initial loss: {metrics['train_loss'][0]:.4f}")
print(f"Final loss: {metrics['train_loss'][-1]:.4f}")
print(f"Loss reduction: {metrics['train_loss'][0] - metrics['train_loss'][-1]:.4f}")
print(f"Reduction percentage: {(1 - metrics['train_loss'][-1]/metrics['train_loss'][0])*100:.1f}%")
# Check learning rate schedule
lr_values = [lr for lr in metrics['learning_rate'] if lr is not None]
if lr_values:
print(f"\nLearning Rate:")
print(f" Initial: {lr_values[0]:.6f}")
print(f" Final: {lr_values[-1]:.6f}")
print(f" Decay factor: {lr_values[-1]/lr_values[0]:.6f}")
# Find best checkpoint (lowest loss)
best_step_idx = metrics['train_loss'].index(min(metrics['train_loss']))
best_step = metrics['steps'][best_step_idx]
best_loss = metrics['train_loss'][best_step_idx]
print(f"\nBest checkpoint:")
print(f" Step: {best_step}")
print(f" Loss: {best_loss:.4f}")
Summary
What is Training?
Training is teaching the model to make accurate predictions by:
- Showing examples
- Computing errors
- Adjusting parameters
- Repeating millions of times
Why We Need Data
Data provides:
- Examples to learn from
- Patterns to discover
- Ground truth to compare
- Evaluation to measure progress
Why More Data is Better
More data enables:
- Better pattern recognition
- Broader coverage
- Better generalization
- Reduced overfitting
- Statistical confidence
Training Process
1. Initialize model (random weights)
2. Forward pass (make predictions)
3. Compute loss (measure error)
4. Backward pass (compute gradients)
5. Update weights (reduce error)
6. Repeat for many epochs
Key Takeaways
✅ Training teaches models to make predictions
✅ Models learn from data, not rules
✅ More data = Better performance
✅ Loss measures prediction error
✅ Optimization updates weights to reduce loss
✅ Evaluation checks if model learned correctly
This document provides a comprehensive explanation of model training, why we need data, and why more data leads to better performance in transformer models.

