Initial commit: SheepOp LLM - Transformer-based language model implementation
- Complete transformer implementation from scratch - Training pipeline with gradient accumulation and mixed precision - Optimized inference with KV caching - Multi-format data processing (PDFs, images, code, text) - Comprehensive documentation - Apache 2.0 license - Example training plots included in docs/images/
This commit is contained in:
848
docs/OPTIMIZATION_EXPLAINED.md
Normal file
848
docs/OPTIMIZATION_EXPLAINED.md
Normal file
@@ -0,0 +1,848 @@
|
||||
# What is Optimization? Step-by-Step Explanation
|
||||
|
||||
Complete step-by-step explanation of optimization in neural networks: how optimizers update weights to minimize loss.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [What is Optimization?](#71-what-is-optimization)
|
||||
2. [The Optimization Problem](#72-the-optimization-problem)
|
||||
3. [Gradient Descent](#73-gradient-descent)
|
||||
4. [AdamW Optimizer](#74-adamw-optimizer)
|
||||
5. [Why Optimization Matters](#75-why-optimization-matters)
|
||||
6. [Complete Mathematical Formulation](#76-complete-mathematical-formulation)
|
||||
7. [Exercise: Optimizer Step-by-Step](#77-exercise-optimizer-step-by-step)
|
||||
8. [Key Takeaways](#78-key-takeaways)
|
||||
|
||||
---
|
||||
|
||||
## 7.1 What is Optimization?
|
||||
|
||||
### Simple Definition
|
||||
|
||||
**Optimization** is the process of finding the best set of weights (parameters) that minimize the loss function and make the model's predictions as accurate as possible.
|
||||
|
||||
### Visual Analogy
|
||||
|
||||
**Think of optimization like finding the lowest point in a valley:**
|
||||
|
||||
```
|
||||
Loss Landscape:
|
||||
|
||||
High Loss
|
||||
│
|
||||
│ ● (current position)
|
||||
│ ╱│╲
|
||||
│ ╱ │ ╲
|
||||
│ ╱ │ ╲
|
||||
│╱ │ ╲
|
||||
│ ▼ │
|
||||
│ (goal)│
|
||||
│ │
|
||||
Low Loss ─────┘
|
||||
```
|
||||
|
||||
**Goal:** Find the bottom of the valley (minimum loss)
|
||||
|
||||
**Optimizer:** Your guide down the mountain
|
||||
|
||||
### What Optimization Does
|
||||
|
||||
**Optimization:**
|
||||
1. **Measures** how wrong the model is (loss)
|
||||
2. **Calculates** direction to improve (gradients)
|
||||
3. **Updates** weights to reduce loss
|
||||
4. **Repeats** until convergence
|
||||
|
||||
**Result:** Model learns to make accurate predictions!
|
||||
|
||||
### Optimization Process Flow
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
Start[Training Start] --> Init[Initialize Weights<br/>Random Values]
|
||||
Init --> Loop[Training Loop]
|
||||
|
||||
Loop --> Forward[Forward Pass<br/>Model Prediction]
|
||||
Forward --> Loss["Compute Loss<br/>L = loss(pred, target)"]
|
||||
Loss --> Check{Converged?}
|
||||
|
||||
Check -->|Yes| End[Training Complete]
|
||||
Check -->|No| Gradient["Compute Gradients<br/>∇L = ∂L/∂θ"]
|
||||
|
||||
Gradient --> Optimize[Optimizer<br/>Update Weights]
|
||||
Optimize --> Update["New Weights<br/>θ = θ - update"]
|
||||
Update --> Loop
|
||||
|
||||
style Start fill:#e1f5ff
|
||||
style End fill:#e1ffe1
|
||||
style Optimize fill:#fff4e1
|
||||
style Check fill:#ffe1f5
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7.2 The Optimization Problem
|
||||
|
||||
### The Objective
|
||||
|
||||
**We want to minimize:**
|
||||
|
||||
```math
|
||||
L(\theta) = \frac{1}{N} \sum_{i=1}^{N} \ell(y_i, f(x_i; \theta))
|
||||
```
|
||||
|
||||
**Where:**
|
||||
- $\theta$ = all model parameters (weights)
|
||||
- $L$ = total loss
|
||||
- $\ell$ = loss function (e.g., cross-entropy)
|
||||
- $y_i$ = correct answer
|
||||
- $f(x_i; \theta)$ = model prediction
|
||||
- $N$ = number of examples
|
||||
|
||||
### The Challenge
|
||||
|
||||
**Problem:** Loss function is complex and high-dimensional
|
||||
|
||||
**Solution:** Use iterative optimization algorithms
|
||||
|
||||
**Process:**
|
||||
```
|
||||
Initialize weights randomly
|
||||
Repeat:
|
||||
1. Compute loss
|
||||
2. Compute gradients
|
||||
3. Update weights
|
||||
Until convergence
|
||||
```
|
||||
|
||||
### Optimization Problem Flowchart
|
||||
|
||||
```mermaid
|
||||
graph LR
|
||||
subgraph "Optimization Problem"
|
||||
A["Loss Function<br/>L(θ)"] --> B["Find Minimum<br/>min L(θ)"]
|
||||
B --> C["Optimal Weights<br/>θ*"]
|
||||
end
|
||||
|
||||
subgraph "Solution Approach"
|
||||
D["Initialize θ"] --> E[Iterative Updates]
|
||||
E --> F[Compute Loss]
|
||||
F --> G[Compute Gradient]
|
||||
G --> H[Update Weights]
|
||||
H --> I{Converged?}
|
||||
I -->|No| E
|
||||
I -->|Yes| C
|
||||
end
|
||||
|
||||
A -.-> F
|
||||
C -.-> C
|
||||
|
||||
style A fill:#ffcccc
|
||||
style B fill:#ffffcc
|
||||
style C fill:#ccffcc
|
||||
style E fill:#cce5ff
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7.3 Gradient Descent
|
||||
|
||||
### What is Gradient Descent?
|
||||
|
||||
**Gradient Descent** is a basic optimization algorithm that updates weights by moving in the direction of steepest descent.
|
||||
|
||||
### How It Works
|
||||
|
||||
**Step 1: Compute Gradient**
|
||||
|
||||
```math
|
||||
\nabla_\theta L = \frac{\partial L}{\partial \theta}
|
||||
```
|
||||
|
||||
**Gradient tells us:**
|
||||
- Direction: Which way to go
|
||||
- Magnitude: How steep the slope
|
||||
|
||||
**Step 2: Update Weights**
|
||||
|
||||
```math
|
||||
\theta_{t+1} = \theta_t - \eta \nabla_\theta L
|
||||
```
|
||||
|
||||
**Where:**
|
||||
- $\theta_t$ = current weights
|
||||
- $\eta$ = learning rate (step size)
|
||||
- $\nabla_\theta L$ = gradient
|
||||
|
||||
**Meaning:** Move weights in direction opposite to gradient
|
||||
|
||||
### Visual Example
|
||||
|
||||
```
|
||||
Loss Landscape (2D):
|
||||
|
||||
Gradient
|
||||
Direction
|
||||
↓
|
||||
● ──────┼───── → Lower Loss
|
||||
│
|
||||
│
|
||||
```
|
||||
|
||||
**Move in direction of negative gradient!**
|
||||
|
||||
### Gradient Descent Flowchart
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Gradient Descent Algorithm"
|
||||
Start["Start: Initialize θ₀"] --> Loop["For each iteration t"]
|
||||
|
||||
Loop --> Forward[Forward Pass<br/>Compute Predictions]
|
||||
Forward --> Loss["Compute Loss<br/>L(θₜ)"]
|
||||
Loss --> Grad["Compute Gradient<br/>g = ∇L(θₜ)"]
|
||||
|
||||
Grad --> Direction["Determine Direction<br/>-g points to minimum"]
|
||||
Direction --> Step["Take Step<br/>η × g"]
|
||||
Step --> Update["Update Weights<br/>θₜ₊₁ = θₜ - ηg"]
|
||||
|
||||
Update --> Check{"Converged?<br/>|g| < ε"}
|
||||
Check -->|No| Loop
|
||||
Check -->|Yes| End["Found Minimum<br/>θ*"]
|
||||
end
|
||||
|
||||
subgraph "Gradient Information"
|
||||
GradInfo["Gradient g contains:<br/>- Direction: Which way to go<br/>- Magnitude: How steep"]
|
||||
end
|
||||
|
||||
Grad -.-> GradInfo
|
||||
|
||||
style Start fill:#e1f5ff
|
||||
style End fill:#e1ffe1
|
||||
style Grad fill:#fff4e1
|
||||
style Check fill:#ffe1f5
|
||||
style Update fill:#ccffcc
|
||||
```
|
||||
|
||||
### Types of Gradient Descent
|
||||
|
||||
**1. Batch Gradient Descent:**
|
||||
- Uses all training examples
|
||||
- Most accurate gradients
|
||||
- Slow for large datasets
|
||||
|
||||
**2. Stochastic Gradient Descent (SGD):**
|
||||
- Uses one example at a time
|
||||
- Fast but noisy
|
||||
- Can bounce around
|
||||
|
||||
**3. Mini-Batch Gradient Descent:**
|
||||
- Uses small batch of examples
|
||||
- Balance of speed and accuracy
|
||||
- Most commonly used
|
||||
|
||||
### Gradient Descent Types Comparison
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Batch Gradient Descent"
|
||||
B1[All Training Data] --> B2[Compute Gradient<br/>on Full Dataset]
|
||||
B2 --> B3[Single Update<br/>Most Accurate]
|
||||
B3 --> B4["Slow: O(N)"]
|
||||
end
|
||||
|
||||
subgraph "Stochastic Gradient Descent"
|
||||
S1[Single Example] --> S2[Compute Gradient<br/>on One Sample]
|
||||
S2 --> S3[Many Updates<br/>Fast but Noisy]
|
||||
S3 --> S4["Fast: O(1)"]
|
||||
end
|
||||
|
||||
subgraph "Mini-Batch Gradient Descent"
|
||||
M1[Small Batch<br/>32-256 samples] --> M2[Compute Gradient<br/>on Batch]
|
||||
M2 --> M3[Balanced Updates<br/>Good Accuracy]
|
||||
M3 --> M4["Fast: O(batch_size)"]
|
||||
end
|
||||
|
||||
style B3 fill:#ccffcc
|
||||
style S3 fill:#ffcccc
|
||||
style M3 fill:#fff4e1
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7.4 AdamW Optimizer
|
||||
|
||||
### What is AdamW?
|
||||
|
||||
**AdamW** (Adam with Weight Decay) is an advanced optimizer that combines:
|
||||
- **Adaptive learning rates** (like Adam)
|
||||
- **Weight decay** (regularization)
|
||||
|
||||
**Why AdamW?**
|
||||
- Per-parameter learning rates
|
||||
- Handles sparse gradients well
|
||||
- Works great for transformers
|
||||
|
||||
### How AdamW Works
|
||||
|
||||
**Step 1: Compute Gradient**
|
||||
|
||||
```math
|
||||
g_t = \nabla_\theta L(\theta_t)
|
||||
```
|
||||
|
||||
**Step 2: Update Momentum**
|
||||
|
||||
```math
|
||||
m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t
|
||||
```
|
||||
|
||||
**Where:**
|
||||
- $\beta_1 = 0.9$ (momentum decay)
|
||||
- $m_t$ = first moment estimate
|
||||
|
||||
**Meaning:** Moving average of gradients
|
||||
|
||||
**Step 3: Update Variance**
|
||||
|
||||
```math
|
||||
v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2
|
||||
```
|
||||
|
||||
**Where:**
|
||||
- $\beta_2 = 0.999$ (variance decay)
|
||||
- $v_t$ = second moment estimate
|
||||
|
||||
**Meaning:** Moving average of squared gradients
|
||||
|
||||
**Step 4: Bias Correction**
|
||||
|
||||
```math
|
||||
\hat{m}_t = \frac{m_t}{1 - \beta_1^t}
|
||||
```
|
||||
|
||||
```math
|
||||
\hat{v}_t = \frac{v_t}{1 - \beta_2^t}
|
||||
```
|
||||
|
||||
**Why?** Corrects bias in early iterations
|
||||
|
||||
**Step 5: Update Weights**
|
||||
|
||||
```math
|
||||
\theta_{t+1} = \theta_t - \eta \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \theta_t \right)
|
||||
```
|
||||
|
||||
**Where:**
|
||||
- $\eta$ = learning rate
|
||||
- $\epsilon = 10^{-8}$ (small constant)
|
||||
- $\lambda$ = weight decay coefficient
|
||||
|
||||
**Key Points:**
|
||||
- $\frac{\hat{m}_t}{\sqrt{\hat{v}_t}}$ = adaptive learning rate per parameter
|
||||
- $\lambda \theta_t$ = weight decay (regularization)
|
||||
|
||||
### AdamW Optimizer Flowchart
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "AdamW Optimization Process"
|
||||
Start["Start: Initialize<br/>θ₀, m₀=0, v₀=0"] --> Loop["For each iteration t"]
|
||||
|
||||
Loop --> Forward["Forward Pass<br/>Compute Loss L(θₜ)"]
|
||||
Forward --> Grad["Step 1: Compute Gradient<br/>gₜ = ∇L(θₜ)"]
|
||||
|
||||
Grad --> Mom["Step 2: Update Momentum<br/>mₜ = β₁mₜ₋₁ + (1-β₁)gₜ"]
|
||||
Mom --> Var["Step 3: Update Variance<br/>vₜ = β₂vₜ₋₁ + (1-β₂)gₜ²"]
|
||||
|
||||
Var --> Bias["Step 4: Bias Correction<br/>m̂ₜ = mₜ/(1-β₁ᵗ)<br/>v̂ₜ = vₜ/(1-β₂ᵗ)"]
|
||||
|
||||
Bias --> Adapt["Step 5: Adaptive LR<br/>LR = η/(√v̂ₜ + ε)"]
|
||||
|
||||
Adapt --> Decay["Step 6: Weight Decay<br/>λθₜ"]
|
||||
|
||||
Decay --> Update["Step 7: Update Weights<br/>θₜ₊₁ = θₜ - LR×m̂ₜ - λθₜ"]
|
||||
|
||||
Update --> Check{Converged?}
|
||||
Check -->|No| Loop
|
||||
Check -->|Yes| End["Optimal Weights θ*"]
|
||||
end
|
||||
|
||||
subgraph "Key Components"
|
||||
C1["Momentum mₜ<br/>Moving avg of gradients"]
|
||||
C2["Variance vₜ<br/>Moving avg of g²"]
|
||||
C3["Adaptive LR<br/>Per-parameter learning rate"]
|
||||
C4["Weight Decay<br/>Regularization"]
|
||||
end
|
||||
|
||||
Mom -.-> C1
|
||||
Var -.-> C2
|
||||
Adapt -.-> C3
|
||||
Decay -.-> C4
|
||||
|
||||
style Start fill:#e1f5ff
|
||||
style End fill:#e1ffe1
|
||||
style Grad fill:#fff4e1
|
||||
style Adapt fill:#ccffcc
|
||||
style Update fill:#ccffcc
|
||||
style Check fill:#ffe1f5
|
||||
```
|
||||
|
||||
### AdamW Detailed Subgraph
|
||||
|
||||
```mermaid
|
||||
graph LR
|
||||
subgraph "Input"
|
||||
I1["Gradient gₜ"]
|
||||
I2["Previous Momentum mₜ₋₁"]
|
||||
I3["Previous Variance vₜ₋₁"]
|
||||
I4["Current Weights θₜ"]
|
||||
end
|
||||
|
||||
subgraph "Momentum Update"
|
||||
M1["Multiply: β₁mₜ₋₁"] --> M2["Combine: β₁mₜ₋₁ + (1-β₁)gₜ"]
|
||||
I2 --> M1
|
||||
I1 --> M2
|
||||
end
|
||||
|
||||
subgraph "Variance Update"
|
||||
V1["Square: gₜ²"] --> V2["Combine: β₂vₜ₋₁ + (1-β₂)gₜ²"]
|
||||
I3 --> V2
|
||||
I1 --> V1
|
||||
end
|
||||
|
||||
subgraph "Bias Correction"
|
||||
M2 --> BC1["m̂ₜ = mₜ/(1-β₁ᵗ)"]
|
||||
V2 --> BC2["v̂ₜ = vₜ/(1-β₂ᵗ)"]
|
||||
end
|
||||
|
||||
subgraph "Adaptive Learning Rate"
|
||||
BC2 --> ALR["LR = η/(√v̂ₜ + ε)"]
|
||||
end
|
||||
|
||||
subgraph "Weight Update"
|
||||
BC1 --> WU1["Adaptive Step: LR × m̂ₜ"]
|
||||
ALR --> WU1
|
||||
I4 --> WU2["Decay Step: λθₜ"]
|
||||
WU1 --> WU3["Update: θₜ₊₁ = θₜ - LR×m̂ₜ - λθₜ"]
|
||||
WU2 --> WU3
|
||||
end
|
||||
|
||||
style M2 fill:#e1f5ff
|
||||
style V2 fill:#e1f5ff
|
||||
style BC1 fill:#fff4e1
|
||||
style BC2 fill:#fff4e1
|
||||
style ALR fill:#ccffcc
|
||||
style WU3 fill:#ccffcc
|
||||
```
|
||||
|
||||
### Why AdamW is Better
|
||||
|
||||
**Compared to SGD:**
|
||||
|
||||
**SGD:**
|
||||
```
|
||||
Same learning rate for all parameters
|
||||
→ Slow convergence
|
||||
→ Manual tuning needed
|
||||
```
|
||||
|
||||
**AdamW:**
|
||||
```
|
||||
Adaptive learning rate per parameter
|
||||
→ Faster convergence
|
||||
→ Less manual tuning
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
1. **Adaptive:** Each parameter gets its own learning rate
|
||||
2. **Robust:** Works well with noisy gradients
|
||||
3. **Efficient:** Converges faster than SGD
|
||||
4. **Regularized:** Weight decay prevents overfitting
|
||||
|
||||
### SGD vs AdamW Comparison
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Stochastic Gradient Descent"
|
||||
SGD1["Gradient gₜ"] --> SGD2["Fixed Learning Rate η"]
|
||||
SGD2 --> SGD3["Update: θₜ₊₁ = θₜ - ηgₜ"]
|
||||
SGD3 --> SGD4["All params same LR"]
|
||||
SGD4 --> SGD5["Slow Convergence<br/>Manual Tuning"]
|
||||
end
|
||||
|
||||
subgraph "AdamW Optimizer"
|
||||
AD1["Gradient gₜ"] --> AD2["Momentum mₜ"]
|
||||
AD1 --> AD3["Variance vₜ"]
|
||||
AD2 --> AD4[Bias Correction]
|
||||
AD3 --> AD4
|
||||
AD4 --> AD5["Adaptive LR per param"]
|
||||
AD5 --> AD6["Update: θₜ₊₁ = θₜ - LR×m̂ₜ - λθₜ"]
|
||||
AD6 --> AD7["Fast Convergence<br/>Less Tuning"]
|
||||
end
|
||||
|
||||
subgraph "Comparison"
|
||||
Comp1["Same Model<br/>Same Data"]
|
||||
Comp1 --> Comp2["SGD: Loss = 2.5<br/>After 100 epochs"]
|
||||
Comp1 --> Comp3["AdamW: Loss = 1.8<br/>After 100 epochs"]
|
||||
Comp3 --> Comp4[AdamW is Better!]
|
||||
end
|
||||
|
||||
SGD5 -.-> Comp2
|
||||
AD7 -.-> Comp3
|
||||
|
||||
style SGD5 fill:#ffcccc
|
||||
style AD7 fill:#ccffcc
|
||||
style Comp4 fill:#e1ffe1
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7.5 Why Optimization Matters
|
||||
|
||||
### Reason 1: Without Optimization
|
||||
|
||||
**Random weights:**
|
||||
```
|
||||
Weights: Random values
|
||||
Loss: Very high
|
||||
Predictions: Random
|
||||
Model: Useless
|
||||
```
|
||||
|
||||
### Reason 2: With Optimization
|
||||
|
||||
**Learned weights:**
|
||||
```
|
||||
Weights: Optimized values
|
||||
Loss: Low
|
||||
Predictions: Accurate
|
||||
Model: Useful
|
||||
```
|
||||
|
||||
### Reason 3: Determines Learning Speed
|
||||
|
||||
**Good optimizer:**
|
||||
- Fast convergence
|
||||
- Stable training
|
||||
- Good final performance
|
||||
|
||||
**Poor optimizer:**
|
||||
- Slow convergence
|
||||
- Unstable training
|
||||
- Poor final performance
|
||||
|
||||
### Reason 4: Affects Final Performance
|
||||
|
||||
**Same model, different optimizers:**
|
||||
|
||||
```
|
||||
SGD: Loss = 2.5 (after 100 epochs)
|
||||
AdamW: Loss = 1.8 (after 100 epochs)
|
||||
```
|
||||
|
||||
**Better optimizer = Better model!**
|
||||
|
||||
### Optimization Impact Visualization
|
||||
|
||||
```mermaid
|
||||
graph LR
|
||||
subgraph "Without Optimization"
|
||||
WO1[Random Weights] --> WO2["High Loss<br/>L ≈ 8-10"]
|
||||
WO2 --> WO3[Random Predictions]
|
||||
WO3 --> WO4[Model Useless]
|
||||
end
|
||||
|
||||
subgraph "With Optimization"
|
||||
W1[Random Weights] --> W2[Optimization Loop]
|
||||
W2 --> W3[Update Weights]
|
||||
W3 --> W4["Low Loss<br/>L ≈ 1-2"]
|
||||
W4 --> W5[Accurate Predictions]
|
||||
W5 --> W6[Model Useful]
|
||||
end
|
||||
|
||||
subgraph "Optimizer Quality"
|
||||
O1["Poor Optimizer<br/>SGD, Bad LR"] --> O2["Slow Convergence<br/>Loss = 2.5"]
|
||||
O3["Good Optimizer<br/>AdamW, Proper LR"] --> O4["Fast Convergence<br/>Loss = 1.8"]
|
||||
end
|
||||
|
||||
W2 -.-> O1
|
||||
W2 -.-> O3
|
||||
|
||||
style WO4 fill:#ffcccc
|
||||
style W6 fill:#ccffcc
|
||||
style O2 fill:#ffcccc
|
||||
style O4 fill:#ccffcc
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7.6 Complete Mathematical Formulation
|
||||
|
||||
### Optimization Problem
|
||||
|
||||
```math
|
||||
\theta^* = \arg\min_{\theta} L(\theta)
|
||||
```
|
||||
|
||||
**Where $\theta^*$ is the optimal set of weights**
|
||||
|
||||
### Gradient Descent Update
|
||||
|
||||
```math
|
||||
\theta_{t+1} = \theta_t - \eta \nabla_\theta L(\theta_t)
|
||||
```
|
||||
|
||||
### AdamW Update (Complete)
|
||||
|
||||
**For each parameter $\theta_i$:**
|
||||
|
||||
**Gradient:**
|
||||
```math
|
||||
g_{t,i} = \frac{\partial L}{\partial \theta_{t,i}}
|
||||
```
|
||||
|
||||
**Momentum:**
|
||||
```math
|
||||
m_{t,i} = \beta_1 m_{t-1,i} + (1 - \beta_1) g_{t,i}
|
||||
```
|
||||
|
||||
**Variance:**
|
||||
```math
|
||||
v_{t,i} = \beta_2 v_{t-1,i} + (1 - \beta_2) g_{t,i}^2
|
||||
```
|
||||
|
||||
**Bias Correction:**
|
||||
```math
|
||||
\hat{m}_{t,i} = \frac{m_{t,i}}{1 - \beta_1^t}
|
||||
```
|
||||
|
||||
```math
|
||||
\hat{v}_{t,i} = \frac{v_{t,i}}{1 - \beta_2^t}
|
||||
```
|
||||
|
||||
**Update:**
|
||||
```math
|
||||
\theta_{t+1,i} = \theta_{t,i} - \eta \left( \frac{\hat{m}_{t,i}}{\sqrt{\hat{v}_{t,i}} + \epsilon} + \lambda \theta_{t,i} \right)
|
||||
```
|
||||
|
||||
**Where:**
|
||||
- $\beta_1 = 0.9$
|
||||
- $\beta_2 = 0.999$
|
||||
- $\epsilon = 10^{-8}$
|
||||
- $\lambda$ = weight decay (e.g., 0.01)
|
||||
|
||||
### Complete Mathematical Flow
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Optimization Problem"
|
||||
OP1["Loss Function L(θ)"] --> OP2["Find: θ* = argmin L(θ)"]
|
||||
end
|
||||
|
||||
subgraph "Gradient Computation"
|
||||
GC1[Forward Pass] --> GC2[Compute Loss L]
|
||||
GC2 --> GC3[Backpropagation]
|
||||
GC3 --> GC4["Gradient gᵢ = ∂L/∂θᵢ"]
|
||||
end
|
||||
|
||||
subgraph "AdamW Update Steps"
|
||||
GC4 --> AU1["Momentum: mᵢ = β₁m + (1-β₁)gᵢ"]
|
||||
AU1 --> AU2["Variance: vᵢ = β₂v + (1-β₂)gᵢ²"]
|
||||
AU2 --> AU3["Bias Correction:<br/>m̂ = m/(1-β₁ᵗ), v̂ = v/(1-β₂ᵗ)"]
|
||||
AU3 --> AU4["Adaptive LR: η/(√v̂ + ε)"]
|
||||
AU4 --> AU5["Update: θᵢ = θᵢ - LR×m̂ - λθᵢ"]
|
||||
end
|
||||
|
||||
subgraph "Convergence Check"
|
||||
AU5 --> CC1{Converged?}
|
||||
CC1 -->|No| GC1
|
||||
CC1 -->|Yes| CC2["Optimal Weights θ*"]
|
||||
end
|
||||
|
||||
OP2 -.-> GC1
|
||||
CC2 -.-> OP2
|
||||
|
||||
style OP2 fill:#ffffcc
|
||||
style GC4 fill:#fff4e1
|
||||
style AU5 fill:#ccffcc
|
||||
style CC2 fill:#e1ffe1
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7.7 Exercise: Optimizer Step-by-Step
|
||||
|
||||
### Problem
|
||||
|
||||
**Given:**
|
||||
- Current weight: $\theta_0 = 2.0$
|
||||
- Loss function: $L(\theta) = (\theta - 1)^2$
|
||||
- Learning rate: $\eta = 0.1$
|
||||
- Use AdamW with $\beta_1 = 0.9$, $\beta_2 = 0.999$, $\lambda = 0.01$
|
||||
- Initial moments: $m_0 = 0$, $v_0 = 0$
|
||||
|
||||
**Calculate the weight update for step 1.**
|
||||
|
||||
### Step-by-Step Solution
|
||||
|
||||
#### Step 1: Compute Gradient
|
||||
|
||||
**Loss function:**
|
||||
```math
|
||||
L(\theta) = (\theta - 1)^2
|
||||
```
|
||||
|
||||
**Gradient:**
|
||||
```math
|
||||
g_1 = \frac{\partial L}{\partial \theta} = 2(\theta - 1)
|
||||
```
|
||||
|
||||
**At $\theta_0 = 2.0$:**
|
||||
```math
|
||||
g_1 = 2(2.0 - 1) = 2(1.0) = 2.0
|
||||
```
|
||||
|
||||
#### Step 2: Update Momentum
|
||||
|
||||
```math
|
||||
m_1 = \beta_1 m_0 + (1 - \beta_1) g_1
|
||||
```
|
||||
|
||||
```math
|
||||
m_1 = 0.9 \times 0 + (1 - 0.9) \times 2.0 = 0 + 0.1 \times 2.0 = 0.2
|
||||
```
|
||||
|
||||
#### Step 3: Update Variance
|
||||
|
||||
```math
|
||||
v_1 = \beta_2 v_0 + (1 - \beta_2) g_1^2
|
||||
```
|
||||
|
||||
```math
|
||||
v_1 = 0.999 \times 0 + (1 - 0.999) \times (2.0)^2 = 0 + 0.001 \times 4.0 = 0.004
|
||||
```
|
||||
|
||||
#### Step 4: Bias Correction
|
||||
|
||||
```math
|
||||
\hat{m}_1 = \frac{m_1}{1 - \beta_1^1} = \frac{0.2}{1 - 0.9} = \frac{0.2}{0.1} = 2.0
|
||||
```
|
||||
|
||||
```math
|
||||
\hat{v}_1 = \frac{v_1}{1 - \beta_2^1} = \frac{0.004}{1 - 0.999} = \frac{0.004}{0.001} = 4.0
|
||||
```
|
||||
|
||||
#### Step 5: Compute Update
|
||||
|
||||
```math
|
||||
\Delta \theta_1 = \eta \left( \frac{\hat{m}_1}{\sqrt{\hat{v}_1} + \epsilon} + \lambda \theta_0 \right)
|
||||
```
|
||||
|
||||
```math
|
||||
\Delta \theta_1 = 0.1 \left( \frac{2.0}{\sqrt{4.0} + 10^{-8}} + 0.01 \times 2.0 \right)
|
||||
```
|
||||
|
||||
```math
|
||||
\Delta \theta_1 = 0.1 \left( \frac{2.0}{2.0 + 0.00000001} + 0.02 \right)
|
||||
```
|
||||
|
||||
```math
|
||||
\Delta \theta_1 = 0.1 \left( \frac{2.0}{2.0} + 0.02 \right) = 0.1 (1.0 + 0.02) = 0.1 \times 1.02 = 0.102
|
||||
```
|
||||
|
||||
#### Step 6: Update Weight
|
||||
|
||||
```math
|
||||
\theta_1 = \theta_0 - \Delta \theta_1 = 2.0 - 0.102 = 1.898
|
||||
```
|
||||
|
||||
### Answer
|
||||
|
||||
**After one step:**
|
||||
- Old weight: $\theta_0 = 2.0$
|
||||
- New weight: $\theta_1 = 1.898$
|
||||
- Update: $\Delta \theta_1 = -0.102$
|
||||
|
||||
**The weight moved closer to the optimal value (1.0)!**
|
||||
|
||||
### Verification
|
||||
|
||||
**Check loss:**
|
||||
- Old loss: $L(2.0) = (2.0 - 1)^2 = 1.0$
|
||||
- New loss: $L(1.898) = (1.898 - 1)^2 = 0.806$
|
||||
|
||||
**Loss decreased! ✓**
|
||||
|
||||
### Exercise Solution Flowchart
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Given Values"
|
||||
G1["θ₀ = 2.0"] --> Start
|
||||
G2["m₀ = 0"] --> Start
|
||||
G3["v₀ = 0"] --> Start
|
||||
G4["η = 0.1, β₁ = 0.9<br/>β₂ = 0.999, λ = 0.01"] --> Start
|
||||
end
|
||||
|
||||
Start[Start] --> Step1["Step 1: Compute Gradient<br/>L(θ) = (θ-1)²<br/>g₁ = 2(θ₀-1) = 2.0"]
|
||||
|
||||
Step1 --> Step2["Step 2: Update Momentum<br/>m₁ = 0.9×0 + 0.1×2.0<br/>m₁ = 0.2"]
|
||||
|
||||
Step2 --> Step3["Step 3: Update Variance<br/>v₁ = 0.999×0 + 0.001×4.0<br/>v₁ = 0.004"]
|
||||
|
||||
Step3 --> Step4["Step 4: Bias Correction<br/>m̂₁ = 0.2/(1-0.9) = 2.0<br/>v̂₁ = 0.004/(1-0.999) = 4.0"]
|
||||
|
||||
Step4 --> Step5["Step 5: Compute Update<br/>Δθ₁ = 0.1×(2.0/√4.0 + 0.01×2.0)<br/>Δθ₁ = 0.102"]
|
||||
|
||||
Step5 --> Step6["Step 6: Update Weight<br/>θ₁ = 2.0 - 0.102<br/>θ₁ = 1.898"]
|
||||
|
||||
Step6 --> Verify["Verification:<br/>L(2.0) = 1.0 → L(1.898) = 0.806<br/>Loss Decreased!"]
|
||||
|
||||
Verify --> End["Result: θ₁ = 1.898<br/>Closer to optimum θ* = 1.0"]
|
||||
|
||||
style Start fill:#e1f5ff
|
||||
style Step1 fill:#fff4e1
|
||||
style Step2 fill:#fff4e1
|
||||
style Step3 fill:#fff4e1
|
||||
style Step4 fill:#fff4e1
|
||||
style Step5 fill:#fff4e1
|
||||
style Step6 fill:#ccffcc
|
||||
style Verify fill:#e1ffe1
|
||||
style End fill:#e1ffe1
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7.8 Key Takeaways
|
||||
|
||||
### Optimization
|
||||
|
||||
✅ **Optimization finds best weights to minimize loss**
|
||||
✅ **Uses gradients to determine update direction**
|
||||
✅ **Iterative process: compute → update → repeat**
|
||||
|
||||
### Gradient Descent
|
||||
|
||||
✅ **Basic algorithm: move opposite to gradient**
|
||||
✅ **Learning rate controls step size**
|
||||
✅ **Can be slow for complex problems**
|
||||
|
||||
### AdamW
|
||||
|
||||
✅ **Advanced optimizer with adaptive learning rates**
|
||||
✅ **Each parameter gets its own learning rate**
|
||||
✅ **Combines momentum and variance estimates**
|
||||
✅ **Includes weight decay for regularization**
|
||||
✅ **Works great for transformers**
|
||||
|
||||
### Why Important
|
||||
|
||||
✅ **Determines how fast model learns**
|
||||
✅ **Affects final model performance**
|
||||
✅ **Essential for training neural networks**
|
||||
|
||||
---
|
||||
|
||||
*This document provides a comprehensive explanation of optimization in neural networks, including gradient descent and AdamW optimizer with mathematical formulations and solved exercises.*
|
||||
|
||||
Reference in New Issue
Block a user