Initial commit: SheepOp LLM - Transformer-based language model implementation
- Complete transformer implementation from scratch - Training pipeline with gradient accumulation and mixed precision - Optimized inference with KV caching - Multi-format data processing (PDFs, images, code, text) - Comprehensive documentation - Apache 2.0 license - Example training plots included in docs/images/
This commit is contained in:
729
docs/ARCHITECTURE.md
Normal file
729
docs/ARCHITECTURE.md
Normal file
@@ -0,0 +1,729 @@
|
||||
# SheepOp LLM - Complete Architecture Documentation
|
||||
|
||||
Complete documentation of the SheepOp Language Model project architecture, data flow, training pipeline, and inference system.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [System Overview](#system-overview)
|
||||
2. [Data Ingestion Pipeline](#data-ingestion-pipeline)
|
||||
3. [Training Pipeline](#training-pipeline)
|
||||
4. [Model Architecture](#model-architecture)
|
||||
5. [Inference Pipeline](#inference-pipeline)
|
||||
6. [Complete Workflow](#complete-workflow)
|
||||
|
||||
---
|
||||
|
||||
## System Overview
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Data Sources"
|
||||
A[PDF Files] --> DataProcessor
|
||||
B[Images - PNG/JPG/etc] --> DataProcessor
|
||||
C[Code Files - .py/.js/etc] --> DataProcessor
|
||||
D[Text Files - .txt/.md/etc] --> DataProcessor
|
||||
end
|
||||
|
||||
DataProcessor[DataProcessor<br/>Multi-Format Extractor] --> TextList[Text Lines]
|
||||
|
||||
TextList --> Tokenizer[SimpleTokenizer<br/>Character-Level]
|
||||
Tokenizer --> DataLoader[PyTorch DataLoader<br/>Batched Sequences]
|
||||
|
||||
DataLoader --> Trainer[Trainer<br/>Training Loop]
|
||||
|
||||
subgraph "Training Components"
|
||||
Trainer --> Model[TransformerModel]
|
||||
Trainer --> Optimizer[AdamW Optimizer]
|
||||
Trainer --> Scheduler[CosineAnnealingLR]
|
||||
Trainer --> Loss[CrossEntropyLoss]
|
||||
end
|
||||
|
||||
Model --> Checkpoint[Model Checkpoints<br/>checkpoints/*.pt]
|
||||
|
||||
Checkpoint --> Inference[Inference Script]
|
||||
Inference --> GeneratedText[Generated Text]
|
||||
|
||||
style DataProcessor fill:#e1f5ff
|
||||
style Model fill:#fff4e1
|
||||
style Trainer fill:#ffe1f5
|
||||
style Checkpoint fill:#e1ffe1
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Data Ingestion Pipeline
|
||||
|
||||
### Multi-Format Data Processing Flow
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
Start([Start:<br/>train.py --data path]) --> CheckPath{Path Type?}
|
||||
|
||||
CheckPath -->|File| SingleFile[Process Single File]
|
||||
CheckPath -->|Directory| Directory[Process Directory]
|
||||
|
||||
SingleFile --> DataProcessor[DataProcessor.process_file]
|
||||
Directory --> RecursiveScan[Recursive Directory Scan<br/>Find all files]
|
||||
|
||||
RecursiveScan --> FileType{File Extension?}
|
||||
|
||||
FileType -->|.txt/.md/.json/etc| TextExtract[Read as Text File<br/>Line by line]
|
||||
FileType -->|.py/.js/.java/etc| CodeExtract[Read as Code File<br/>Line by line]
|
||||
FileType -->|.pdf| PDFExtract[PDF Extraction<br/>PyPDF2/pdfplumber]
|
||||
FileType -->|.png/.jpg/.tiff/etc| ImageExtract[OCR Extraction<br/>pytesseract]
|
||||
FileType -->|Unknown| Fallback[Try Text Fallback]
|
||||
|
||||
PDFExtract --> PDFPages[Extract Each Page]
|
||||
PDFPages --> PDFLines[Split into Lines]
|
||||
|
||||
ImageExtract --> OCR[Perform OCR<br/>pytesseract]
|
||||
OCR --> OCRLines[Split OCR Text into Lines]
|
||||
|
||||
TextExtract --> FilterLines[Filter Lines<br/>min_length=10]
|
||||
CodeExtract --> FilterLines
|
||||
PDFLines --> FilterLines
|
||||
OCRLines --> FilterLines
|
||||
Fallback --> FilterLines
|
||||
|
||||
FilterLines --> Combine[List of Text Lines]
|
||||
Combine --> Validate{Texts Empty?}
|
||||
|
||||
Validate -->|Yes| Error[Raise Error:<br/>No data extracted]
|
||||
Validate -->|No| Success[✅ Success<br/>N text samples loaded]
|
||||
|
||||
Success --> TokenizerStep[Next: Tokenization]
|
||||
|
||||
style DataProcessor fill:#e1f5ff
|
||||
style PDFExtract fill:#ffe1f5
|
||||
style ImageExtract fill:#fff4e1
|
||||
style Success fill:#e1ffe1
|
||||
style Error fill:#ffe1e1
|
||||
```
|
||||
|
||||
### Data Processing Components
|
||||
|
||||
```mermaid
|
||||
classDiagram
|
||||
class DataProcessor {
|
||||
+process_file(file_path) Iterator[str]
|
||||
+process_directory(directory) Iterator[str]
|
||||
+process_to_list(...) List[str]
|
||||
-_process_text_file() Iterator[str]
|
||||
-_process_code_file() Iterator[str]
|
||||
-_process_pdf() Iterator[str]
|
||||
-_process_image() Iterator[str]
|
||||
-_check_dependencies()
|
||||
}
|
||||
|
||||
class SimpleTokenizer {
|
||||
+vocab: Dict[str, int]
|
||||
+inv_vocab: Dict[int, str]
|
||||
+vocab_size: int
|
||||
+encode(text: str) List[int]
|
||||
+decode(token_ids: List[int]) str
|
||||
+save_vocab(path: str)
|
||||
}
|
||||
|
||||
class TextDataset {
|
||||
+texts: List[str]
|
||||
+tokenizer: SimpleTokenizer
|
||||
+max_length: int
|
||||
+sequences: List[torch.Tensor]
|
||||
+__getitem__(idx) Dict
|
||||
+_prepare_sequences() List[Tensor]
|
||||
}
|
||||
|
||||
class DataLoader {
|
||||
+batch_size: int
|
||||
+shuffle: bool
|
||||
+num_workers: int
|
||||
+collate_fn: Callable
|
||||
}
|
||||
|
||||
DataProcessor --> TextDataset : extracts text
|
||||
SimpleTokenizer --> TextDataset : tokenizes
|
||||
TextDataset --> DataLoader : creates dataset
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Training Pipeline
|
||||
|
||||
### Complete Training Flow
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
Start([python train.py<br/>--data path]) --> Args[Parse Arguments<br/>--data, --config, --resume, --device]
|
||||
|
||||
Args --> ConfigLoad{Config File<br/>Provided?}
|
||||
ConfigLoad -->|Yes| LoadConfig[Load config.json]
|
||||
ConfigLoad -->|No| DefaultConfig[Use Default Config]
|
||||
|
||||
LoadConfig --> Config[Config Object<br/>ModelConfig<br/>TrainingConfig<br/>DataConfig<br/>seed=42]
|
||||
DefaultConfig --> Config
|
||||
|
||||
Config --> SetSeed[Set Random Seed<br/>torch.manual_seed<br/>torch.cuda.manual_seed_all<br/>CUDNN deterministic]
|
||||
|
||||
SetSeed --> Device[Detect Device<br/>CUDA/MPS/CPU]
|
||||
|
||||
Device --> DataIngestion[Data Ingestion Pipeline<br/>Extract text from all files]
|
||||
|
||||
DataIngestion --> TextList[List of Text Lines<br/>N samples]
|
||||
|
||||
TextList --> CreateTokenizer[Create SimpleTokenizer<br/>Character-level vocab]
|
||||
|
||||
CreateTokenizer --> Tokenizer[Tokenizer Ready<br/>vocab_size calculated]
|
||||
|
||||
Tokenizer --> CreateDataLoader[Create DataLoader<br/>Batch size<br/>Max length<br/>Shuffle]
|
||||
|
||||
CreateDataLoader --> TrainLoader[PyTorch DataLoader<br/>Batched sequences]
|
||||
|
||||
TrainLoader --> CheckResume{Resume<br/>Checkpoint?}
|
||||
|
||||
CheckResume -->|Yes| LoadCheckpoint[Load Checkpoint<br/>Model state<br/>Optimizer state<br/>Scheduler state<br/>Epoch/Step]
|
||||
CheckResume -->|No| CreateModel[Create New Model<br/>TransformerModel]
|
||||
|
||||
LoadCheckpoint --> CreateModel
|
||||
|
||||
CreateModel --> Model[Model Ready<br/>N parameters]
|
||||
|
||||
Model --> CreateOptimizer[Create Optimizer<br/>AdamW<br/>lr, weight_decay]
|
||||
|
||||
CreateOptimizer --> CreateScheduler[Create Scheduler<br/>CosineAnnealingLR<br/>T_max=total_steps]
|
||||
|
||||
CreateScheduler --> CreateTrainer[Create Trainer<br/>Model<br/>DataLoader<br/>Optimizer<br/>Scheduler<br/>Device]
|
||||
|
||||
CreateTrainer --> Trainer[Trainer Ready]
|
||||
|
||||
Trainer --> TrainingLoop[Training Loop<br/>For each epoch]
|
||||
|
||||
TrainingLoop --> EpochLoop[For each batch]
|
||||
|
||||
EpochLoop --> Forward[Forward Pass<br/>Model prediction]
|
||||
|
||||
Forward --> Loss[Compute Loss<br/>CrossEntropyLoss]
|
||||
|
||||
Loss --> Backward[Backward Pass<br/>Compute gradients]
|
||||
|
||||
Backward --> GradientAccum{Gradient<br/>Accumulation?}
|
||||
|
||||
GradientAccum -->|Not yet| Accumulate[Accumulate gradients]
|
||||
Accumulate --> EpochLoop
|
||||
|
||||
GradientAccum -->|Ready| ClipGrad[Gradient Clipping<br/>max_grad_norm]
|
||||
|
||||
ClipGrad --> Update[Update Weights<br/>Optimizer.step]
|
||||
|
||||
Update --> UpdateLR[Update Learning Rate<br/>Scheduler.step]
|
||||
|
||||
UpdateLR --> ZeroGrad[Zero Gradients]
|
||||
|
||||
ZeroGrad --> Log{Log Interval?}
|
||||
|
||||
Log -->|Yes| LogMetrics[Log Metrics<br/>Loss, LR<br/>Save to metrics.json]
|
||||
Log -->|No| EvalCheck{Evaluation<br/>Interval?}
|
||||
|
||||
LogMetrics --> EvalCheck
|
||||
|
||||
EvalCheck -->|Yes| Evaluate[Evaluate on<br/>Validation Set]
|
||||
EvalCheck -->|No| SaveCheck{End of<br/>Epoch?}
|
||||
|
||||
Evaluate --> SaveCheck
|
||||
|
||||
SaveCheck -->|No| EpochLoop
|
||||
SaveCheck -->|Yes| SaveCheckpoint[Save Checkpoint<br/>Model state<br/>Optimizer state<br/>Scheduler state<br/>Epoch/Step]
|
||||
|
||||
SaveCheckpoint --> MoreEpochs{More<br/>Epochs?}
|
||||
|
||||
MoreEpochs -->|Yes| TrainingLoop
|
||||
MoreEpochs -->|No| GeneratePlots[Generate Training Plots<br/>loss_by_epoch.png<br/>training_curve.png]
|
||||
|
||||
GeneratePlots --> End([Training Complete!<br/>Checkpoints saved])
|
||||
|
||||
style SetSeed fill:#ffe1f5
|
||||
style DataIngestion fill:#e1f5ff
|
||||
style Model fill:#fff4e1
|
||||
style TrainingLoop fill:#ffe1f5
|
||||
style End fill:#e1ffe1
|
||||
```
|
||||
|
||||
### Seed Initialization Details
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant TrainScript as train.py
|
||||
participant Config as Config
|
||||
participant PyTorch as PyTorch
|
||||
participant CUDA as CUDA Backend
|
||||
|
||||
TrainScript->>Config: Load config (seed=42)
|
||||
TrainScript->>PyTorch: torch.manual_seed(42)
|
||||
TrainScript->>CUDA: torch.cuda.manual_seed_all(42)
|
||||
TrainScript->>PyTorch: torch.backends.cudnn.deterministic = True
|
||||
TrainScript->>PyTorch: torch.backends.cudnn.benchmark = False
|
||||
|
||||
Note over TrainScript,CUDA: Seed ensures reproducibility<br/>across runs and devices
|
||||
```
|
||||
|
||||
### Training Loop Details
|
||||
|
||||
```mermaid
|
||||
graph LR
|
||||
subgraph "Single Training Step"
|
||||
A[Batch Input<br/>input_ids, labels] --> B[Forward Pass<br/>Model forward]
|
||||
B --> C[Logits<br/>batch_size × seq_len × vocab_size]
|
||||
C --> D[Compute Loss<br/>CrossEntropyLoss]
|
||||
D --> E[Backward Pass<br/>Compute gradients]
|
||||
E --> F{Gradient<br/>Accumulation<br/>Steps reached?}
|
||||
F -->|No| G[Accumulate Gradients]
|
||||
F -->|Yes| H[Gradient Clipping]
|
||||
H --> I[Optimizer Step<br/>Update weights]
|
||||
I --> J[Scheduler Step<br/>Update LR]
|
||||
J --> K[Zero Gradients]
|
||||
K --> L[Log Metrics]
|
||||
end
|
||||
|
||||
G --> A
|
||||
|
||||
style B fill:#e1f5ff
|
||||
style D fill:#ffe1f5
|
||||
style I fill:#fff4e1
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Model Architecture
|
||||
|
||||
### Transformer Model Structure
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
Input[Input Tokens<br/>Token IDs] --> Embed[Token Embedding<br/>vocab_size → d_model]
|
||||
|
||||
Embed --> PosEnc[Positional Encoding<br/>Sinusoidal/Cosine]
|
||||
|
||||
PosEnc --> Dropout1[Dropout]
|
||||
|
||||
Dropout1 --> Layer1[Transformer Block 1]
|
||||
Layer1 --> Layer2[Transformer Block 2]
|
||||
Layer2 --> Layer3[Transformer Block 3]
|
||||
Layer3 --> LayerN[Transformer Block N<br/>num_layers]
|
||||
|
||||
LayerN --> LayerNorm[Final Layer Norm]
|
||||
|
||||
LayerNorm --> OutputProj[Output Projection<br/>d_model → vocab_size]
|
||||
|
||||
OutputProj --> Logits[Logits<br/>batch × seq_len × vocab_size]
|
||||
|
||||
subgraph "Transformer Block Details"
|
||||
TBInput[Input x] --> Attention[Multi-Head<br/>Self-Attention]
|
||||
Attention --> AddNorm1[Add & Norm<br/>Residual + LayerNorm]
|
||||
AddNorm1 --> FFN[Feed-Forward<br/>Network]
|
||||
FFN --> AddNorm2[Add & Norm<br/>Residual + LayerNorm]
|
||||
AddNorm2 --> TBOutput[Output]
|
||||
end
|
||||
|
||||
style Embed fill:#e1f5ff
|
||||
style Attention fill:#ffe1f5
|
||||
style FFN fill:#fff4e1
|
||||
style Logits fill:#e1ffe1
|
||||
```
|
||||
|
||||
### Multi-Head Attention Mechanism
|
||||
|
||||
```mermaid
|
||||
graph LR
|
||||
Input[Input<br/>batch × seq_len × d_model] --> Q[Query<br/>Linear Layer]
|
||||
Input --> K[Key<br/>Linear Layer]
|
||||
Input --> V[Value<br/>Linear Layer]
|
||||
|
||||
Q --> SplitQ[Split into<br/>num_heads heads]
|
||||
K --> SplitK[Split into<br/>num_heads heads]
|
||||
V --> SplitV[Split into<br/>num_heads heads]
|
||||
|
||||
SplitQ --> ScaledDot[Scaled Dot-Product<br/>Attention]
|
||||
SplitK --> ScaledDot
|
||||
SplitV --> ScaledDot
|
||||
|
||||
ScaledDot --> Mask[Causal Mask<br/>Lower triangular]
|
||||
|
||||
Mask --> Softmax[Softmax]
|
||||
|
||||
Softmax --> AttentionOutput[Attention Output<br/>per head]
|
||||
|
||||
AttentionOutput --> Concat[Concat Heads]
|
||||
|
||||
Concat --> OutputProj[Output Projection<br/>Linear Layer]
|
||||
|
||||
OutputProj --> Output[Output<br/>batch × seq_len × d_model]
|
||||
|
||||
style ScaledDot fill:#ffe1f5
|
||||
style Mask fill:#fff4e1
|
||||
style Output fill:#e1ffe1
|
||||
```
|
||||
|
||||
### Complete Model Component Diagram
|
||||
|
||||
```mermaid
|
||||
classDiagram
|
||||
class TransformerModel {
|
||||
+vocab_size: int
|
||||
+d_model: int
|
||||
+num_layers: int
|
||||
+num_heads: int
|
||||
+token_embedding: Embedding
|
||||
+pos_encoding: PositionalEncoding
|
||||
+layers: ModuleList[TransformerBlock]
|
||||
+final_norm: LayerNorm
|
||||
+output_proj: Linear
|
||||
+forward(input_ids) Tuple[Tensor, Tensor]
|
||||
+generate(...) Tensor
|
||||
+get_num_params() int
|
||||
}
|
||||
|
||||
class TransformerBlock {
|
||||
+attention: MultiHeadAttention
|
||||
+ffn: FeedForward
|
||||
+norm1: LayerNorm
|
||||
+norm2: LayerNorm
|
||||
+dropout: Dropout
|
||||
+forward(x, mask) Tensor
|
||||
}
|
||||
|
||||
class MultiHeadAttention {
|
||||
+num_heads: int
|
||||
+d_model: int
|
||||
+d_k: int
|
||||
+q_proj: Linear
|
||||
+k_proj: Linear
|
||||
+v_proj: Linear
|
||||
+out_proj: Linear
|
||||
+forward(q, k, v, mask) Tensor
|
||||
}
|
||||
|
||||
class FeedForward {
|
||||
+linear1: Linear
|
||||
+linear2: Linear
|
||||
+activation: GELU/ReLU
|
||||
+dropout: Dropout
|
||||
+forward(x) Tensor
|
||||
}
|
||||
|
||||
class PositionalEncoding {
|
||||
+d_model: int
|
||||
+max_len: int
|
||||
+pe: Tensor
|
||||
+forward(x) Tensor
|
||||
}
|
||||
|
||||
TransformerModel --> TransformerBlock : contains N layers
|
||||
TransformerModel --> PositionalEncoding : adds positional info
|
||||
TransformerBlock --> MultiHeadAttention : self-attention
|
||||
TransformerBlock --> FeedForward : feed-forward network
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Inference Pipeline
|
||||
|
||||
### Text Generation Flow
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
Start([python inference.py<br/>--checkpoint path<br/>--prompt text]) --> LoadModel[Load Model from Checkpoint<br/>Load state dict<br/>Set to eval mode]
|
||||
|
||||
LoadModel --> CreateTokenizer[Create Tokenizer<br/>SimpleTokenizer]
|
||||
|
||||
CreateTokenizer --> EncodePrompt[Encode Prompt<br/>Text → Token IDs]
|
||||
|
||||
EncodePrompt --> CheckOptimized{Use Optimized<br/>Inference?}
|
||||
|
||||
CheckOptimized -->|Yes| OptimizedGen[OptimizedInference<br/>with KV Caching]
|
||||
CheckOptimized -->|No| StandardGen[Standard Generation]
|
||||
|
||||
StandardGen --> InitGen[Initialize Generation<br/>generated = input_ids]
|
||||
|
||||
InitGen --> LoopStart[Generation Loop<br/>For max_length steps]
|
||||
|
||||
LoopStart --> Forward[Forward Pass<br/>Model prediction]
|
||||
|
||||
Forward --> NextToken[Get Next Token Logits<br/>Last position]
|
||||
|
||||
NextToken --> Temperature[Apply Temperature<br/>Scale logits]
|
||||
|
||||
Temperature --> TopK{Top-K<br/>Filtering?}
|
||||
|
||||
TopK -->|Yes| FilterK[Filter Top-K Tokens]
|
||||
TopK -->|No| TopP{Top-P<br/>Nucleus Sampling?}
|
||||
|
||||
FilterK --> TopP
|
||||
|
||||
TopP -->|Yes| FilterP[Filter by Cumulative Prob]
|
||||
TopP -->|No| Sample[Sample Token<br/>Multinomial]
|
||||
|
||||
FilterP --> Sample
|
||||
|
||||
Sample --> Append[Append Token<br/>to Generated]
|
||||
|
||||
Append --> CheckStop{Stop<br/>Condition?}
|
||||
|
||||
CheckStop -->|No| LoopStart
|
||||
CheckStop -->|Yes| Decode[Decode Tokens<br/>Token IDs → Text]
|
||||
|
||||
OptimizedGen --> KVCache[Use KV Cache<br/>Cache previous KV]
|
||||
KVCache --> LoopStart
|
||||
|
||||
Decode --> Output[Generated Text<br/>Output]
|
||||
|
||||
Output --> End([End])
|
||||
|
||||
style OptimizedGen fill:#e1f5ff
|
||||
style Forward fill:#ffe1f5
|
||||
style Sample fill:#fff4e1
|
||||
style Output fill:#e1ffe1
|
||||
```
|
||||
|
||||
### Optimized Inference with KV Caching
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Standard Generation"
|
||||
A1[Input Token] --> B1[Forward Pass<br/>Compute Q, K, V]
|
||||
B1 --> C1[Attention<br/>Full Sequence]
|
||||
C1 --> D1[Next Token]
|
||||
D1 --> E1[Append Token]
|
||||
E1 --> A1
|
||||
end
|
||||
|
||||
subgraph "Optimized Generation with KV Cache"
|
||||
A2[Input Token] --> B2{First<br/>Token?}
|
||||
B2 -->|Yes| C2[Forward Pass<br/>Compute Q, K, V]
|
||||
B2 -->|No| C2Cache[Use Cached K, V<br/>Only compute Q]
|
||||
C2 --> D2[Cache K, V]
|
||||
D2 --> E2[Attention<br/>Only with New Token]
|
||||
C2Cache --> E2
|
||||
E2 --> F2[Next Token]
|
||||
F2 --> G2[Append Token]
|
||||
G2 --> A2
|
||||
end
|
||||
|
||||
style C2 fill:#e1f5ff
|
||||
style C2Cache fill:#ffe1f5
|
||||
style E2 fill:#e1ffe1
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Complete Workflow
|
||||
|
||||
### End-to-End System Flow
|
||||
|
||||
```mermaid
|
||||
flowchart TB
|
||||
subgraph "Phase 1: Data Preparation"
|
||||
A1[Raw Data Files<br/>PDFs, Images, Code, Text] --> A2[DataProcessor<br/>Extract Text]
|
||||
A2 --> A3[Text Lines<br/>List of Strings]
|
||||
A3 --> A4[SimpleTokenizer<br/>Build Vocabulary]
|
||||
A4 --> A5[Tokenize & Chunk<br/>Create Sequences]
|
||||
A5 --> A6[DataLoader<br/>Batched Data]
|
||||
end
|
||||
|
||||
subgraph "Phase 2: Model Initialization"
|
||||
B1[Load Config<br/>ModelConfig] --> B2[Set Random Seed<br/>seed=42]
|
||||
B2 --> B3[Create Model<br/>TransformerModel]
|
||||
B3 --> B4[Initialize Weights<br/>Normal Distribution]
|
||||
B4 --> B5[Create Optimizer<br/>AdamW]
|
||||
B5 --> B6[Create Scheduler<br/>CosineAnnealingLR]
|
||||
end
|
||||
|
||||
subgraph "Phase 3: Training"
|
||||
C1[Trainer Setup] --> C2[Training Loop<br/>Epochs]
|
||||
C2 --> C3[Batch Loop]
|
||||
C3 --> C4[Forward Pass]
|
||||
C4 --> C5[Compute Loss]
|
||||
C5 --> C6[Backward Pass]
|
||||
C6 --> C7[Gradient Clipping]
|
||||
C7 --> C8[Update Weights]
|
||||
C8 --> C9[Save Checkpoint]
|
||||
C9 --> C10{More Epochs?}
|
||||
C10 -->|Yes| C2
|
||||
C10 -->|No| C11[Generate Plots<br/>Training Metrics]
|
||||
end
|
||||
|
||||
subgraph "Phase 4: Inference"
|
||||
D1[Load Checkpoint] --> D2[Load Model State]
|
||||
D2 --> D3[Encode Prompt]
|
||||
D3 --> D4[Generate Text<br/>Autoregressive]
|
||||
D4 --> D5[Decode Tokens]
|
||||
D5 --> D6[Output Text]
|
||||
end
|
||||
|
||||
A6 --> B1
|
||||
B6 --> C1
|
||||
C11 --> D1
|
||||
|
||||
style A2 fill:#e1f5ff
|
||||
style B3 fill:#fff4e1
|
||||
style C4 fill:#ffe1f5
|
||||
style D4 fill:#e1ffe1
|
||||
```
|
||||
|
||||
### Checkpoint Structure
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
Checkpoint[Checkpoint File<br/>checkpoint_epoch_N.pt] --> ModelState[model_state_dict<br/>Model weights]
|
||||
Checkpoint --> OptimizerState[optimizer_state_dict<br/>AdamW state]
|
||||
Checkpoint --> SchedulerState[scheduler_state_dict<br/>LR scheduler state]
|
||||
Checkpoint --> ModelConfig[model_config<br/>Model hyperparameters]
|
||||
Checkpoint --> Epoch[epoch<br/>Current epoch number]
|
||||
Checkpoint --> GlobalStep[global_step<br/>Training step count]
|
||||
Checkpoint --> BestValLoss[best_val_loss<br/>Best validation loss]
|
||||
|
||||
ModelState --> Resume[Resume Training<br/>Restore model state]
|
||||
OptimizerState --> Resume
|
||||
SchedulerState --> Resume
|
||||
ModelConfig --> Resume
|
||||
Epoch --> Resume
|
||||
GlobalStep --> Resume
|
||||
|
||||
style Checkpoint fill:#e1f5ff
|
||||
style Resume fill:#e1ffe1
|
||||
```
|
||||
|
||||
### Configuration Hierarchy
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
Config[Config<br/>Root Configuration] --> ModelConfig[ModelConfig<br/>vocab_size<br/>d_model<br/>num_layers<br/>num_heads<br/>d_ff<br/>max_seq_len<br/>dropout<br/>activation]
|
||||
|
||||
Config --> TrainingConfig[TrainingConfig<br/>batch_size<br/>max_epochs<br/>learning_rate<br/>weight_decay<br/>warmup_steps<br/>max_grad_norm<br/>gradient_accumulation_steps<br/>use_amp]
|
||||
|
||||
Config --> DataConfig[DataConfig<br/>data_dir<br/>max_length<br/>stride<br/>num_workers]
|
||||
|
||||
Config --> Global[Global Settings<br/>device<br/>seed]
|
||||
|
||||
ModelConfig --> Model[TransformerModel<br/>Model Architecture]
|
||||
TrainingConfig --> Trainer[Trainer<br/>Training Parameters]
|
||||
DataConfig --> DataLoader[DataLoader<br/>Data Parameters]
|
||||
|
||||
style Config fill:#e1f5ff
|
||||
style Model fill:#fff4e1
|
||||
style Trainer fill:#ffe1f5
|
||||
style DataLoader fill:#e1ffe1
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Components Summary
|
||||
|
||||
### 1. **Data Processing**
|
||||
- **DataProcessor**: Multi-format text extraction (PDFs, images, code, text)
|
||||
- **SimpleTokenizer**: Character-level tokenization
|
||||
- **TextDataset**: PyTorch dataset for training
|
||||
- **DataLoader**: Batched data loading
|
||||
|
||||
### 2. **Model Architecture**
|
||||
- **TransformerModel**: Complete transformer language model
|
||||
- **TransformerBlock**: Multi-head attention + feed-forward
|
||||
- **MultiHeadAttention**: Scaled dot-product attention with causal masking
|
||||
- **FeedForward**: Position-wise feed-forward network
|
||||
- **PositionalEncoding**: Sinusoidal position embeddings
|
||||
|
||||
### 3. **Training**
|
||||
- **Trainer**: Complete training loop with:
|
||||
- Gradient accumulation
|
||||
- Mixed precision training (AMP)
|
||||
- Gradient clipping
|
||||
- Learning rate scheduling
|
||||
- Checkpointing
|
||||
- Metrics tracking
|
||||
|
||||
### 4. **Inference**
|
||||
- **Standard Generation**: Autoregressive text generation
|
||||
- **OptimizedInference**: KV caching for faster generation
|
||||
- **RetrievalCache**: Caching for RAG systems
|
||||
|
||||
### 5. **Configuration**
|
||||
- **Config System**: Hierarchical configuration (Model, Training, Data)
|
||||
- **JSON Support**: Save/load configurations
|
||||
- **Default Values**: Sensible defaults for all parameters
|
||||
|
||||
---
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Training
|
||||
```bash
|
||||
# Basic training
|
||||
python train.py --data /path/to/data
|
||||
|
||||
# With custom config
|
||||
python train.py --data /path/to/data --config config.json
|
||||
|
||||
# Resume from checkpoint
|
||||
python train.py --data /path/to/data --resume checkpoints/checkpoint_epoch_5.pt
|
||||
|
||||
# Specify device
|
||||
python train.py --data /path/to/data --device cuda
|
||||
```
|
||||
|
||||
### Inference
|
||||
```bash
|
||||
# Basic inference
|
||||
python inference.py --checkpoint checkpoints/best_checkpoint.pt --prompt "Hello world"
|
||||
|
||||
# With sampling parameters
|
||||
python inference.py \
|
||||
--checkpoint checkpoints/best_checkpoint.pt \
|
||||
--prompt "The future of AI" \
|
||||
--max-length 200 \
|
||||
--temperature 0.8 \
|
||||
--top-k 50 \
|
||||
--top-p 0.95
|
||||
|
||||
# Optimized inference
|
||||
python inference.py \
|
||||
--checkpoint checkpoints/best_checkpoint.pt \
|
||||
--prompt "Hello" \
|
||||
--optimized
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
sheepOp/
|
||||
├── train.py # Main training script
|
||||
├── inference.py # Inference script
|
||||
├── config.py # Configuration management
|
||||
├── config.json # Configuration file
|
||||
├── data/ # Data module (symlink)
|
||||
│ └── __init__.py # Tokenizer, DataLoader, DataProcessor
|
||||
├── models/ # Model definitions
|
||||
│ ├── transformer.py # Main transformer model
|
||||
│ ├── blocks.py # Transformer blocks
|
||||
│ ├── attention.py # Attention mechanisms
|
||||
│ └── optimized_attention.py # Optimized inference
|
||||
├── training/ # Training utilities
|
||||
│ ├── __init__.py # Trainer class
|
||||
│ └── metrics.py # Training metrics
|
||||
├── checkpoints/ # Saved model checkpoints
|
||||
└── requirements.txt # Dependencies
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Flow Summary
|
||||
|
||||
1. **Data Ingestion**: Raw files → Text extraction → Text lines
|
||||
2. **Tokenization**: Text lines → Token sequences → Batched data
|
||||
3. **Model Setup**: Config → Model → Optimizer → Scheduler
|
||||
4. **Training**: Batches → Forward → Loss → Backward → Update → Checkpoint
|
||||
5. **Inference**: Checkpoint → Model → Prompt → Generate → Output
|
||||
|
||||
---
|
||||
|
||||
*This documentation provides a complete view of the SheepOp LLM project architecture and workflow.*
|
||||
|
||||
410
docs/ATTENTION_EXPLAINED.md
Normal file
410
docs/ATTENTION_EXPLAINED.md
Normal file
@@ -0,0 +1,410 @@
|
||||
# What is Attention? Step-by-Step Explanation
|
||||
|
||||
Complete step-by-step explanation of attention mechanisms in transformer models: how models understand relationships between words.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [The Problem Attention Solves](#21-the-problem-attention-solves)
|
||||
2. [What is Attention?](#22-what-is-attention)
|
||||
3. [How Attention Works: Step-by-Step](#23-how-attention-works-step-by-step)
|
||||
4. [Complete Example: Attention in "Hello World"](#24-complete-example-attention-in-hello-world)
|
||||
5. [Why Attention Matters](#25-why-attention-matters)
|
||||
6. [Multi-Head Attention](#26-multi-head-attention)
|
||||
7. [Visual Representation of Attention](#27-visual-representation-of-attention)
|
||||
8. [Key Takeaways](#28-key-takeaways)
|
||||
|
||||
---
|
||||
|
||||
## 2.1 The Problem Attention Solves
|
||||
|
||||
### The Challenge
|
||||
|
||||
**In a sentence, words depend on each other:**
|
||||
|
||||
```
|
||||
"He saw the cat with binoculars"
|
||||
```
|
||||
|
||||
Two possible meanings:
|
||||
1. He used binoculars to see the cat
|
||||
2. The cat has binoculars
|
||||
|
||||
**Context matters!** The model needs to understand which words relate to each other.
|
||||
|
||||
### The Solution: Attention
|
||||
|
||||
**Attention allows the model to "look" at other words when processing each word.**
|
||||
|
||||
---
|
||||
|
||||
## 2.2 What is Attention?
|
||||
|
||||
### Simple Definition
|
||||
|
||||
**Attention** is a mechanism that determines **how much each word should consider other words** when processing information.
|
||||
|
||||
### Intuitive Analogy
|
||||
|
||||
**Think of reading a sentence:**
|
||||
|
||||
When you read "cat" in:
|
||||
```
|
||||
"The cat sat on the mat"
|
||||
```
|
||||
|
||||
You might:
|
||||
- Pay attention to "sat" (what the cat did)
|
||||
- Pay attention to "mat" (where the cat is)
|
||||
- Pay less attention to "the" (just a word)
|
||||
|
||||
**Attention does the same thing mathematically!**
|
||||
|
||||
---
|
||||
|
||||
## 2.3 How Attention Works: Step-by-Step
|
||||
|
||||
### High-Level Overview
|
||||
|
||||
```
|
||||
Step 1: Create Query, Key, Value for each word
|
||||
Step 2: Compare queries and keys (find similarities)
|
||||
Step 3: Calculate attention weights (how much to attend)
|
||||
Step 4: Combine values weighted by attention
|
||||
```
|
||||
|
||||
### Detailed Step-by-Step
|
||||
|
||||
#### Step 1: Create Query, Key, Value (Q, K, V)
|
||||
|
||||
**For each word, create three representations:**
|
||||
|
||||
**Query (Q):** "What am I looking for?"
|
||||
**Key (K):** "What am I offering?"
|
||||
**Value (V):** "What information do I contain?"
|
||||
|
||||
**Example with "Hello World":**
|
||||
|
||||
```
|
||||
Word: "Hello"
|
||||
Query: [0.2, -0.1, 0.3, ...] ← What should I look for?
|
||||
Key: [0.1, 0.2, -0.1, ...] ← What do I represent?
|
||||
Value: [0.15, 0.1, 0.2, ...] ← What information do I have?
|
||||
|
||||
Word: "World"
|
||||
Query: [0.18, 0.15, 0.25, ...]
|
||||
Key: [0.12, 0.19, -0.08, ...]
|
||||
Value: [0.14, 0.12, 0.18, ...]
|
||||
```
|
||||
|
||||
**How Q, K, V are created:**
|
||||
```
|
||||
Q = Word × W_Q (learned matrix)
|
||||
K = Word × W_K (learned matrix)
|
||||
V = Word × W_V (learned matrix)
|
||||
```
|
||||
|
||||
#### Step 2: Compute Similarity Scores
|
||||
|
||||
**Compare each query with all keys:**
|
||||
|
||||
```
|
||||
Score[i, j] = How much should word i attend to word j?
|
||||
```
|
||||
|
||||
**Mathematical Formula:**
|
||||
```
|
||||
Score[i, j] = (Query[i] · Key[j]) / √d_k
|
||||
```
|
||||
|
||||
**Example:**
|
||||
|
||||
**Query for "Hello":** `[0.2, -0.1, 0.3]`
|
||||
**Key for "Hello":** `[0.1, 0.2, -0.1]`
|
||||
**Key for "World":** `[0.12, 0.19, -0.08]`
|
||||
|
||||
**Calculate similarity:**
|
||||
|
||||
```
|
||||
Score["Hello", "Hello"] = (0.2×0.1 + (-0.1)×0.2 + 0.3×(-0.1)) / √3
|
||||
= (0.02 - 0.02 - 0.03) / 1.732
|
||||
= -0.03 / 1.732
|
||||
≈ -0.017
|
||||
|
||||
Score["Hello", "World"] = (0.2×0.12 + (-0.1)×0.19 + 0.3×(-0.08)) / √3
|
||||
= (0.024 - 0.019 - 0.024) / 1.732
|
||||
= -0.019 / 1.732
|
||||
≈ -0.011
|
||||
```
|
||||
|
||||
**Result:** Similarity scores tell us how related words are
|
||||
|
||||
#### Step 3: Convert Scores to Attention Weights
|
||||
|
||||
**Use softmax to convert scores to probabilities:**
|
||||
|
||||
```
|
||||
Attention[i, j] = exp(Score[i, j]) / Σ exp(Score[i, k])
|
||||
```
|
||||
|
||||
**Example:**
|
||||
|
||||
**Raw Scores:**
|
||||
```
|
||||
Score["Hello", "Hello"] = -0.017
|
||||
Score["Hello", "World"] = -0.011
|
||||
```
|
||||
|
||||
**Compute exponentials:**
|
||||
```
|
||||
exp(-0.017) ≈ 0.983
|
||||
exp(-0.011) ≈ 0.989
|
||||
Sum = 0.983 + 0.989 = 1.972
|
||||
```
|
||||
|
||||
**Compute attention weights:**
|
||||
```
|
||||
Attention["Hello", "Hello"] = 0.983 / 1.972 ≈ 0.499 (49.9%)
|
||||
Attention["Hello", "World"] = 0.989 / 1.972 ≈ 0.501 (50.1%)
|
||||
```
|
||||
|
||||
**Meaning:** "Hello" attends 49.9% to itself and 50.1% to "World"
|
||||
|
||||
#### Step 4: Weighted Combination
|
||||
|
||||
**Combine values using attention weights:**
|
||||
|
||||
```
|
||||
Output["Hello"] = Attention["Hello", "Hello"] × Value["Hello"]
|
||||
+ Attention["Hello", "World"] × Value["World"]
|
||||
```
|
||||
|
||||
**Example:**
|
||||
|
||||
```
|
||||
Value["Hello"] = [0.15, 0.1, 0.2]
|
||||
Value["World"] = [0.14, 0.12, 0.18]
|
||||
|
||||
Output["Hello"] = 0.499 × [0.15, 0.1, 0.2] + 0.501 × [0.14, 0.12, 0.18]
|
||||
= [0.075, 0.050, 0.100] + [0.070, 0.060, 0.090]
|
||||
= [0.145, 0.110, 0.190]
|
||||
```
|
||||
|
||||
**Result:** New representation that combines information from both words!
|
||||
|
||||
---
|
||||
|
||||
## 2.4 Complete Example: Attention in "Hello World"
|
||||
|
||||
### Input
|
||||
|
||||
```
|
||||
Words: ["Hello", "World"]
|
||||
Position 0: "Hello"
|
||||
Position 1: "World"
|
||||
```
|
||||
|
||||
### Step-by-Step Processing
|
||||
|
||||
#### Step 1: Embeddings
|
||||
|
||||
```
|
||||
E["Hello"] = [0.10, -0.20, 0.30, ..., 0.05]
|
||||
E["World"] = [0.15, -0.18, 0.28, ..., 0.10]
|
||||
```
|
||||
|
||||
#### Step 2: Create Q, K, V
|
||||
|
||||
```
|
||||
Q["Hello"] = E["Hello"] × W_Q = [0.2, -0.1, 0.3, ...]
|
||||
K["Hello"] = E["Hello"] × W_K = [0.1, 0.2, -0.1, ...]
|
||||
V["Hello"] = E["Hello"] × W_V = [0.15, 0.1, 0.2, ...]
|
||||
|
||||
Q["World"] = E["World"] × W_Q = [0.18, 0.15, 0.25, ...]
|
||||
K["World"] = E["World"] × W_K = [0.12, 0.19, -0.08, ...]
|
||||
V["World"] = E["World"] × W_V = [0.14, 0.12, 0.18, ...]
|
||||
```
|
||||
|
||||
#### Step 3: Compute Attention Scores
|
||||
|
||||
```
|
||||
Score Matrix (2×2):
|
||||
|
||||
"Hello" "World"
|
||||
"Hello" 0.5 0.3
|
||||
"World" 0.4 0.6
|
||||
```
|
||||
|
||||
**Interpretation:**
|
||||
- "Hello" attends to itself (0.5) more than "World" (0.3)
|
||||
- "World" attends to itself (0.6) more than "Hello" (0.4)
|
||||
|
||||
#### Step 4: Apply Softmax
|
||||
|
||||
```
|
||||
Attention Matrix:
|
||||
|
||||
"Hello" "World"
|
||||
"Hello" 0.62 0.38
|
||||
"World" 0.40 0.60
|
||||
```
|
||||
|
||||
**Interpretation:**
|
||||
- "Hello" gives 62% attention to itself, 38% to "World"
|
||||
- "World" gives 40% attention to "Hello", 60% to itself
|
||||
|
||||
#### Step 5: Weighted Combination
|
||||
|
||||
```
|
||||
Output["Hello"] = 0.62 × V["Hello"] + 0.38 × V["World"]
|
||||
= 0.62 × [0.15, 0.1, 0.2] + 0.38 × [0.14, 0.12, 0.18]
|
||||
= [0.093, 0.062, 0.124] + [0.053, 0.046, 0.068]
|
||||
= [0.146, 0.108, 0.192]
|
||||
|
||||
Output["World"] = 0.40 × V["Hello"] + 0.60 × V["World"]
|
||||
= 0.40 × [0.15, 0.1, 0.2] + 0.60 × [0.14, 0.12, 0.18]
|
||||
= [0.060, 0.040, 0.080] + [0.084, 0.072, 0.108]
|
||||
= [0.144, 0.112, 0.188]
|
||||
```
|
||||
|
||||
**Result:** Each word now contains information from both words!
|
||||
|
||||
---
|
||||
|
||||
## 2.5 Why Attention Matters
|
||||
|
||||
### Benefit 1: Context Understanding
|
||||
|
||||
**Without Attention:**
|
||||
```
|
||||
"Hello" is processed in isolation
|
||||
"World" is processed in isolation
|
||||
Result: No understanding of relationship
|
||||
```
|
||||
|
||||
**With Attention:**
|
||||
```
|
||||
"Hello" considers "World" (38% attention)
|
||||
"World" considers "Hello" (40% attention)
|
||||
Result: Understands they're related
|
||||
```
|
||||
|
||||
### Benefit 2: Long-Range Dependencies
|
||||
|
||||
**Attention can connect distant words:**
|
||||
|
||||
```
|
||||
"The cat that I saw yesterday sat on the mat"
|
||||
```
|
||||
|
||||
- "cat" can attend to "yesterday" (even though far apart)
|
||||
- Model understands the cat from yesterday
|
||||
|
||||
### Benefit 3: Selective Focus
|
||||
|
||||
**Attention focuses on relevant information:**
|
||||
|
||||
```
|
||||
"He saw the cat with binoculars"
|
||||
```
|
||||
|
||||
- "saw" attends strongly to "binoculars" (how he saw)
|
||||
- "cat" attends strongly to "sat" (what it did)
|
||||
- Each word focuses on what's relevant to it
|
||||
|
||||
---
|
||||
|
||||
## 2.6 Multi-Head Attention
|
||||
|
||||
### What is Multi-Head Attention?
|
||||
|
||||
**Multiple attention "heads" look at different aspects:**
|
||||
|
||||
```
|
||||
Head 1: Focuses on syntax (grammar relationships)
|
||||
Head 2: Focuses on semantics (meaning relationships)
|
||||
Head 3: Focuses on position (spatial relationships)
|
||||
...
|
||||
Head 8: Focuses on another aspect
|
||||
```
|
||||
|
||||
### Visual Representation
|
||||
|
||||
```
|
||||
Input: "Hello World"
|
||||
|
||||
Head 1 (Syntax):
|
||||
"Hello" → attends to "World" (subject-object relationship)
|
||||
|
||||
Head 2 (Semantics):
|
||||
"Hello" → attends to "World" (greeting relationship)
|
||||
|
||||
Head 3 (Position):
|
||||
"Hello" → attends more to itself (being first)
|
||||
|
||||
... (other heads)
|
||||
|
||||
Final: Combine all heads → Richer representation
|
||||
```
|
||||
|
||||
### Why Multiple Heads?
|
||||
|
||||
**Different heads capture different relationships:**
|
||||
|
||||
- **Head 1:** Grammatical relationships
|
||||
- **Head 2:** Semantic relationships
|
||||
- **Head 3:** Positional relationships
|
||||
- **Head 4:** Other patterns...
|
||||
|
||||
**Together:** Comprehensive understanding!
|
||||
|
||||
---
|
||||
|
||||
## 2.7 Visual Representation of Attention
|
||||
|
||||
### Attention Heatmap
|
||||
|
||||
```
|
||||
Attention Weights for "Hello World"
|
||||
|
||||
Position 0 Position 1
|
||||
("Hello") ("World")
|
||||
┌─────────┐ ┌─────────┐
|
||||
Position 0 │ 0.62 │ │ 0.38 │
|
||||
("Hello") └─────────┘ └─────────┘
|
||||
┌─────────┐ ┌─────────┐
|
||||
Position 1 │ 0.40 │ │ 0.60 │
|
||||
("World") └─────────┘ └─────────┘
|
||||
```
|
||||
|
||||
**Reading:**
|
||||
- Row 0: "Hello" attends 62% to itself, 38% to "World"
|
||||
- Row 1: "World" attends 40% to "Hello", 60% to itself
|
||||
|
||||
### Attention Flow Diagram
|
||||
|
||||
```
|
||||
"Hello" ──── 0.38 ────→ "World"
|
||||
↑ ↑
|
||||
│ │
|
||||
0.62 0.60
|
||||
│ │
|
||||
└──────────────────────┘
|
||||
(self-attention)
|
||||
```
|
||||
|
||||
**Meaning:** Information flows between words based on attention weights.
|
||||
|
||||
---
|
||||
|
||||
## 2.8 Key Takeaways: Attention
|
||||
|
||||
✅ **Attention determines which words to focus on**
|
||||
✅ **Calculates similarity between words**
|
||||
✅ **Creates weighted combinations of information**
|
||||
✅ **Enables understanding of relationships**
|
||||
✅ **Multiple heads capture different aspects**
|
||||
|
||||
---
|
||||
|
||||
*This document provides a step-by-step explanation of attention mechanisms, the core component that enables transformers to understand relationships between words.*
|
||||
|
||||
757
docs/BENCHMARKING_GUIDE.md
Normal file
757
docs/BENCHMARKING_GUIDE.md
Normal file
@@ -0,0 +1,757 @@
|
||||
# Inference Benchmarking Guide
|
||||
|
||||
This guide explains how to use the benchmarking feature to compare optimized vs non-optimized inference performance for research purposes.
|
||||
|
||||
## Overview
|
||||
|
||||
The benchmarking feature runs inference both with and without optimizations (KV caching, optimized attention) and generates:
|
||||
|
||||
- **Performance metrics** (tokens/sec, latency, memory usage)
|
||||
- **Comparison plots** (visual charts showing improvements)
|
||||
- **CSV export** (data for further analysis)
|
||||
|
||||
## Data Storage Location
|
||||
|
||||
**All benchmark data is saved to:** `./inference_benchmarks/` (default)
|
||||
|
||||
**You can customize the location:**
|
||||
|
||||
```bash
|
||||
python inference.py --benchmark --benchmark-dir ./research/results
|
||||
```
|
||||
|
||||
**Data files created:**
|
||||
|
||||
- `inference_metrics.json` - All raw metrics (JSON format)
|
||||
- `inference_metrics.csv` - Spreadsheet-friendly data (CSV format)
|
||||
- `optimization_comparison.png` - Visual comparison charts
|
||||
- `performance_over_time.png` - Trend analysis over multiple runs
|
||||
|
||||
**Note:** All runs accumulate in the same files, so you can run multiple benchmarks and build trends over time.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Benchmark
|
||||
|
||||
```bash
|
||||
python inference.py \
|
||||
--checkpoint checkpoints/best_checkpoint.pt \
|
||||
--prompt "The future of artificial intelligence" \
|
||||
--max-length 100 \
|
||||
--benchmark
|
||||
```
|
||||
|
||||
This will:
|
||||
|
||||
1. Run inference **without** optimizations
|
||||
2. Run inference **with** optimizations (KV cache)
|
||||
3. Collect metrics for both runs
|
||||
4. Generate comparison plots
|
||||
5. Save all data to `./inference_benchmarks/`
|
||||
|
||||
### Custom Benchmark Directory
|
||||
|
||||
```bash
|
||||
python inference.py \
|
||||
--checkpoint checkpoints/best_checkpoint.pt \
|
||||
--prompt "Your prompt here" \
|
||||
--max-length 100 \
|
||||
--benchmark \
|
||||
--benchmark-dir ./research/results
|
||||
```
|
||||
|
||||
### Running Multiple Prompts for Trends
|
||||
|
||||
**Use the batch benchmark script** to run multiple prompts and create trends:
|
||||
|
||||
```bash
|
||||
# Create a prompts file
|
||||
cat > prompts.txt << EOF
|
||||
The future of artificial intelligence
|
||||
Machine learning is transforming
|
||||
Deep neural networks enable
|
||||
Natural language processing requires
|
||||
EOF
|
||||
|
||||
# Run batch benchmarks
|
||||
python benchmark_batch.py \
|
||||
--checkpoint checkpoints/best_checkpoint.pt \
|
||||
--prompt-file prompts.txt \
|
||||
--max-length 100 \
|
||||
--benchmark-dir ./research/results
|
||||
```
|
||||
|
||||
**Or use command-line prompts:**
|
||||
|
||||
```bash
|
||||
python benchmark_batch.py \
|
||||
--checkpoint checkpoints/best_checkpoint.pt \
|
||||
--prompts "Prompt 1" "Prompt 2" "Prompt 3" \
|
||||
--max-length 100
|
||||
```
|
||||
|
||||
**Results accumulate** in the same files, allowing you to:
|
||||
|
||||
- Build trends across multiple prompts
|
||||
- Analyze performance consistency
|
||||
- Create comprehensive research reports
|
||||
|
||||
## Output Files
|
||||
|
||||
After running a benchmark, you'll get:
|
||||
|
||||
### 1. JSON Metrics File
|
||||
|
||||
**Location:** `inference_benchmarks/inference_metrics.json`
|
||||
|
||||
Contains all raw metrics data:
|
||||
|
||||
```json
|
||||
{
|
||||
"runs": [
|
||||
{
|
||||
"run_name": "run_1234567890_optimized",
|
||||
"optimized": true,
|
||||
"tokens_per_second": 150.5,
|
||||
"time_per_token": 6.64,
|
||||
"memory_used_mb": 245.3,
|
||||
...
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 2. CSV Export
|
||||
|
||||
**Location:** `inference_benchmarks/inference_metrics.csv`
|
||||
|
||||
For spreadsheet analysis:
|
||||
|
||||
```csv
|
||||
run_name,timestamp,optimized,prompt_length,generated_length,total_time,tokens_per_second,time_per_token,memory_used_mb,device
|
||||
run_1234567890_optimized,1234567890.5,true,20,100,0.663,150.8,6.63,245.3,cuda
|
||||
...
|
||||
```
|
||||
|
||||
### 3. Comparison Plot
|
||||
|
||||
**Location:** `inference_benchmarks/optimization_comparison.png`
|
||||
|
||||
Shows 4 charts:
|
||||
|
||||
- **Tokens per Second** (speed comparison)
|
||||
- **Time per Token** (latency comparison)
|
||||
- **Total Generation Time** (overall speed)
|
||||
- **Memory Usage** (memory efficiency)
|
||||
|
||||
### 4. Performance Over Time Plot
|
||||
|
||||
**Location:** `inference_benchmarks/performance_over_time.png`
|
||||
|
||||
Shows how performance varies across multiple benchmark runs.
|
||||
|
||||
## Metrics Collected
|
||||
|
||||
### Performance Metrics
|
||||
|
||||
- **Tokens per Second**: Generation speed
|
||||
- **Time per Token**: Latency per token (milliseconds)
|
||||
- **Total Time**: Complete generation time
|
||||
|
||||
### Resource Metrics
|
||||
|
||||
- **Memory Usage**: GPU memory consumption (MB)
|
||||
- **Device**: Device used (cuda/cpu/mps)
|
||||
|
||||
### Derived Metrics
|
||||
|
||||
- **Speedup**: Ratio of optimized vs non-optimized speed
|
||||
- **Memory Reduction**: Percentage reduction in memory usage
|
||||
|
||||
## Example Output
|
||||
|
||||
```
|
||||
🔬 BENCHMARK MODE: Comparing optimized vs non-optimized inference
|
||||
======================================================================
|
||||
|
||||
BENCHMARK RUN: run_1234567890
|
||||
======================================================================
|
||||
|
||||
🔴 Running NON-OPTIMIZED inference...
|
||||
⏱️ Total Time: 1.234 s
|
||||
📊 Tokens/Second: 81.0
|
||||
⚡ Time/Token: 12.35 ms
|
||||
💾 Memory Used: 512.3 MB
|
||||
📝 Generated: The future of artificial intelligence is bright...
|
||||
|
||||
🟢 Running OPTIMIZED inference...
|
||||
⏱️ Total Time: 0.663 s
|
||||
📊 Tokens/Second: 150.8
|
||||
⚡ Time/Token: 6.63 ms
|
||||
💾 Memory Used: 245.3 MB
|
||||
📝 Generated: The future of artificial intelligence is bright...
|
||||
|
||||
🚀 SPEEDUP: 1.86x faster with optimizations
|
||||
💾 MEMORY REDUCTION: 52.1%
|
||||
|
||||
📊 Generating comparison plots and data...
|
||||
📊 Comparison plot saved to: ./inference_benchmarks/optimization_comparison.png
|
||||
📊 Performance over time plot saved to: ./inference_benchmarks/performance_over_time.png
|
||||
📊 Metrics exported to CSV: ./inference_benchmarks/inference_metrics.csv
|
||||
|
||||
✅ Benchmark complete! Results saved to: ./inference_benchmarks
|
||||
```
|
||||
|
||||
## Running Multiple Benchmarks for Trends
|
||||
|
||||
### Method 1: Individual Runs (Manual)
|
||||
|
||||
```bash
|
||||
# Run 1
|
||||
python inference.py --checkpoint checkpoints/best.pt --prompt "Prompt 1" --benchmark
|
||||
|
||||
# Run 2
|
||||
python inference.py --checkpoint checkpoints/best.pt --prompt "Prompt 2" --benchmark
|
||||
|
||||
# Run 3
|
||||
python inference.py --checkpoint checkpoints/best.pt --prompt "Prompt 3" --max-length 200 --benchmark
|
||||
```
|
||||
|
||||
All runs accumulate in the same files:
|
||||
|
||||
- `inference_metrics.json` - All runs appended
|
||||
- `inference_metrics.csv` - All runs in CSV format
|
||||
- Plots update automatically with new data
|
||||
|
||||
### Method 2: Batch Script (Recommended)
|
||||
|
||||
**Create a prompts file:**
|
||||
|
||||
```bash
|
||||
cat > research_prompts.txt << EOF
|
||||
The future of artificial intelligence is bright.
|
||||
Machine learning models are becoming more efficient.
|
||||
Deep neural networks can process complex patterns.
|
||||
Natural language processing enables human-computer interaction.
|
||||
Transformer architectures revolutionized NLP.
|
||||
EOF
|
||||
```
|
||||
|
||||
**Run batch benchmarks:**
|
||||
|
||||
```bash
|
||||
python benchmark_batch.py \
|
||||
--checkpoint checkpoints/best_checkpoint.pt \
|
||||
--prompt-file research_prompts.txt \
|
||||
--max-length 100 \
|
||||
--benchmark-dir ./research/results \
|
||||
--delay 2.0
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
|
||||
- ✅ Runs all prompts automatically
|
||||
- ✅ Accumulates data for trend analysis
|
||||
- ✅ Creates comprehensive performance reports
|
||||
- ✅ Handles errors gracefully
|
||||
|
||||
**After running multiple benchmarks:**
|
||||
|
||||
- Check `performance_over_time.png` for trends
|
||||
- Analyze `inference_metrics.csv` in Excel/Python
|
||||
- Review aggregated statistics in console output
|
||||
|
||||
## Research Use Cases
|
||||
|
||||
### 1. Performance Analysis
|
||||
|
||||
Compare how optimizations affect inference speed:
|
||||
|
||||
```bash
|
||||
python inference.py \
|
||||
--checkpoint checkpoints/best.pt \
|
||||
--prompt "Your research prompt" \
|
||||
--benchmark
|
||||
```
|
||||
|
||||
### 2. Memory Efficiency Study
|
||||
|
||||
Analyze memory usage improvements:
|
||||
|
||||
```bash
|
||||
# Check memory reduction
|
||||
python inference.py --checkpoint checkpoints/best.pt --prompt "Long prompt" --max-length 500 --benchmark
|
||||
```
|
||||
|
||||
### 3. Scalability Testing
|
||||
|
||||
Test with different generation lengths:
|
||||
|
||||
```bash
|
||||
# Short sequences
|
||||
python inference.py --checkpoint checkpoints/best.pt --prompt "Test" --max-length 50 --benchmark
|
||||
|
||||
# Medium sequences
|
||||
python inference.py --checkpoint checkpoints/best.pt --prompt "Test" --max-length 200 --benchmark
|
||||
|
||||
# Long sequences
|
||||
python inference.py --checkpoint checkpoints/best.pt --prompt "Test" --max-length 1000 --benchmark
|
||||
```
|
||||
|
||||
## Plot Interpretation
|
||||
|
||||
### Comparison Plot (`optimization_comparison.png`)
|
||||
|
||||
**Top Left - Tokens per Second:**
|
||||
|
||||
- Higher is better
|
||||
- Shows generation speed
|
||||
- Speedup annotation shows improvement factor
|
||||
|
||||
**Top Right - Time per Token:**
|
||||
|
||||
- Lower is better
|
||||
- Shows latency per token
|
||||
- Important for real-time applications
|
||||
|
||||
**Bottom Left - Total Generation Time:**
|
||||
|
||||
- Lower is better
|
||||
- Overall generation time
|
||||
- Most user-visible metric
|
||||
|
||||
**Bottom Right - Memory Usage:**
|
||||
|
||||
- Lower is better
|
||||
- GPU memory consumption
|
||||
- Memory reduction annotation shows savings
|
||||
|
||||
### Performance Over Time Plot (`performance_over_time.png`)
|
||||
|
||||
Shows performance trends across multiple benchmark runs:
|
||||
|
||||
- **Green line**: Optimized performance
|
||||
- **Red line**: Non-optimized performance
|
||||
- Useful for finding performance regressions or improvements
|
||||
|
||||
## Reporting Results
|
||||
|
||||
### Speedup Calculation
|
||||
|
||||
```
|
||||
Speedup = Optimized Tokens/Second / Non-Optimized Tokens/Second
|
||||
```
|
||||
|
||||
**Example:**
|
||||
|
||||
- Optimized: 150 tokens/sec
|
||||
- Non-Optimized: 81 tokens/sec
|
||||
- Speedup: 150/81 = 1.85x faster
|
||||
|
||||
### Memory Reduction Calculation
|
||||
|
||||
```
|
||||
Memory Reduction % = (1 - Optimized Memory / Non-Optimized Memory) × 100
|
||||
```
|
||||
|
||||
**Example:**
|
||||
|
||||
- Optimized: 245 MB
|
||||
- Non-Optimized: 512 MB
|
||||
- Reduction: (1 - 245/512) × 100 = 52.1%
|
||||
|
||||
## Tips for Best Results
|
||||
|
||||
1. **Warm Up GPU**: Run a few inference calls before benchmarking to warm up the GPU
|
||||
2. **Clear Cache**: The benchmark automatically clears CUDA cache between runs
|
||||
3. **Multiple Runs**: Run multiple benchmarks for statistical significance
|
||||
4. **Consistent Prompts**: Use the same prompt for fair comparison
|
||||
5. **Device Consistency**: Use the same device for all runs
|
||||
|
||||
## Command Line Options
|
||||
|
||||
```bash
|
||||
python inference.py \
|
||||
--checkpoint PATH # Path to model checkpoint (required)
|
||||
--prompt TEXT # Prompt text (required)
|
||||
--max-length INT # Maximum generation length (default: 100)
|
||||
--temperature FLOAT # Sampling temperature (default: 1.0)
|
||||
--top-k INT # Top-k sampling (default: 50)
|
||||
--top-p FLOAT # Top-p sampling (default: 0.95)
|
||||
--device DEVICE # Device: cuda/cpu/mps (default: cuda)
|
||||
--benchmark # Enable benchmarking mode
|
||||
--benchmark-dir DIR # Benchmark output directory (default: ./inference_benchmarks)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### No GPU Memory Stats
|
||||
|
||||
If memory stats show as `None`:
|
||||
|
||||
- CUDA: Memory tracking should work automatically
|
||||
- MPS (Apple Silicon): Memory tracking not available
|
||||
- CPU: Memory tracking not available
|
||||
|
||||
### Plots Not Generated
|
||||
|
||||
If plots fail to generate:
|
||||
|
||||
- Ensure `matplotlib` is installed: `pip install matplotlib`
|
||||
- Check file permissions for output directory
|
||||
|
||||
### Inconsistent Results
|
||||
|
||||
For consistent results:
|
||||
|
||||
- Use same device for all runs
|
||||
- Use same prompt length
|
||||
- Allow GPU to warm up
|
||||
- Close other GPU applications
|
||||
|
||||
## Example Research Workflow
|
||||
|
||||
```bash
|
||||
# 1. Run initial benchmark
|
||||
python inference.py --checkpoint checkpoints/best.pt --prompt "Test prompt" --benchmark
|
||||
|
||||
# 2. Review results
|
||||
ls inference_benchmarks/
|
||||
cat inference_benchmarks/inference_metrics.json
|
||||
|
||||
# 3. Generate plots (already done automatically)
|
||||
# View: inference_benchmarks/optimization_comparison.png
|
||||
|
||||
# 4. Analyze CSV data
|
||||
# Open: inference_benchmarks/inference_metrics.csv in Excel/Python
|
||||
|
||||
# 5. Run additional benchmarks
|
||||
python inference.py --checkpoint checkpoints/best.pt --prompt "Different prompt" --max-length 200 --benchmark
|
||||
|
||||
# 6. Compare results
|
||||
python inference.py --checkpoint checkpoints/best.pt --prompt "Same prompt" --benchmark
|
||||
```
|
||||
|
||||
## Optimization Architecture & Code Injection
|
||||
|
||||
### Overview: Optimization Layers
|
||||
|
||||
The optimizations are implemented as layers that wrap the standard inference pipeline:
|
||||
|
||||
```mermaid
|
||||
flowchart TB
|
||||
subgraph subGraph0["Standard Inference (Non-Optimized)"]
|
||||
B["Tokenize"]
|
||||
A["Input Prompt"]
|
||||
C["Embedding Layer"]
|
||||
D["Transformer Blocks"]
|
||||
E["Attention: Recompute All"]
|
||||
F["Forward Pass: O(n²)"]
|
||||
G["Output Tokens"]
|
||||
H["Detokenize"]
|
||||
I["Generated Text"]
|
||||
end
|
||||
subgraph subGraph1["Optimized Inference (With KV Cache)"]
|
||||
B2["Tokenize"]
|
||||
A2["Input Prompt"]
|
||||
C2["Embedding Layer"]
|
||||
D2["Transformer Blocks"]
|
||||
E2["Optimized Attention"]
|
||||
F2["KV Cache Layer"]
|
||||
G2["Forward Pass: O(n)"]
|
||||
H2["Output Tokens"]
|
||||
I2["Detokenize"]
|
||||
J2["Generated Text"]
|
||||
end
|
||||
A --> B
|
||||
B --> C
|
||||
C --> D
|
||||
D --> E
|
||||
E --> F
|
||||
F --> G
|
||||
G --> H
|
||||
H --> I
|
||||
A2 --> B2
|
||||
B2 --> C2
|
||||
C2 --> D2
|
||||
D2 --> E2
|
||||
E2 --> F2
|
||||
F2 --> G2
|
||||
G2 --> H2
|
||||
H2 --> I2
|
||||
I2 --> J2
|
||||
style E fill:#ffcccc
|
||||
style F fill:#ffcccc
|
||||
style E2 fill:#ccffcc
|
||||
style F2 fill:#ccffcc
|
||||
```
|
||||
|
||||
### Detailed Optimization Flow
|
||||
|
||||
```mermaid
|
||||
|
||||
flowchart LR
|
||||
subgraph subGraph0["Request Flow"]
|
||||
Mode{"Optimized?"}
|
||||
Start["Benchmark Request"]
|
||||
Standard["Standard Path"]
|
||||
Optimized["Optimized Path"]
|
||||
end
|
||||
subgraph subGraph1["Standard Path"]
|
||||
S1["Model.generate"]
|
||||
S2["Transformer Forward"]
|
||||
S3["MultiHeadAttention"]
|
||||
S4["Compute Q, K, V"]
|
||||
S5["Recompute All KVs"]
|
||||
S6["Attention Scores: O(n²)"]
|
||||
S7["Generate Token"]
|
||||
end
|
||||
subgraph subGraph2["Optimized Path"]
|
||||
O1["OptimizedInference"]
|
||||
O2["Init KV Cache"]
|
||||
O3["Transformer Forward"]
|
||||
O4["OptimizedMultiHeadAttention"]
|
||||
O5["Compute Q, K, V"]
|
||||
O6["KV Cache Layer"]
|
||||
O7["Append to Cache"]
|
||||
O8["Reuse Cached KVs"]
|
||||
O9["Attention Scores: O(n)"]
|
||||
O10["Generate Token"]
|
||||
end
|
||||
Start --> Mode
|
||||
Mode -- No --> Standard
|
||||
Mode -- Yes --> Optimized
|
||||
Standard --> S1
|
||||
S1 --> S2
|
||||
S2 --> S3
|
||||
S3 --> S4
|
||||
S4 --> S5
|
||||
S5 --> S6
|
||||
S6 --> S7
|
||||
Optimized --> O1
|
||||
O1 --> O2
|
||||
O2 --> O3
|
||||
O3 --> O4
|
||||
O4 --> O5
|
||||
O5 --> O6
|
||||
O6 --> O7
|
||||
O7 --> O8
|
||||
O8 --> O9
|
||||
O9 --> O10
|
||||
S7 --> Metrics["Collect Metrics"]
|
||||
O10 --> Metrics
|
||||
style Standard fill:#ffcccc
|
||||
style Optimized fill:#ccffcc
|
||||
style S5 fill:#ffcccc
|
||||
style O8 fill:#ccffcc
|
||||
|
||||
```
|
||||
|
||||
### Code Injection Points
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Standard Model Architecture"
|
||||
A[TransformerModel] --> B[TransformerBlock]
|
||||
B --> C[MultiHeadAttention]
|
||||
C --> D[Q, K, V Projections]
|
||||
D --> E[Attention Computation]
|
||||
E --> F[Output Projection]
|
||||
F --> G[Feed Forward]
|
||||
end
|
||||
|
||||
subgraph "Optimization Injection Points"
|
||||
H[OptimizedInference Wrapper] --> A
|
||||
A --> B2[TransformerBlock]
|
||||
B2 --> C2[OptimizedMultiHeadAttention]
|
||||
C2 --> D2[Q, K, V Projections]
|
||||
D2 --> I[KV Cache Injection]
|
||||
I --> E2[Optimized Attention]
|
||||
E2 --> F2[Output Projection]
|
||||
F2 --> G2[Feed Forward]
|
||||
end
|
||||
|
||||
subgraph "KV Cache Layer Details"
|
||||
I --> J[Cache Check]
|
||||
J --> K{Cache Exists?}
|
||||
K -->|No| L[Compute K, V]
|
||||
K -->|Yes| M[Retrieve from Cache]
|
||||
L --> N[Store in Cache]
|
||||
M --> O[Append New K, V]
|
||||
N --> O
|
||||
O --> P[Use Cached KVs]
|
||||
end
|
||||
|
||||
style H fill:#90EE90
|
||||
style I fill:#90EE90
|
||||
style K fill:#FFD700
|
||||
style P fill:#90EE90
|
||||
```
|
||||
|
||||
### Benchmark Execution Flow
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant User
|
||||
participant InferenceScript
|
||||
participant BenchmarkModule
|
||||
participant OptimizedInference
|
||||
participant StandardModel
|
||||
participant MetricsCollector
|
||||
|
||||
User->>InferenceScript: python inference.py --benchmark
|
||||
InferenceScript->>BenchmarkModule: Initialize Metrics
|
||||
BenchmarkModule->>MetricsCollector: Create InferenceMetrics
|
||||
|
||||
Note over InferenceScript: Run 1: Non-Optimized
|
||||
InferenceScript->>StandardModel: model.generate()
|
||||
StandardModel->>StandardModel: Forward Pass (O(n²))
|
||||
StandardModel-->>InferenceScript: Generated Tokens
|
||||
InferenceScript->>MetricsCollector: Log Run (optimized=false)
|
||||
|
||||
Note over InferenceScript: Run 2: Optimized
|
||||
InferenceScript->>OptimizedInference: get_optimized_inference()
|
||||
OptimizedInference->>OptimizedInference: Init KV Cache
|
||||
OptimizedInference->>OptimizedInference: generate_with_cache()
|
||||
|
||||
loop For each token
|
||||
OptimizedInference->>OptimizedInference: Forward Pass (O(n))
|
||||
OptimizedInference->>OptimizedInference: Update KV Cache
|
||||
end
|
||||
|
||||
OptimizedInference-->>InferenceScript: Generated Tokens
|
||||
InferenceScript->>MetricsCollector: Log Run (optimized=true)
|
||||
|
||||
MetricsCollector->>MetricsCollector: Calculate Speedup
|
||||
MetricsCollector->>MetricsCollector: Generate Plots
|
||||
MetricsCollector->>MetricsCollector: Export CSV
|
||||
MetricsCollector-->>User: Results & Plots
|
||||
```
|
||||
|
||||
### Optimization Components Stack
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
subgraph "Application Layer"
|
||||
A[inference.py] --> B[benchmark_inference]
|
||||
B --> C[Generate Text]
|
||||
end
|
||||
|
||||
subgraph "Optimization Layer"
|
||||
C --> D{Optimized?}
|
||||
D -->|Yes| E[OptimizedInference]
|
||||
D -->|No| F[Standard Model]
|
||||
E --> G[KV Cache Manager]
|
||||
E --> H[Optimized Attention]
|
||||
end
|
||||
|
||||
subgraph "Core Model Layer"
|
||||
F --> I[TransformerModel]
|
||||
E --> I
|
||||
I --> J[TransformerBlock]
|
||||
J --> K[MultiHeadAttention]
|
||||
H --> K
|
||||
K --> L[Attention Computation]
|
||||
end
|
||||
|
||||
subgraph "Cache Layer"
|
||||
G --> M[KVCache Data Structure]
|
||||
M --> N[Keys Cache]
|
||||
M --> O[Values Cache]
|
||||
N --> P[Retrieve Previous K]
|
||||
O --> Q[Retrieve Previous V]
|
||||
end
|
||||
|
||||
subgraph "Compute Layer"
|
||||
L --> R[Q × K^T]
|
||||
P --> R
|
||||
Q --> R
|
||||
R --> S[Softmax]
|
||||
S --> T[Attention Weights]
|
||||
T --> U[Output]
|
||||
end
|
||||
|
||||
style E fill:#90EE90
|
||||
style G fill:#90EE90
|
||||
style H fill:#90EE90
|
||||
style M fill:#FFD700
|
||||
```
|
||||
|
||||
### Performance Comparison Schema
|
||||
|
||||
```mermaid
|
||||
|
||||
flowchart LR
|
||||
subgraph subGraph0["Metrics Collection"]
|
||||
B["Non-Optimized Metrics"]
|
||||
A["Benchmark Run"]
|
||||
C["Optimized Metrics"]
|
||||
D["Time: T1<br>Memory: M1<br>Speed: S1"]
|
||||
E["Time: T2<br>Memory: M2<br>Speed: S2"]
|
||||
end
|
||||
subgraph Analysis["Analysis"]
|
||||
F["Calculate Speedup"]
|
||||
G["Speedup = S2/S1"]
|
||||
H["Calculate Memory Reduction"]
|
||||
I["Reduction = (M1-M2)/M1 × 100%"]
|
||||
end
|
||||
subgraph Visualization["Visualization"]
|
||||
J["Comparison Plot"]
|
||||
K["Trend Analysis"]
|
||||
L["Performance Over Time"]
|
||||
end
|
||||
subgraph subGraph3["Data Export"]
|
||||
M["JSON Metrics"]
|
||||
N["CSV Export"]
|
||||
end
|
||||
A --> B & C
|
||||
B --> D
|
||||
C --> E
|
||||
D --> F & H
|
||||
E --> F & H
|
||||
F --> G & K
|
||||
H --> I
|
||||
G --> J
|
||||
I --> J
|
||||
K --> L
|
||||
J --> M & N
|
||||
L --> M & N
|
||||
style F fill:#FFD700
|
||||
style G fill:#90EE90
|
||||
style I fill:#90EE90
|
||||
|
||||
```
|
||||
|
||||
## Data File Locations Summary
|
||||
|
||||
**All benchmark data is saved to:**
|
||||
|
||||
```
|
||||
./inference_benchmarks/
|
||||
├── inference_metrics.json # All raw metrics (JSON)
|
||||
├── inference_metrics.csv # Spreadsheet data (CSV)
|
||||
├── optimization_comparison.png # Comparison charts
|
||||
└── performance_over_time.png # Trend analysis
|
||||
```
|
||||
|
||||
**Custom location:**
|
||||
|
||||
```bash
|
||||
--benchmark-dir ./research/results
|
||||
```
|
||||
|
||||
**Data accumulates:** Each benchmark run appends to the same files, building trends over time.
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. ✅ Run your first benchmark
|
||||
2. ✅ Review the comparison plots
|
||||
3. ✅ Analyze CSV data for deeper insights
|
||||
4. ✅ Run multiple benchmarks for statistical analysis
|
||||
5. ✅ Use batch script for trend analysis
|
||||
6. ✅ Include results in your research paper/presentation
|
||||
|
||||
---
|
||||
|
||||
**Happy Benchmarking!** 📊🔬
|
||||
854
docs/COMPLETE_GUIDE.md
Normal file
854
docs/COMPLETE_GUIDE.md
Normal file
@@ -0,0 +1,854 @@
|
||||
# SheepOp LLM 🐑➡️🤖
|
||||
|
||||
A modern language model implementation from scratch, incorporating insights from recent research papers.
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the **Apache License 2.0**.
|
||||
|
||||
See [LICENSE](../LICENSE) or [LICENSE.txt](../LICENSE.txt) for the full license text.
|
||||
|
||||
**Summary:**
|
||||
- ✅ Free to use, modify, and distribute
|
||||
- ✅ Commercial use allowed
|
||||
- ✅ Patent grant included
|
||||
- ✅ Private use allowed
|
||||
- ⚠️ Must include license and copyright notice
|
||||
- ⚠️ Must state changes if modifying
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [What is This?](#what-is-this)
|
||||
- [Features](#features)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Mathematical Foundations](#mathematical-foundations)
|
||||
- [Architecture Explained](#architecture-explained)
|
||||
- [Project Structure](#project-structure)
|
||||
- [Installation](#installation)
|
||||
- [Usage](#usage)
|
||||
- [Configuration](#configuration)
|
||||
- [Diagrams](#diagrams)
|
||||
- [References](#references)
|
||||
|
||||
## What is This?
|
||||
|
||||
A Transformer-based language model implementing autoregressive next-token prediction using multi-head self-attention, positional encoding, and modern training optimizations (mixed precision, gradient accumulation, KV caching).
|
||||
|
||||
The model learns to write by reading large amounts of text, discovering patterns like "after 'the cat' usually comes 'sat' or 'ran'", enabling it to generate coherent sentences. It processes text sequentially, predicting each next word based on the context provided by previous words.
|
||||
|
||||
## Features
|
||||
|
||||
- **Transformer Architecture**: Multi-head self-attention mechanism from "Attention Is All You Need"
|
||||
- **Long Context Support**: Efficient handling of long sequences with Rotary Positional Encoding (RoPE)
|
||||
- **Training Optimizations**: Mixed precision training, gradient accumulation, gradient clipping
|
||||
- **Modern Best Practices**: Pre-norm architecture, GELU activation, weight tying
|
||||
- **Comprehensive Evaluation**: Perplexity, accuracy metrics, and generation utilities
|
||||
|
||||
## Quick Start
|
||||
|
||||
**Option 1: Automated Setup (Recommended)**
|
||||
|
||||
```bash
|
||||
# Run the setup script - it handles everything automatically!
|
||||
./setup.sh
|
||||
|
||||
# Then activate the virtual environment
|
||||
source venv/bin/activate
|
||||
|
||||
# Download data and train
|
||||
python3 download_large_data.py wiki --version 103
|
||||
python3 train.py --data data/wikitext_103.txt --config config.json --device cuda
|
||||
```
|
||||
|
||||
**Option 2: Manual Setup**
|
||||
|
||||
```bash
|
||||
# 1. Create and activate virtual environment (REQUIRED on modern Linux systems)
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
|
||||
# 2. Upgrade pip
|
||||
pip install --upgrade pip
|
||||
|
||||
# 3. Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 4. Download data
|
||||
python3 download_large_data.py wiki --version 103
|
||||
|
||||
# 5. Train
|
||||
python3 train.py --data data/wikitext_103.txt --config config.json --device cuda
|
||||
|
||||
# 6. Generate
|
||||
python3 inference.py --checkpoint checkpoints/checkpoint_epoch_10.pt \
|
||||
--prompt "The future of artificial intelligence" --device cuda
|
||||
```
|
||||
|
||||
**Note:** On modern Debian/Ubuntu systems (especially Python 3.12+), pip prevents system-wide installations to protect system packages. Always use a virtual environment (`python3 -m venv venv`) before installing dependencies.
|
||||
|
||||
See [GETTING_STARTED.md](GETTING_STARTED.md) for detailed instructions.
|
||||
|
||||
## Mathematical Foundations
|
||||
|
||||
### 1. Token Embedding
|
||||
|
||||
Words are converted into numerical representations that the model can process. Each word (token) is assigned a unique ID, which is then mapped to a dense vector representation - think of it as converting words into a format the computer can understand and manipulate mathematically.
|
||||
|
||||
**Mathematical Formulation**:
|
||||
|
||||
Given a vocabulary of size $V$ and token ID $t \in \{0, 1, \ldots, V-1\}$:
|
||||
|
||||
$$
|
||||
\mathbf{E}_t = \text{EmbeddingTable}[t] \in \mathbb{R}^{d_{\text{model}}}
|
||||
$$
|
||||
|
||||
where $\mathbf{E} \in \mathbb{R}^{V \times d_{\text{model}}}$ is the learnable embedding matrix.
|
||||
|
||||
### 2. Positional Encoding
|
||||
|
||||
Word order is crucial in language - "Cat bites dog" is very different from "Dog bites cat". Since transformers process all tokens simultaneously, we need to inject positional information so the model knows which word comes first, second, third, etc.
|
||||
|
||||
**Mathematical Formulation**:
|
||||
|
||||
**Sinusoidal Positional Encoding** (for position $pos$ and dimension $i$):
|
||||
|
||||
$$
|
||||
PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)
|
||||
$$
|
||||
|
||||
$$
|
||||
PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)
|
||||
$$
|
||||
|
||||
The final embedding combines token and position:
|
||||
|
||||
$$
|
||||
\mathbf{h}_i = \mathbf{E}_{t_i} + PE(i)
|
||||
$$
|
||||
|
||||
where $t_i$ is the token at position $i$.
|
||||
|
||||
### 3. Multi-Head Self-Attention
|
||||
|
||||
Attention mechanisms allow the model to understand relationships between words in a sentence. When reading "The cat sat on the mat", the model learns to connect "cat" with "sat" because they're related. Multi-head attention enables the model to focus on different types of relationships simultaneously - syntax, semantics, and context.
|
||||
|
||||
**Mathematical Formulation**:
|
||||
|
||||
Given input $\mathbf{X} \in \mathbb{R}^{n \times d_{\text{model}}}$ with $n$ tokens:
|
||||
|
||||
1. **Project to Query, Key, Value**:
|
||||
|
||||
```math
|
||||
\mathbf{Q} = \mathbf{X}\mathbf{W}_Q, \quad \mathbf{K} = \mathbf{X}\mathbf{W}_K, \quad \mathbf{V} = \mathbf{X}\mathbf{W}_V
|
||||
|
||||
where $\mathbf{W}_Q, \mathbf{W}_K, \mathbf{W}_V \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}}$
|
||||
```
|
||||
|
||||
2. **Scaled Dot-Product Attention**:
|
||||
|
||||
```math
|
||||
\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}}\right)\mathbf{V}
|
||||
```
|
||||
|
||||
3. **Multi-Head Attention** splits into $h$ heads:
|
||||
|
||||
```math
|
||||
\text{head}_i = \text{Attention}(\mathbf{Q}_i, \mathbf{K}_i, \mathbf{V}_i)
|
||||
|
||||
\text{MultiHead}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)\mathbf{W}_O
|
||||
|
||||
where $d_k = d_{\text{model}} / h$ (each head has dimension $d_k$).
|
||||
```
|
||||
|
||||
4. **Causal Masking** (for autoregressive generation):
|
||||
|
||||
```math
|
||||
M_{ij} = \begin{cases}
|
||||
0 & \text{if } i \geq j \\
|
||||
-\infty & \text{if } i < j
|
||||
\end{cases}
|
||||
\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}, M) = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}} + M\right)\mathbf{V}
|
||||
```
|
||||
|
||||
### 4. Feed-Forward Network
|
||||
|
||||
After attention identifies which words relate to each other, the feed-forward network performs non-linear transformations on the information. This step allows the model to synthesize and combine the attended information, similar to mixing ingredients to create a final output.
|
||||
|
||||
**Mathematical Formulation**:
|
||||
|
||||
```math
|
||||
\text{FFN}(\mathbf{x}) = \text{GELU}(\mathbf{x}\mathbf{W}_1 + \mathbf{b}_1)\mathbf{W}_2 + \mathbf{b}_2
|
||||
|
||||
where $\mathbf{W}_1 \in \mathbb{R}^{d_{\text{model}} \times d_{ff}}$, $\mathbf{W}_2 \in \mathbb{R}^{d_{ff} \times d_{\text{model}}}$, and typically $d_{ff} = 4 \times d_{\text{model}}$.
|
||||
```
|
||||
|
||||
**GELU Activation**:
|
||||
|
||||
```math
|
||||
\text{GELU}(x) = x \cdot \Phi(x) = x \cdot \frac{1}{2}\left(1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right)
|
||||
|
||||
where $\Phi(x)$ is the cumulative distribution function of the standard normal distribution.
|
||||
```
|
||||
|
||||
### 5. Transformer Block (Pre-Norm Architecture)
|
||||
|
||||
**For Kids**: Imagine building with LEGO blocks. Each block does something special (attention or thinking), and you can stack many blocks on top of each other to build something amazing!
|
||||
|
||||
**Mathematical Formulation**:
|
||||
|
||||
For a transformer block with input $\mathbf{x}$:
|
||||
|
||||
1. **Self-Attention Sublayer**:
|
||||
|
||||
```math
|
||||
\mathbf{x}' = \mathbf{x} + \text{Dropout}\left(\text{MultiHead}(\text{LayerNorm}(\mathbf{x}))\right)
|
||||
```
|
||||
|
||||
2. **Feed-Forward Sublayer**:
|
||||
|
||||
```math
|
||||
\mathbf{x}'' = \mathbf{x}' + \text{Dropout}\left(\text{FFN}(\text{LayerNorm}(\mathbf{x}'))\right)
|
||||
```
|
||||
|
||||
**Layer Normalization**:
|
||||
|
||||
```math
|
||||
\text{LayerNorm}(\mathbf{x}) = \gamma \odot \frac{\mathbf{x} - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta
|
||||
|
||||
where $\mu = \frac{1}{d}\sum_{i=1}^d x_i$, $\sigma^2 = \frac{1}{d}\sum_{i=1}^d (x_i - \mu)^2$, and $\gamma, \beta$ are learnable parameters.
|
||||
```
|
||||
|
||||
### 6. Complete Forward Pass
|
||||
|
||||
**For Kids**: Think of it like an assembly line! The words go through many stations (layers), each doing something special, until finally we get a prediction of what word comes next.
|
||||
|
||||
**Mathematical Formulation**:
|
||||
|
||||
Given input token sequence $\mathbf{t} = [t_1, t_2, \ldots, t_n]$:
|
||||
|
||||
$$
|
||||
\mathbf{h}_0 = \mathbf{E}[\mathbf{t}] + PE(\mathbf{t}) \quad \text{(token embeddings + positional encoding)}
|
||||
$$
|
||||
|
||||
$$
|
||||
\mathbf{h}_1 = \text{TransformerBlock}_1(\mathbf{h}_0)
|
||||
$$
|
||||
|
||||
$$
|
||||
\mathbf{h}_2 = \text{TransformerBlock}_2(\mathbf{h}_1)
|
||||
$$
|
||||
|
||||
$$
|
||||
\vdots
|
||||
$$
|
||||
|
||||
$$
|
||||
\mathbf{h}_L = \text{TransformerBlock}_L(\mathbf{h}_{L-1})
|
||||
$$
|
||||
|
||||
$$
|
||||
\mathbf{o} = \text{LayerNorm}(\mathbf{h}_L) \quad \text{(final normalization)}
|
||||
$$
|
||||
|
||||
$$
|
||||
\mathbf{y} = \mathbf{o}\mathbf{W}_{\text{out}} \quad \text{(output projection)}
|
||||
$$
|
||||
|
||||
$$
|
||||
p(t_{n+1} | t_1, \ldots, t_n) = \text{softmax}(\mathbf{y}_n) \quad \text{(next token probability)}
|
||||
$$
|
||||
|
||||
### 7. Training Objective (Cross-Entropy Loss)
|
||||
|
||||
**For Kids**: When we train, we show the computer a sentence and ask "what word comes next?" If it guesses wrong, we say "try again!" and it learns from its mistake.
|
||||
|
||||
**Mathematical Formulation**:
|
||||
|
||||
For a sequence of tokens $\mathbf{t} = [t_1, t_2, \ldots, t_n]$:
|
||||
|
||||
$$
|
||||
\mathcal{L} = -\frac{1}{n-1}\sum_{i=1}^{n-1} \log p(t_{i+1} | t_1, \ldots, t_i)
|
||||
$$
|
||||
|
||||
**Perplexity** (measure of model confidence):
|
||||
|
||||
$$
|
||||
\text{Perplexity} = \exp(\mathcal{L}) = \exp\left(-\frac{1}{n-1}\sum_{i=1}^{n-1} \log p(t_{i+1} | t_1, \ldots, t_i)\right)
|
||||
$$
|
||||
|
||||
Lower perplexity = better model!
|
||||
|
||||
### 8. Text Generation (Autoregressive Sampling)
|
||||
|
||||
**For Kids**: The computer writes one word at a time. After each word, it thinks "what word makes sense next?" and picks one.
|
||||
|
||||
**Mathematical Formulation**:
|
||||
|
||||
Given prompt $\mathbf{p} = [p_1, \ldots, p_k]$:
|
||||
|
||||
1. Initialize: $\mathbf{s} = \mathbf{p}$
|
||||
2. For $i = k+1, \ldots, k+n$:
|
||||
|
||||
$$
|
||||
\mathbf{h}_i = \text{Transformer}(\mathbf{s})
|
||||
$$
|
||||
|
||||
$$
|
||||
\mathbf{p}_i = \text{softmax}(\mathbf{h}_i / \tau) \quad \text{(with temperature $\tau \geq 1$)}
|
||||
$$
|
||||
|
||||
$$
|
||||
t_i \sim \text{Sample}(\mathbf{p}_i) \quad \text{(sample next token)}
|
||||
$$
|
||||
|
||||
$$
|
||||
\mathbf{s} = \mathbf{s} \cup \{t_i\} \quad \text{(append to sequence)}
|
||||
$$
|
||||
|
||||
**Top-k Sampling**:
|
||||
|
||||
$$
|
||||
\mathbf{p}_i' = \begin{cases}
|
||||
\mathbf{p}_i[j] & \text{if } j \in \text{top}_k(\mathbf{p}_i) \\
|
||||
0 & \text{otherwise}
|
||||
\end{cases}
|
||||
$$
|
||||
|
||||
**Top-p (Nucleus) Sampling**:
|
||||
Find smallest set $S$ such that $\sum_{j \in S} \mathbf{p}_i[j] \geq p$, then:
|
||||
|
||||
$$
|
||||
\mathbf{p}_i'[j] = \begin{cases}
|
||||
\mathbf{p}_i[j] & \text{if } j \in S \\
|
||||
0 & \text{otherwise}
|
||||
\end{cases}
|
||||
$$
|
||||
|
||||
### 9. Optimization (AdamW)
|
||||
|
||||
**For Kids**: Learning is like climbing a mountain. You take small steps in the right direction. AdamW is like having a smart guide that knows which direction is best and adjusts your step size automatically!
|
||||
|
||||
**Mathematical Formulation**:
|
||||
|
||||
For parameter $\theta$ with gradient $\mathbf{g}_t$ at step $t$:
|
||||
|
||||
**Momentum**:
|
||||
|
||||
$$
|
||||
\mathbf{m}_t = \beta_1 \mathbf{m}_{t-1} + (1 - \beta_1) \mathbf{g}_t
|
||||
$$
|
||||
|
||||
**RMSprop**:
|
||||
|
||||
$$
|
||||
\mathbf{v}_t = \beta_2 \mathbf{v}_{t-1} + (1 - \beta_2) \mathbf{g}_t^2
|
||||
$$
|
||||
|
||||
**Bias Correction**:
|
||||
|
||||
$$
|
||||
\hat{\mathbf{m}}_t = \frac{\mathbf{m}_t}{1 - \beta_1^t}, \quad \hat{\mathbf{v}}_t = \frac{\mathbf{v}_t}{1 - \beta_2^t}
|
||||
$$
|
||||
|
||||
**Parameter Update**:
|
||||
|
||||
$$
|
||||
\theta_t = \theta_{t-1} - \eta \left(\frac{\hat{\mathbf{m}}_t}{\sqrt{\hat{\mathbf{v}}_t} + \epsilon} + \lambda \theta_{t-1}\right)
|
||||
$$
|
||||
|
||||
where $\eta$ is learning rate, $\lambda$ is weight decay, and typically $\beta_1 = 0.9$, $\beta_2 = 0.999$.
|
||||
|
||||
### 10. Gradient Clipping
|
||||
|
||||
**For Kids**: Sometimes the computer gets too excited and tries to learn too fast (like running too fast down a hill). We slow it down so it doesn't fall!
|
||||
|
||||
**Mathematical Formulation**:
|
||||
|
||||
$$
|
||||
\mathbf{g}_{\text{clipped}} = \begin{cases}
|
||||
\mathbf{g} & \text{if } \|\mathbf{g}\| \leq \theta_{\max} \\
|
||||
\mathbf{g} \cdot \frac{\theta_{\max}}{\|\mathbf{g}\|} & \text{if } \|\mathbf{g}\| > \theta_{\max}
|
||||
\end{cases}
|
||||
$$
|
||||
|
||||
where $\theta_{\max}$ is the maximum gradient norm (typically 1.0).
|
||||
|
||||
## Architecture Explained
|
||||
|
||||
### High-Level Overview
|
||||
|
||||
**For Kids**:
|
||||
|
||||
```
|
||||
📚 Text Input → 🔤 Turn into Numbers → 🧠 Thinking Layers → ✨ Predict Next Word → 📝 Generate Text
|
||||
```
|
||||
|
||||
**For Scientists**:
|
||||
|
||||
```
|
||||
Token IDs → Embeddings → Positional Encoding → N× Transformer Blocks → Output Projection → Logits → Sampling → Text
|
||||
```
|
||||
|
||||
### Detailed Architecture
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ INPUT TEXT │
|
||||
│ "The cat sat on the mat" │
|
||||
└──────────────────────────┬──────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ TOKENIZATION │
|
||||
│ [1, 45, 123, 67, 45, 234] (each word → number) │
|
||||
└──────────────────────────┬──────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ TOKEN EMBEDDING │
|
||||
│ E ∈ ℝ^(V×d_model) | Each token → vector of size d_model │
|
||||
│ [1, 45, ...] → [[0.1, 0.3, ...], [0.2, -0.1, ...], ...] │
|
||||
└──────────────────────────┬──────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ POSITIONAL ENCODING │
|
||||
│ PE(pos, 2i) = sin(pos / 10000^(2i/d_model)) │
|
||||
│ h_i = E[t_i] + PE(i) (add position info) │
|
||||
└──────────────────────────┬──────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ TRANSFORMER BLOCK 1 │
|
||||
│ ┌──────────────────────────────────────────────┐ │
|
||||
│ │ Pre-Norm Multi-Head Attention │ │
|
||||
│ │ Attention(Q, K, V) = softmax(QK^T/√d_k)V │ │
|
||||
│ │ + Residual Connection │ │
|
||||
│ └──────────────────────────────────────────────┘ │
|
||||
│ ┌──────────────────────────────────────────────┐ │
|
||||
│ │ Pre-Norm Feed-Forward Network │ │
|
||||
│ │ FFN(x) = GELU(xW₁ + b₁)W₂ + b₂ │ │
|
||||
│ │ + Residual Connection │ │
|
||||
│ └──────────────────────────────────────────────┘ │
|
||||
└──────────────────────────┬──────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ TRANSFORMER BLOCK 2 │
|
||||
│ (Same structure as Block 1) │
|
||||
└──────────────────────────┬──────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
... (N-2 more blocks) ...
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ TRANSFORMER BLOCK N │
|
||||
│ (Same structure) │
|
||||
└──────────────────────────┬──────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ FINAL LAYER NORM │
|
||||
│ LayerNorm(h_L) │
|
||||
└──────────────────────────┬──────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ OUTPUT PROJECTION │
|
||||
│ y = h_L · W_out (W_out ∈ ℝ^(d_model × V)) │
|
||||
│ Output: logits for each token in vocabulary │
|
||||
└──────────────────────────┬──────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ SOFTMAX │
|
||||
│ p(t | context) = softmax(y) │
|
||||
│ Probability distribution over vocabulary │
|
||||
└──────────────────────────┬──────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ SAMPLING │
|
||||
│ t_{next} ~ Sample(p) (with temperature, top-k, top-p) │
|
||||
└──────────────────────────┬──────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ OUTPUT TEXT │
|
||||
│ "The cat sat on the mat and..." │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
sheepOp/
|
||||
├── models/ # Model architectures
|
||||
│ ├── __init__.py # Module exports
|
||||
│ ├── transformer.py # Main transformer model
|
||||
│ ├── attention.py # Attention mechanisms
|
||||
│ ├── blocks.py # Building blocks (FFN, TransformerBlock)
|
||||
│ ├── optimized_attention.py # KV caching, optimized inference
|
||||
│ └── prefetching.py # Data prefetching utilities
|
||||
├── data/ # Data loading utilities
|
||||
│ └── __init__.py # Dataset and tokenizer
|
||||
├── training/ # Training utilities
|
||||
│ ├── __init__.py # Trainer class
|
||||
│ └── metrics.py # Training metrics and plotting
|
||||
├── config.py # Configuration management
|
||||
├── config.json # Example configuration file
|
||||
├── train.py # Training script
|
||||
├── inference.py # Inference script
|
||||
├── example.py # Usage examples
|
||||
├── utils.py # Evaluation utilities
|
||||
└── requirements.txt # Dependencies
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
**Important:** On modern Debian/Ubuntu systems (Python 3.12+), you must use a virtual environment. The system prevents system-wide pip installations to protect system packages.
|
||||
|
||||
```bash
|
||||
# 1. Create virtual environment
|
||||
python3 -m venv venv
|
||||
|
||||
# 2. Activate virtual environment
|
||||
source venv/bin/activate
|
||||
|
||||
# 3. Upgrade pip
|
||||
pip install --upgrade pip
|
||||
|
||||
# 4. Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 5. For large dataset downloads (optional)
|
||||
pip install datasets
|
||||
```
|
||||
|
||||
**Alternative:** Use the automated setup script which handles everything:
|
||||
```bash
|
||||
./setup.sh
|
||||
source venv/bin/activate
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Training
|
||||
|
||||
```bash
|
||||
python3 train.py \
|
||||
--data data/amazon_reviews.txt \
|
||||
--config config.json \
|
||||
--device cuda
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
```bash
|
||||
python3 inference.py \
|
||||
--checkpoint checkpoints/checkpoint_epoch_10.pt \
|
||||
--prompt "The future of artificial intelligence" \
|
||||
--optimized \
|
||||
--device cuda
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
See `config.json` for all available settings. Key parameters:
|
||||
|
||||
- **Model**: `vocab_size`, `d_model`, `num_layers`, `num_heads`
|
||||
- **Training**: `batch_size`, `learning_rate`, `max_epochs`
|
||||
- **Data**: `max_length`, `data_dir`
|
||||
|
||||
## Diagrams
|
||||
|
||||
### Complete Model Architecture Flow
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Input Processing"
|
||||
A[Text Input] --> B[Tokenization]
|
||||
B --> C[Token Embedding<br/>E ∈ ℝ^V×d]
|
||||
C --> D[Positional Encoding<br/>PE pos,2i = sin pos/10000^2i/d]
|
||||
D --> E[Embedding + Position<br/>h = E + PE]
|
||||
end
|
||||
|
||||
subgraph "Transformer Stack"
|
||||
E --> F[Transformer Block 1]
|
||||
F --> G[Transformer Block 2]
|
||||
G --> H[...]
|
||||
H --> I[Transformer Block N]
|
||||
end
|
||||
|
||||
subgraph "Transformer Block Detail"
|
||||
F --> F1[Layer Norm 1]
|
||||
F1 --> F2[Multi-Head Attention<br/>QK^T/√d_k → softmax → V]
|
||||
F2 --> F3[Residual + Dropout]
|
||||
F --> F3
|
||||
F3 --> F4[Layer Norm 2]
|
||||
F4 --> F5[Feed-Forward<br/>GELU xW₁ + b₁W₂ + b₂]
|
||||
F5 --> F6[Residual + Dropout]
|
||||
F3 --> F6
|
||||
end
|
||||
|
||||
subgraph "Output Processing"
|
||||
I --> J[Final Layer Norm]
|
||||
J --> K[Output Projection<br/>y = hW_out]
|
||||
K --> L[Softmax<br/>p = softmax y]
|
||||
L --> M[Sampling<br/>t ~ p with temp, top-k, top-p]
|
||||
M --> N[Generated Text]
|
||||
end
|
||||
|
||||
style A fill:#e1f5ff
|
||||
style N fill:#fff4e1
|
||||
style F fill:#e8f5e9
|
||||
style F2 fill:#ffe1f5
|
||||
style F5 fill:#ffe1f5
|
||||
```
|
||||
|
||||
### Attention Mechanism Visualization
|
||||
|
||||
```mermaid
|
||||
graph LR
|
||||
subgraph "Input"
|
||||
X[Input Tokens<br/>x₁, x₂, ..., xₙ]
|
||||
end
|
||||
|
||||
subgraph "Q, K, V Projections"
|
||||
X --> Q[Query Q<br/>Q = XW_Q]
|
||||
X --> K[Key K<br/>K = XW_K]
|
||||
X --> V[Value V<br/>V = XW_V]
|
||||
end
|
||||
|
||||
subgraph "Attention Computation"
|
||||
Q --> M[Matrix Multiply<br/>QK^T]
|
||||
K --> M
|
||||
M --> S[Scale<br/>÷√d_k]
|
||||
S --> Mask{Causal Mask?}
|
||||
Mask -->|Yes| CM[Mask<br/>M_ij = -∞ if i < j]
|
||||
Mask -->|No| SM[Skip Mask]
|
||||
CM --> Soft[Softmax<br/>exp scores / Σexp]
|
||||
SM --> Soft
|
||||
Soft --> WV[Weighted Sum<br/>Attention = softmax scores × V]
|
||||
V --> WV
|
||||
end
|
||||
|
||||
subgraph "Multi-Head"
|
||||
WV --> Split[Split into h heads]
|
||||
Split --> H1[Head 1]
|
||||
Split --> H2[Head 2]
|
||||
Split --> H3[...]
|
||||
Split --> Hh[Head h]
|
||||
H1 --> Concat[Concatenate]
|
||||
H2 --> Concat
|
||||
H3 --> Concat
|
||||
Hh --> Concat
|
||||
Concat --> Out[Output Projection<br/>WO]
|
||||
end
|
||||
|
||||
style X fill:#e1f5ff
|
||||
style Out fill:#fff4e1
|
||||
style Soft fill:#ffe1f5
|
||||
style WV fill:#e8f5e9
|
||||
```
|
||||
|
||||
### Training Loop Flow
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[Start Training] --> B[Load Data]
|
||||
B --> C[Create Tokenizer]
|
||||
C --> D[Build Vocabulary]
|
||||
D --> E[Create DataLoader]
|
||||
E --> F[Initialize Model]
|
||||
F --> G[Setup Optimizer<br/>AdamW]
|
||||
G --> H[Setup LR Scheduler<br/>Cosine Annealing]
|
||||
|
||||
H --> I[For Each Epoch]
|
||||
I --> J[For Each Batch]
|
||||
|
||||
J --> K[Forward Pass<br/>y = Model x]
|
||||
K --> L[Compute Loss<br/>L = -log p t_next]
|
||||
L --> M[Backward Pass<br/>∇L]
|
||||
M --> N[Gradient Accumulation]
|
||||
N --> O{Accumulation<br/>Complete?}
|
||||
O -->|No| J
|
||||
O -->|Yes| P[Gradient Clipping<br/>Clip gradient norm]
|
||||
P --> Q[Optimizer Step<br/>θ = θ - η∇L]
|
||||
Q --> R[LR Scheduler Step]
|
||||
R --> S{End of<br/>Epoch?}
|
||||
S -->|No| J
|
||||
S -->|Yes| T[Evaluate on<br/>Validation Set]
|
||||
T --> U[Compute Perplexity<br/>exp L]
|
||||
U --> V{Best<br/>Model?}
|
||||
V -->|Yes| W[Save Checkpoint]
|
||||
V -->|No| X[Save Regular Checkpoint]
|
||||
W --> Y{More<br/>Epochs?}
|
||||
X --> Y
|
||||
Y -->|Yes| I
|
||||
Y -->|No| Z[Training Complete]
|
||||
|
||||
style A fill:#e1f5ff
|
||||
style Z fill:#fff4e1
|
||||
style K fill:#e8f5e9
|
||||
style L fill:#ffe1f5
|
||||
style P fill:#fff4e1
|
||||
```
|
||||
|
||||
### Inference Flow with Sampling
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[Load Checkpoint] --> B[Initialize Model]
|
||||
B --> C[Set Eval Mode]
|
||||
C --> D[Input Prompt]
|
||||
D --> E[Tokenize<br/>t = t1, t2, ..., tk]
|
||||
E --> F[Encode<br/>h = E t + PE]
|
||||
|
||||
F --> G[Forward Pass<br/>y = Model h]
|
||||
G --> H[Get Logits<br/>y ∈ ℝ^V]
|
||||
H --> I[Apply Temperature<br/>p = softmax y/τ]
|
||||
|
||||
I --> J{Top-k<br/>Filter?}
|
||||
J -->|Yes| K[Keep Top-k<br/>p' = filter p, k]
|
||||
J -->|No| L[p' = p]
|
||||
K --> M{Top-p<br/>Filter?}
|
||||
L --> M
|
||||
M -->|Yes| N[Nucleus Sampling<br/>p' = filter p', p]
|
||||
M -->|No| O[p'' = p']
|
||||
N --> O
|
||||
|
||||
O --> P[Sample Token<br/>t_i ~ p'']
|
||||
P --> Q[Append to Sequence<br/>s = s ∪ t_i]
|
||||
Q --> R{Max Length<br/>Reached?}
|
||||
R -->|No| G
|
||||
R -->|Yes| S[Decode Tokens<br/>text = decode s]
|
||||
S --> T[Output Text]
|
||||
|
||||
style A fill:#e1f5ff
|
||||
style T fill:#fff4e1
|
||||
style G fill:#e8f5e9
|
||||
style P fill:#ffe1f5
|
||||
```
|
||||
|
||||
### Component Interaction
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Configuration"
|
||||
CFG[config.py<br/>Config Classes]
|
||||
CFGJSON[config.json<br/>JSON Settings]
|
||||
end
|
||||
|
||||
subgraph "Models"
|
||||
TRANS[transformer.py<br/>TransformerModel]
|
||||
ATT[attention.py<br/>MultiHeadAttention]
|
||||
BLOCKS[blocks.py<br/>TransformerBlock, FFN]
|
||||
OPT[optimized_attention.py<br/>KV Cache, Optimized Inference]
|
||||
end
|
||||
|
||||
subgraph "Data"
|
||||
DATA[data/__init__.py<br/>TextDataset, SimpleTokenizer]
|
||||
TOKEN[SimpleTokenizer<br/>encode/decode]
|
||||
DATASET[TextDataset<br/>PyTorch Dataset]
|
||||
end
|
||||
|
||||
subgraph "Training"
|
||||
TRAIN[training/__init__.py<br/>Trainer]
|
||||
METRICS[training/metrics.py<br/>TrainingMetrics]
|
||||
end
|
||||
|
||||
subgraph "Scripts"
|
||||
TRAIN_SCRIPT[train.py<br/>Training Entry Point]
|
||||
INFER[inference.py<br/>Generation Script]
|
||||
EX[example.py<br/>Usage Examples]
|
||||
end
|
||||
|
||||
subgraph "Utils"
|
||||
UTILS[utils.py<br/>Evaluation Functions]
|
||||
end
|
||||
|
||||
CFG --> TRAIN_SCRIPT
|
||||
CFGJSON --> TRAIN_SCRIPT
|
||||
TRANS --> TRAIN_SCRIPT
|
||||
TRANS --> INFER
|
||||
ATT --> TRANS
|
||||
BLOCKS --> TRANS
|
||||
OPT --> TRANS
|
||||
DATA --> TRAIN_SCRIPT
|
||||
DATA --> INFER
|
||||
TOKEN --> DATASET
|
||||
DATASET --> TRAINER
|
||||
TRAIN --> TRAIN_SCRIPT
|
||||
METRICS --> TRAIN
|
||||
UTILS --> TRAIN_SCRIPT
|
||||
|
||||
style TRANS fill:#e1f5ff
|
||||
style TRAINER fill:#fff4e1
|
||||
style DATA fill:#e8f5e9
|
||||
style ATT fill:#ffe1f5
|
||||
```
|
||||
|
||||
## Mathematical Summary Table
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th>Concept</th>
|
||||
<th>Formula</th>
|
||||
<th>Description</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Token Embedding</strong></td>
|
||||
<td>$$\mathbf{E}_t = \text{EmbeddingTable}[t]$$</td>
|
||||
<td>Maps token ID to vector</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Positional Encoding</strong></td>
|
||||
<td>$$PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d}}\right)$$</td>
|
||||
<td>Adds position information</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Attention</strong></td>
|
||||
<td>$$\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$</td>
|
||||
<td>Computes attention weights</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Feed-Forward</strong></td>
|
||||
<td>$$\text{FFN}(x) = \text{GELU}(xW_1 + b_1)W_2 + b_2$$</td>
|
||||
<td>Non-linear transformation</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Layer Norm</strong></td>
|
||||
<td>$$\text{LayerNorm}(x) = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$$</td>
|
||||
<td>Normalizes activations</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Loss</strong></td>
|
||||
<td>$$\mathcal{L} = -\frac{1}{n}\sum_{i=1}^{n} \log p(t_{i+1} \mid t_1, \ldots, t_i)$$</td>
|
||||
<td>Cross-entropy loss</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Perplexity</strong></td>
|
||||
<td>$$\text{Perplexity} = \exp(\mathcal{L})$$</td>
|
||||
<td>Measure of model confidence</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>AdamW Update</strong></td>
|
||||
<td>$$\theta_t = \theta_{t-1} - \eta\left(\frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda\theta_{t-1}\right)$$</td>
|
||||
<td>Optimizer step</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## References
|
||||
|
||||
- **Attention Is All You Need** (Vaswani et al., 2017) - Original Transformer paper
|
||||
- **Optimizing LLM Inference and Retrieval** - Production RAG systems optimizations
|
||||
- **RoPE: Rotary Position Embedding** - Efficient positional encoding for long sequences
|
||||
- Various papers on LLM training, hallucinations, and long context handling
|
||||
|
||||
---
|
||||
|
||||
**For Kids**: You've learned how computers can read and write! Just like you practice writing stories, computers practice by reading millions of books. The more they practice, the better they get! 🎉
|
||||
|
||||
**For Scientists**: This implementation follows modern best practices including pre-norm architecture, weight tying, mixed precision training, and efficient inference optimizations. The codebase is modular, well-documented, and optimized for both training and production deployment.
|
||||
2574
docs/CONTROL_SYSTEM_MODEL.md
Normal file
2574
docs/CONTROL_SYSTEM_MODEL.md
Normal file
File diff suppressed because it is too large
Load Diff
217
docs/DATABASE_EXTRACTION_GUIDE.md
Normal file
217
docs/DATABASE_EXTRACTION_GUIDE.md
Normal file
@@ -0,0 +1,217 @@
|
||||
# Database Extraction Guide
|
||||
|
||||
This guide shows you how to extract text from your 1TB database for training.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### SQLite Database
|
||||
|
||||
```bash
|
||||
# Extract from SQLite database
|
||||
python3 extract_from_database.py \
|
||||
--type sqlite \
|
||||
--db-path /path/to/your/database.db \
|
||||
--table your_table_name \
|
||||
--column text_column_name \
|
||||
--output data/database_training.txt \
|
||||
--limit 1000000 # Limit to 1M samples (or omit for all)
|
||||
```
|
||||
|
||||
### PostgreSQL Database
|
||||
|
||||
```bash
|
||||
# Install PostgreSQL driver first
|
||||
pip install psycopg2-binary
|
||||
|
||||
# Extract with SQL query
|
||||
python3 extract_from_database.py \
|
||||
--type sql \
|
||||
--connection "host=localhost dbname=mydb user=myuser password=mypass" \
|
||||
--query "SELECT text_column FROM your_table WHERE length(text_column) > 50" \
|
||||
--output data/database_training.txt \
|
||||
--limit 1000000
|
||||
```
|
||||
|
||||
### MySQL Database
|
||||
|
||||
```bash
|
||||
# Install MySQL driver first
|
||||
pip install pymysql
|
||||
|
||||
# Extract with SQL query
|
||||
python3 extract_from_database.py \
|
||||
--type sql \
|
||||
--connection "mysql+pymysql://user:pass@localhost/dbname" \
|
||||
--query "SELECT text_column FROM your_table" \
|
||||
--output data/database_training.txt
|
||||
```
|
||||
|
||||
### JSON/JSONL Files
|
||||
|
||||
```bash
|
||||
# Extract from JSON Lines file
|
||||
python3 extract_from_database.py \
|
||||
--type json \
|
||||
--json-path /path/to/data.jsonl \
|
||||
--text-field content \
|
||||
--output data/database_training.txt \
|
||||
--limit 1000000
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
### Example 1: Extract All Text from SQLite Table
|
||||
|
||||
```bash
|
||||
python3 extract_from_database.py \
|
||||
--type sqlite \
|
||||
--db-path /Volumes/YourDisk/database.db \
|
||||
--table articles \
|
||||
--column body_text \
|
||||
--output data/training_data.txt
|
||||
```
|
||||
|
||||
### Example 2: Extract Filtered Data (Longer Texts Only)
|
||||
|
||||
```bash
|
||||
python3 extract_from_database.py \
|
||||
--type sqlite \
|
||||
--db-path /Volumes/YourDisk/database.db \
|
||||
--table articles \
|
||||
--column body_text \
|
||||
--where "WHERE length(body_text) > 200" \
|
||||
--output data/training_data.txt \
|
||||
--min-length 50
|
||||
```
|
||||
|
||||
### Example 3: Extract from Multiple Tables
|
||||
|
||||
```bash
|
||||
# Extract from table 1
|
||||
python3 extract_from_database.py \
|
||||
--type sqlite \
|
||||
--db-path /Volumes/YourDisk/database.db \
|
||||
--table articles \
|
||||
--column content \
|
||||
--output data/articles.txt
|
||||
|
||||
# Extract from table 2
|
||||
python3 extract_from_database.py \
|
||||
--type sqlite \
|
||||
--db-path /Volumes/YourDisk/database.db \
|
||||
--table comments \
|
||||
--column text \
|
||||
--output data/comments.txt
|
||||
|
||||
# Combine files
|
||||
cat data/articles.txt data/comments.txt > data/combined_training.txt
|
||||
```
|
||||
|
||||
### Example 4: PostgreSQL with Complex Query
|
||||
|
||||
```bash
|
||||
python3 extract_from_database.py \
|
||||
--type sql \
|
||||
--connection "host=localhost dbname=mydb user=myuser password=mypass" \
|
||||
--query "SELECT description FROM products WHERE description IS NOT NULL AND length(description) > 100 UNION SELECT review_text FROM reviews WHERE review_text IS NOT NULL" \
|
||||
--output data/products_and_reviews.txt
|
||||
```
|
||||
|
||||
## Options
|
||||
|
||||
### Filtering Options
|
||||
|
||||
```bash
|
||||
# Only extract texts longer than 100 characters
|
||||
--min-length 100
|
||||
|
||||
# Limit total samples
|
||||
--limit 1000000
|
||||
|
||||
# Add WHERE clause (SQLite)
|
||||
--where "WHERE created_at > '2024-01-01' AND length(text) > 200"
|
||||
```
|
||||
|
||||
### Output Options
|
||||
|
||||
```bash
|
||||
# Custom output path
|
||||
--output data/my_training_data.txt
|
||||
|
||||
# Don't clean/split text (preserve original format)
|
||||
--no-clean
|
||||
```
|
||||
|
||||
## Performance Tips
|
||||
|
||||
1. **Use LIMIT for Testing**: Start with `--limit 10000` to test
|
||||
2. **Filter in Database**: Use `--where` clause to filter at database level (faster)
|
||||
3. **Batch Processing**: The script processes in batches automatically
|
||||
4. **Monitor Progress**: Progress updates every 1000 texts
|
||||
|
||||
## Data Format
|
||||
|
||||
The output file will have:
|
||||
- One text sample per line
|
||||
- Cleaned and split into sentences
|
||||
- Minimum length filtering applied
|
||||
- UTF-8 encoding
|
||||
|
||||
## Next Steps
|
||||
|
||||
After extraction:
|
||||
|
||||
```bash
|
||||
# Check how much data you extracted
|
||||
wc -l data/database_training.txt
|
||||
|
||||
# Train with the extracted data
|
||||
python3 train.py --data data/database_training.txt --config config.json --device mps
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### SQLite Database Locked
|
||||
- Close any applications using the database
|
||||
- Copy database to a local location first
|
||||
|
||||
### Large Database (1TB)
|
||||
- Use `--limit` to extract in batches
|
||||
- Use `--where` to filter at database level
|
||||
- Consider extracting to multiple files and combining
|
||||
|
||||
### Memory Issues
|
||||
- The script processes in batches (streaming)
|
||||
- Use `--limit` to control size
|
||||
- Process in chunks if needed
|
||||
|
||||
## Example Workflow
|
||||
|
||||
```bash
|
||||
# 1. Extract 1M samples for testing
|
||||
python3 extract_from_database.py \
|
||||
--type sqlite \
|
||||
--db-path /Volumes/YourDisk/database.db \
|
||||
--table your_table \
|
||||
--column text_column \
|
||||
--output data/test_extraction.txt \
|
||||
--limit 1000000
|
||||
|
||||
# 2. Check the data
|
||||
head -20 data/test_extraction.txt
|
||||
wc -l data/test_extraction.txt
|
||||
|
||||
# 3. If good, extract more (or all)
|
||||
python3 extract_from_database.py \
|
||||
--type sqlite \
|
||||
--db-path /Volumes/YourDisk/database.db \
|
||||
--table your_table \
|
||||
--column text_column \
|
||||
--output data/full_training.txt
|
||||
|
||||
# 4. Train with the data
|
||||
python3 train.py --data data/full_training.txt --config config.json --device mps
|
||||
```
|
||||
|
||||
Good luck extracting your 1TB database! 🚀
|
||||
|
||||
225
docs/DATA_GUIDE.md
Normal file
225
docs/DATA_GUIDE.md
Normal file
@@ -0,0 +1,225 @@
|
||||
# Data Collection Guide
|
||||
|
||||
This guide shows you how to get training data from the internet or create your own data.txt file.
|
||||
|
||||
## Option 1: Use the Download Script
|
||||
|
||||
### Quick Start
|
||||
|
||||
```bash
|
||||
# Download Shakespeare text (recommended for testing)
|
||||
python download_data.py --type shakespeare
|
||||
|
||||
# Create a sample data file
|
||||
python download_data.py --type sample --output data/my_data.txt --samples 200
|
||||
|
||||
# Download Wikipedia article (requires: pip install wikipedia)
|
||||
python download_data.py --type wikipedia --title "Artificial Intelligence" --output data/ai_article.txt
|
||||
```
|
||||
|
||||
### Available Options
|
||||
|
||||
**Shakespeare Dataset:**
|
||||
```bash
|
||||
python download_data.py --type shakespeare
|
||||
```
|
||||
Downloads classic Shakespeare text - great for testing!
|
||||
|
||||
**Create Sample Data:**
|
||||
```bash
|
||||
python download_data.py --type sample --output data/my_data.txt --samples 100
|
||||
```
|
||||
Creates a file with sample sentences about ML/AI.
|
||||
|
||||
**Wikipedia Article:**
|
||||
```bash
|
||||
python download_data.py --type wikipedia --title "Machine Learning" --output data/ml_article.txt
|
||||
```
|
||||
Downloads a Wikipedia article (requires `pip install wikipedia`).
|
||||
|
||||
## Option 2: Manual Data Collection
|
||||
|
||||
### Method A: Create Your Own data.txt
|
||||
|
||||
1. **Create a text file:**
|
||||
```bash
|
||||
nano data/my_data.txt
|
||||
# or
|
||||
vim data/my_data.txt
|
||||
```
|
||||
|
||||
2. **Add your text** (one sentence per line):
|
||||
```
|
||||
This is my first training sample.
|
||||
This is my second training sample.
|
||||
Add as many lines as you want.
|
||||
```
|
||||
|
||||
3. **Save and use:**
|
||||
```bash
|
||||
python train.py --data data/my_data.txt
|
||||
```
|
||||
|
||||
### Method B: Download from Public Datasets
|
||||
|
||||
**Shakespeare Text:**
|
||||
```bash
|
||||
curl -o data/shakespeare.txt https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
|
||||
```
|
||||
|
||||
**Book Corpus Sample:**
|
||||
```bash
|
||||
# Download Project Gutenberg books
|
||||
curl -o data/book.txt https://www.gutenberg.org/files/1342/1342-0.txt # Pride and Prejudice
|
||||
```
|
||||
|
||||
**News Articles:**
|
||||
```bash
|
||||
# Download news text
|
||||
curl -o data/news.txt https://raw.githubusercontent.com/sunnysai12345/News_Summary/master/news_summary_more.csv
|
||||
```
|
||||
|
||||
### Method C: Scrape Your Own Data
|
||||
|
||||
**From Wikipedia (Python):**
|
||||
```python
|
||||
import wikipedia
|
||||
|
||||
page = wikipedia.page("Machine Learning")
|
||||
with open("data/ml_article.txt", "w") as f:
|
||||
f.write(page.content)
|
||||
```
|
||||
|
||||
**From a Website:**
|
||||
```python
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
url = "https://example.com/article"
|
||||
response = requests.get(url)
|
||||
soup = BeautifulSoup(response.text, 'html.parser')
|
||||
text = soup.get_text()
|
||||
|
||||
with open("data/scraped.txt", "w") as f:
|
||||
f.write(text)
|
||||
```
|
||||
|
||||
## Option 3: Use Existing Datasets
|
||||
|
||||
### Popular NLP Datasets
|
||||
|
||||
**WikiText-2:**
|
||||
```bash
|
||||
# Download WikiText-2
|
||||
wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip
|
||||
unzip wikitext-2-v1.zip
|
||||
# Use: wikitext-2/wiki.train.tokens
|
||||
```
|
||||
|
||||
**OpenWebText Sample:**
|
||||
```bash
|
||||
# Download sample
|
||||
curl -o data/openwebtext_sample.txt https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
|
||||
```
|
||||
|
||||
**BookCorpus:**
|
||||
```bash
|
||||
# Various book sources available
|
||||
# Check: https://github.com/soskek/bookcorpus
|
||||
```
|
||||
|
||||
## Data Format Requirements
|
||||
|
||||
Your `data.txt` file should:
|
||||
- Have **one text sample per line**
|
||||
- Use **UTF-8 encoding**
|
||||
- Be **plain text** (no special formatting)
|
||||
|
||||
**Example format:**
|
||||
```
|
||||
This is the first training example.
|
||||
This is the second training example.
|
||||
Each line becomes one training sample.
|
||||
```
|
||||
|
||||
**Good:**
|
||||
```
|
||||
Hello world!
|
||||
This is a sentence.
|
||||
Machine learning is cool.
|
||||
```
|
||||
|
||||
**Bad:**
|
||||
```
|
||||
This is paragraph 1 with multiple sentences. This is sentence 2.
|
||||
This is paragraph 2.
|
||||
```
|
||||
|
||||
## Preprocessing Tips
|
||||
|
||||
1. **Clean your data:**
|
||||
```python
|
||||
import re
|
||||
|
||||
with open("raw_data.txt", "r") as f:
|
||||
text = f.read()
|
||||
|
||||
# Remove extra whitespace
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
|
||||
# Split into sentences
|
||||
sentences = text.split('.')
|
||||
|
||||
# Write one per line
|
||||
with open("data/cleaned_data.txt", "w") as f:
|
||||
for sentence in sentences:
|
||||
if sentence.strip():
|
||||
f.write(sentence.strip() + '\n')
|
||||
```
|
||||
|
||||
2. **Split long texts:**
|
||||
```python
|
||||
# If you have long texts, split them into sentences
|
||||
text = "Long paragraph here. Another sentence. More text."
|
||||
sentences = text.split('.')
|
||||
for sentence in sentences:
|
||||
if sentence.strip():
|
||||
print(sentence.strip())
|
||||
```
|
||||
|
||||
## Quick Test
|
||||
|
||||
1. **Create a small test file:**
|
||||
```bash
|
||||
cat > data/test.txt << EOF
|
||||
Hello world!
|
||||
This is a test.
|
||||
Language models are cool.
|
||||
EOF
|
||||
```
|
||||
|
||||
2. **Train with it:**
|
||||
```bash
|
||||
python train.py --data data/test.txt --output ./checkpoints
|
||||
```
|
||||
|
||||
## Recommended Data Sources
|
||||
|
||||
- **Small (for testing):** Shakespeare text, sample_data.txt
|
||||
- **Medium (for training):** Wikipedia articles, news articles
|
||||
- **Large (for serious training):** WikiText-2, BookCorpus, OpenWebText
|
||||
|
||||
## Next Steps
|
||||
|
||||
Once you have your data.txt file:
|
||||
|
||||
```bash
|
||||
# Train your model
|
||||
python train.py --data data/your_data.txt --output ./checkpoints
|
||||
|
||||
# Or use the sample data
|
||||
python train.py --data data/sample_data.txt --output ./checkpoints
|
||||
```
|
||||
|
||||
Happy training! 🚀
|
||||
|
||||
914
docs/DATA_PROCESSING_EXPLAINED.md
Normal file
914
docs/DATA_PROCESSING_EXPLAINED.md
Normal file
@@ -0,0 +1,914 @@
|
||||
# Data Processing Explained: Step-by-Step Guide
|
||||
|
||||
Complete guide to understanding data processing in the SheepOp LLM project, explaining what happens to your data from raw files to training-ready text.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [What is Data Processing?](#1-what-is-data-processing)
|
||||
2. [Why Do We Need Data Processing?](#2-why-do-we-need-data-processing)
|
||||
3. [The Data Processing Pipeline](#3-the-data-processing-pipeline)
|
||||
4. [Step-by-Step: How Each File Type is Processed](#4-step-by-step-how-each-file-type-is-processed)
|
||||
5. [Data Transformation Stages](#5-data-transformation-stages)
|
||||
6. [Complete Example: Processing "Hello World.pdf"](#6-complete-example-processing-hello-worldpdf)
|
||||
7. [Data Quality and Filtering](#7-data-quality-and-filtering)
|
||||
8. [Common Questions](#8-common-questions)
|
||||
|
||||
---
|
||||
|
||||
## 1. What is Data Processing?
|
||||
|
||||
**Data processing** is the transformation of raw, unstructured data into a format that machine learning models can understand and learn from.
|
||||
|
||||
### Simple Analogy
|
||||
|
||||
Think of data processing like preparing ingredients for cooking:
|
||||
|
||||
**Raw Ingredients (Your Files):**
|
||||
- PDF documents
|
||||
- Text files
|
||||
- Images with text
|
||||
- Code files
|
||||
|
||||
**Prepared Ingredients (Processed Data):**
|
||||
- Clean text lines
|
||||
- Consistent format
|
||||
- Ready for training
|
||||
|
||||
**The Recipe (Training):**
|
||||
- The model learns from the prepared ingredients
|
||||
|
||||
### In Our Context
|
||||
|
||||
**Input:** Mixed file types (PDFs, images, code, text)
|
||||
**Output:** List of text strings ready for tokenization
|
||||
**Purpose:** Extract meaningful text that the model can learn from
|
||||
|
||||
---
|
||||
|
||||
## 2. Why Do We Need Data Processing?
|
||||
|
||||
### 2.1 The Problem
|
||||
|
||||
Machine learning models (like our transformer) understand **numbers**, not:
|
||||
- PDF files
|
||||
- Images
|
||||
- Raw text files
|
||||
- Code files
|
||||
|
||||
### 2.2 The Solution
|
||||
|
||||
We need to:
|
||||
1. **Extract** text from different file formats
|
||||
2. **Clean** the text (remove noise, handle encoding)
|
||||
3. **Standardize** the format (consistent structure)
|
||||
4. **Prepare** for tokenization (split into manageable pieces)
|
||||
|
||||
### 2.3 Benefits
|
||||
|
||||
✅ **Unified Format**: All data becomes text lines
|
||||
✅ **Easy to Process**: Simple format for tokenization
|
||||
✅ **Flexible**: Works with many file types
|
||||
✅ **Scalable**: Can process thousands of files automatically
|
||||
|
||||
---
|
||||
|
||||
## 3. The Data Processing Pipeline
|
||||
|
||||
### 3.1 High-Level Overview
|
||||
|
||||
```
|
||||
Raw Files
|
||||
↓
|
||||
[File Type Detection]
|
||||
↓
|
||||
[Text Extraction]
|
||||
↓
|
||||
[Text Cleaning]
|
||||
↓
|
||||
[Line Splitting]
|
||||
↓
|
||||
[Filtering]
|
||||
↓
|
||||
Clean Text Lines
|
||||
↓
|
||||
[Tokenization] ← Not part of data processing
|
||||
↓
|
||||
[Training] ← Not part of data processing
|
||||
```
|
||||
|
||||
### 3.2 Detailed Pipeline
|
||||
|
||||
```
|
||||
Step 1: Directory Scan
|
||||
└─→ Find all files in data/ directory
|
||||
└─→ Categorize by file type (.pdf, .txt, .png, etc.)
|
||||
|
||||
Step 2: File Type Detection
|
||||
└─→ Check file extension
|
||||
└─→ Route to appropriate processor
|
||||
|
||||
Step 3: Text Extraction
|
||||
├─→ PDF files → PDF text extraction
|
||||
├─→ Text files → Read as text
|
||||
├─→ Image files → OCR (Optical Character Recognition)
|
||||
└─→ Code files → Read as text
|
||||
|
||||
Step 4: Text Cleaning
|
||||
└─→ Remove extra whitespace
|
||||
└─→ Handle encoding issues
|
||||
└─→ Normalize line endings
|
||||
|
||||
Step 5: Line Splitting
|
||||
└─→ Split text into individual lines
|
||||
└─→ Each line becomes one training sample
|
||||
|
||||
Step 6: Filtering
|
||||
└─→ Remove empty lines
|
||||
└─→ Filter by minimum length
|
||||
└─→ Remove lines that are too short
|
||||
|
||||
Step 7: Output
|
||||
└─→ List of text strings
|
||||
└─→ Ready for tokenization
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. Step-by-Step: How Each File Type is Processed
|
||||
|
||||
### 4.1 Text Files (.txt, .md, .log, etc.)
|
||||
|
||||
**What happens:**
|
||||
1. File is opened
|
||||
2. Content is read line by line
|
||||
3. Each line becomes a separate text sample
|
||||
|
||||
**Example:**
|
||||
|
||||
**Input:** `document.txt`
|
||||
```
|
||||
Hello world
|
||||
This is a sentence.
|
||||
Machine learning is fascinating.
|
||||
```
|
||||
|
||||
**Processing:**
|
||||
```
|
||||
Line 1: "Hello world"
|
||||
Line 2: "This is a sentence."
|
||||
Line 3: "Machine learning is fascinating."
|
||||
```
|
||||
|
||||
**Output:**
|
||||
```python
|
||||
[
|
||||
"Hello world",
|
||||
"This is a sentence.",
|
||||
"Machine learning is fascinating."
|
||||
]
|
||||
```
|
||||
|
||||
**Why this works:** Text files are already in plain text format, so extraction is straightforward.
|
||||
|
||||
---
|
||||
|
||||
### 4.2 Code Files (.py, .js, .java, etc.)
|
||||
|
||||
**What happens:**
|
||||
1. File is opened
|
||||
2. Content is read line by line
|
||||
3. Each line becomes a separate text sample
|
||||
|
||||
**Example:**
|
||||
|
||||
**Input:** `example.py`
|
||||
```python
|
||||
def hello():
|
||||
print("Hello")
|
||||
return True
|
||||
```
|
||||
|
||||
**Processing:**
|
||||
```
|
||||
Line 1: "def hello():"
|
||||
Line 2: " print("Hello")"
|
||||
Line 3: " return True"
|
||||
```
|
||||
|
||||
**Output:**
|
||||
```python
|
||||
[
|
||||
"def hello():",
|
||||
" print("Hello")",
|
||||
" return True"
|
||||
]
|
||||
```
|
||||
|
||||
**Why this works:** Code files are text files, so they're processed the same way. The model learns code patterns and syntax.
|
||||
|
||||
---
|
||||
|
||||
### 4.3 PDF Files (.pdf)
|
||||
|
||||
**What happens:**
|
||||
1. PDF file is opened
|
||||
2. Text is extracted from each page
|
||||
3. Text is split into lines
|
||||
4. Lines are filtered for quality
|
||||
|
||||
**Example:**
|
||||
|
||||
**Input:** `document.pdf` (3 pages)
|
||||
|
||||
**Page 1:**
|
||||
```
|
||||
Introduction to Machine Learning
|
||||
Machine learning is a subset of artificial intelligence.
|
||||
```
|
||||
|
||||
**Page 2:**
|
||||
```
|
||||
Neural Networks
|
||||
Neural networks are computing systems inspired by biological neural networks.
|
||||
```
|
||||
|
||||
**Page 3:**
|
||||
```
|
||||
Conclusion
|
||||
In conclusion, machine learning has revolutionized technology.
|
||||
```
|
||||
|
||||
**Processing:**
|
||||
|
||||
**Step 1: Extract text from each page**
|
||||
```
|
||||
Page 1 text: "Introduction to Machine Learning\nMachine learning is a subset of artificial intelligence."
|
||||
Page 2 text: "Neural Networks\nNeural networks are computing systems inspired by biological neural networks."
|
||||
Page 3 text: "Conclusion\nIn conclusion, machine learning has revolutionized technology."
|
||||
```
|
||||
|
||||
**Step 2: Split by newlines**
|
||||
```
|
||||
Line 1: "Introduction to Machine Learning"
|
||||
Line 2: "Machine learning is a subset of artificial intelligence."
|
||||
Line 3: "Neural Networks"
|
||||
Line 4: "Neural networks are computing systems inspired by biological neural networks."
|
||||
Line 5: "Conclusion"
|
||||
Line 6: "In conclusion, machine learning has revolutionized technology."
|
||||
```
|
||||
|
||||
**Step 3: Filter short lines**
|
||||
```
|
||||
Remove: "Introduction to Machine Learning" (too short for context)
|
||||
Keep: "Machine learning is a subset of artificial intelligence."
|
||||
Remove: "Neural Networks" (too short)
|
||||
Keep: "Neural networks are computing systems inspired by biological neural networks."
|
||||
Remove: "Conclusion" (too short)
|
||||
Keep: "In conclusion, machine learning has revolutionized technology."
|
||||
```
|
||||
|
||||
**Output:**
|
||||
```python
|
||||
[
|
||||
"Machine learning is a subset of artificial intelligence.",
|
||||
"Neural networks are computing systems inspired by biological neural networks.",
|
||||
"In conclusion, machine learning has revolutionized technology."
|
||||
]
|
||||
```
|
||||
|
||||
**Why this works:** PDFs contain text embedded in the file structure. Libraries like PyPDF2 or pdfplumber extract this text, preserving the content but losing formatting.
|
||||
|
||||
---
|
||||
|
||||
### 4.4 Image Files (.png, .jpg, etc.)
|
||||
|
||||
**What happens:**
|
||||
1. Image file is opened
|
||||
2. OCR (Optical Character Recognition) reads text from the image
|
||||
3. Extracted text is split into lines
|
||||
4. Lines are filtered for quality
|
||||
|
||||
**Example:**
|
||||
|
||||
**Input:** `screenshot.png` containing:
|
||||
```
|
||||
Hello World
|
||||
This is text in an image.
|
||||
```
|
||||
|
||||
**Processing:**
|
||||
|
||||
**Step 1: OCR Processing**
|
||||
```
|
||||
Image → OCR Engine → Text
|
||||
"Hello World\nThis is text in an image."
|
||||
```
|
||||
|
||||
**Step 2: Split by newlines**
|
||||
```
|
||||
Line 1: "Hello World"
|
||||
Line 2: "This is text in an image."
|
||||
```
|
||||
|
||||
**Step 3: Filter short lines**
|
||||
```
|
||||
Remove: "Hello World" (might be too short)
|
||||
Keep: "This is text in an image."
|
||||
```
|
||||
|
||||
**Output:**
|
||||
```python
|
||||
[
|
||||
"This is text in an image."
|
||||
]
|
||||
```
|
||||
|
||||
**Why this works:** OCR software analyzes the image pixel by pixel, identifies characters, and converts them to text. Accuracy depends on image quality.
|
||||
|
||||
---
|
||||
|
||||
## 5. Data Transformation Stages
|
||||
|
||||
### 5.1 Stage 1: File Discovery
|
||||
|
||||
**Purpose:** Find all files to process
|
||||
|
||||
**Process:**
|
||||
```
|
||||
Directory: data/
|
||||
├── document.pdf
|
||||
├── code.py
|
||||
├── screenshot.png
|
||||
└── notes.txt
|
||||
|
||||
Scan recursively:
|
||||
├── Find: document.pdf
|
||||
├── Find: code.py
|
||||
├── Find: screenshot.png
|
||||
└── Find: notes.txt
|
||||
|
||||
Total: 4 files found
|
||||
```
|
||||
|
||||
**Result:** List of file paths to process
|
||||
|
||||
---
|
||||
|
||||
### 5.2 Stage 2: File Type Classification
|
||||
|
||||
**Purpose:** Determine how to process each file
|
||||
|
||||
**Process:**
|
||||
```
|
||||
File: document.pdf
|
||||
├── Extension: .pdf
|
||||
├── Type: PDF
|
||||
└── Processor: PDF Extractor
|
||||
|
||||
File: code.py
|
||||
├── Extension: .py
|
||||
├── Type: Code
|
||||
└── Processor: Text Reader
|
||||
|
||||
File: screenshot.png
|
||||
├── Extension: .png
|
||||
├── Type: Image
|
||||
└── Processor: OCR
|
||||
|
||||
File: notes.txt
|
||||
├── Extension: .txt
|
||||
├── Type: Text
|
||||
└── Processor: Text Reader
|
||||
```
|
||||
|
||||
**Result:** Each file assigned to appropriate processor
|
||||
|
||||
---
|
||||
|
||||
### 5.3 Stage 3: Text Extraction
|
||||
|
||||
**Purpose:** Get raw text from each file
|
||||
|
||||
**Process:**
|
||||
|
||||
**PDF File:**
|
||||
```
|
||||
document.pdf
|
||||
→ Open PDF
|
||||
→ Extract Page 1: "Introduction..."
|
||||
→ Extract Page 2: "Chapter 1..."
|
||||
→ Extract Page 3: "Conclusion..."
|
||||
→ Combine: "Introduction...\nChapter 1...\nConclusion..."
|
||||
```
|
||||
|
||||
**Text File:**
|
||||
```
|
||||
notes.txt
|
||||
→ Open file
|
||||
→ Read content: "Hello\nWorld\nTest"
|
||||
```
|
||||
|
||||
**Image File:**
|
||||
```
|
||||
screenshot.png
|
||||
→ Open image
|
||||
→ Run OCR
|
||||
→ Extract: "Hello World\nThis is text"
|
||||
```
|
||||
|
||||
**Code File:**
|
||||
```
|
||||
code.py
|
||||
→ Open file
|
||||
→ Read content: "def hello():\n print('Hi')"
|
||||
```
|
||||
|
||||
**Result:** Raw text strings from each file
|
||||
|
||||
---
|
||||
|
||||
### 5.4 Stage 4: Text Cleaning
|
||||
|
||||
**Purpose:** Standardize and clean the extracted text
|
||||
|
||||
**Process:**
|
||||
|
||||
**Input:**
|
||||
```
|
||||
"Hello World\n\n\nThis is a test. "
|
||||
```
|
||||
|
||||
**Step 1: Remove Extra Whitespace**
|
||||
```
|
||||
"Hello World\n\n\nThis is a test. "
|
||||
↓
|
||||
"Hello World\n\n\nThis is a test."
|
||||
```
|
||||
|
||||
**Step 2: Normalize Line Endings**
|
||||
```
|
||||
"Hello World\n\n\nThis is a test."
|
||||
↓
|
||||
"Hello World\n\n\nThis is a test."
|
||||
```
|
||||
|
||||
**Step 3: Handle Encoding**
|
||||
```
|
||||
"Hello World" (UTF-8)
|
||||
↓
|
||||
"Hello World" (checked and valid)
|
||||
```
|
||||
|
||||
**Result:** Cleaned text strings
|
||||
|
||||
---
|
||||
|
||||
### 5.5 Stage 5: Line Splitting
|
||||
|
||||
**Purpose:** Break text into individual training samples
|
||||
|
||||
**Process:**
|
||||
|
||||
**Input:**
|
||||
```
|
||||
"Hello World\nThis is a test.\nMachine learning is cool."
|
||||
```
|
||||
|
||||
**Split by newlines:**
|
||||
```
|
||||
Line 1: "Hello World"
|
||||
Line 2: "This is a test."
|
||||
Line 3: "Machine learning is cool."
|
||||
```
|
||||
|
||||
**Result:** List of individual text lines
|
||||
|
||||
---
|
||||
|
||||
### 5.6 Stage 6: Filtering
|
||||
|
||||
**Purpose:** Keep only useful text samples
|
||||
|
||||
**Process:**
|
||||
|
||||
**Input:**
|
||||
```python
|
||||
[
|
||||
"Hello World", # Length: 11
|
||||
"Hi", # Length: 2 (too short)
|
||||
"This is a sentence.", # Length: 19
|
||||
"", # Empty (remove)
|
||||
"A" # Length: 1 (too short)
|
||||
]
|
||||
```
|
||||
|
||||
**Filter criteria:**
|
||||
- Minimum length: 10 characters
|
||||
- Non-empty strings
|
||||
|
||||
**Filtering:**
|
||||
```
|
||||
Keep: "Hello World" (length 11 ≥ 10)
|
||||
Remove: "Hi" (length 2 < 10)
|
||||
Keep: "This is a sentence." (length 19 ≥ 10)
|
||||
Remove: "" (empty)
|
||||
Remove: "A" (length 1 < 10)
|
||||
```
|
||||
|
||||
**Output:**
|
||||
```python
|
||||
[
|
||||
"Hello World",
|
||||
"This is a sentence."
|
||||
]
|
||||
```
|
||||
|
||||
**Result:** Filtered list of quality text samples
|
||||
|
||||
---
|
||||
|
||||
## 6. Complete Example: Processing "Hello World.pdf"
|
||||
|
||||
Let's trace through processing a complete PDF file step-by-step.
|
||||
|
||||
### Input
|
||||
**File:** `Hello World.pdf`
|
||||
**Location:** `data/documents/Hello World.pdf`
|
||||
**Content:** 2 pages with text
|
||||
|
||||
### Step-by-Step Processing
|
||||
|
||||
#### Step 1: File Discovery
|
||||
|
||||
```
|
||||
Scanning: data/
|
||||
├── documents/
|
||||
│ └── Hello World.pdf ← Found
|
||||
├── images/
|
||||
└── code/
|
||||
|
||||
File found: data/documents/Hello World.pdf
|
||||
```
|
||||
|
||||
#### Step 2: File Type Detection
|
||||
|
||||
```
|
||||
File: Hello World.pdf
|
||||
Extension: .pdf
|
||||
Type: PDF
|
||||
Processor: PDF Extractor
|
||||
```
|
||||
|
||||
#### Step 3: PDF Text Extraction
|
||||
|
||||
**Page 1 Content:**
|
||||
```
|
||||
Hello World
|
||||
This is a simple example document.
|
||||
It contains multiple sentences.
|
||||
```
|
||||
|
||||
**Page 2 Content:**
|
||||
```
|
||||
Second Page
|
||||
Here is more content.
|
||||
The end.
|
||||
```
|
||||
|
||||
**Extraction Process:**
|
||||
```
|
||||
Open PDF file
|
||||
↓
|
||||
Extract Page 1:
|
||||
Text: "Hello World\nThis is a simple example document.\nIt contains multiple sentences."
|
||||
↓
|
||||
Extract Page 2:
|
||||
Text: "Second Page\nHere is more content.\nThe end."
|
||||
↓
|
||||
Combine pages:
|
||||
"Hello World\nThis is a simple example document.\nIt contains multiple sentences.\nSecond Page\nHere is more content.\nThe end."
|
||||
```
|
||||
|
||||
#### Step 4: Text Cleaning
|
||||
|
||||
**Input:**
|
||||
```
|
||||
"Hello World\nThis is a simple example document.\nIt contains multiple sentences.\nSecond Page\nHere is more content.\nThe end."
|
||||
```
|
||||
|
||||
**Process:**
|
||||
```
|
||||
Remove extra whitespace: ✓ (already clean)
|
||||
Normalize encoding: ✓ (UTF-8)
|
||||
Handle special characters: ✓ (none found)
|
||||
```
|
||||
|
||||
**Output:**
|
||||
```
|
||||
"Hello World\nThis is a simple example document.\nIt contains multiple sentences.\nSecond Page\nHere is more content.\nThe end."
|
||||
```
|
||||
|
||||
#### Step 5: Line Splitting
|
||||
|
||||
**Input:**
|
||||
```
|
||||
"Hello World\nThis is a simple example document.\nIt contains multiple sentences.\nSecond Page\nHere is more content.\nThe end."
|
||||
```
|
||||
|
||||
**Split by newline character (`\n`):**
|
||||
```
|
||||
Line 1: "Hello World"
|
||||
Line 2: "This is a simple example document."
|
||||
Line 3: "It contains multiple sentences."
|
||||
Line 4: "Second Page"
|
||||
Line 5: "Here is more content."
|
||||
Line 6: "The end."
|
||||
```
|
||||
|
||||
#### Step 6: Filtering
|
||||
|
||||
**Input:**
|
||||
```python
|
||||
[
|
||||
"Hello World", # Length: 11
|
||||
"This is a simple example document.", # Length: 36
|
||||
"It contains multiple sentences.", # Length: 31
|
||||
"Second Page", # Length: 11
|
||||
"Here is more content.", # Length: 21
|
||||
"The end." # Length: 8 (too short!)
|
||||
]
|
||||
```
|
||||
|
||||
**Filter: Minimum length = 10**
|
||||
```
|
||||
✓ Keep: "Hello World" (11 ≥ 10)
|
||||
✓ Keep: "This is a simple example document." (36 ≥ 10)
|
||||
✓ Keep: "It contains multiple sentences." (31 ≥ 10)
|
||||
✓ Keep: "Second Page" (11 ≥ 10)
|
||||
✓ Keep: "Here is more content." (21 ≥ 10)
|
||||
✗ Remove: "The end." (8 < 10)
|
||||
```
|
||||
|
||||
#### Step 7: Final Output
|
||||
|
||||
**Result:**
|
||||
```python
|
||||
[
|
||||
"Hello World",
|
||||
"This is a simple example document.",
|
||||
"It contains multiple sentences.",
|
||||
"Second Page",
|
||||
"Here is more content."
|
||||
]
|
||||
```
|
||||
|
||||
**Statistics:**
|
||||
- Files processed: 1
|
||||
- Pages extracted: 2
|
||||
- Lines extracted: 6
|
||||
- Lines kept: 5
|
||||
- Lines filtered: 1
|
||||
|
||||
---
|
||||
|
||||
## 7. Data Quality and Filtering
|
||||
|
||||
### 7.1 Why Filter?
|
||||
|
||||
**Problem:** Not all text is useful for training
|
||||
|
||||
**Examples of Low-Quality Text:**
|
||||
|
||||
```
|
||||
✗ "" (empty line)
|
||||
✗ " " (just whitespace)
|
||||
✗ "Hi" (too short, no context)
|
||||
✗ "A" (single character)
|
||||
✗ "..." (ellipsis, no meaning)
|
||||
✗ "---" (separator line)
|
||||
```
|
||||
|
||||
**Examples of High-Quality Text:**
|
||||
|
||||
```
|
||||
✓ "Machine learning is a subset of artificial intelligence."
|
||||
✓ "The transformer architecture uses self-attention mechanisms."
|
||||
✓ "Gradient descent optimizes neural network parameters."
|
||||
```
|
||||
|
||||
### 7.2 Filtering Criteria
|
||||
|
||||
**Minimum Length Filter:**
|
||||
|
||||
**Purpose:** Remove very short lines that don't provide context
|
||||
|
||||
**Example:**
|
||||
```
|
||||
Minimum length: 10 characters
|
||||
|
||||
Keep:
|
||||
✓ "Hello world" (11 chars)
|
||||
✓ "This is a test." (15 chars)
|
||||
|
||||
Remove:
|
||||
✗ "Hi" (2 chars)
|
||||
✗ "Test" (4 chars)
|
||||
✗ "OK" (2 chars)
|
||||
```
|
||||
|
||||
**Why 10 characters?**
|
||||
- Provides enough context for meaningful learning
|
||||
- Filters out headers, separators, and noise
|
||||
- Ensures each sample has semantic value
|
||||
|
||||
### 7.3 Encoding Handling
|
||||
|
||||
**Problem:** Files may have different encodings
|
||||
|
||||
**Solution:** Try multiple encodings
|
||||
|
||||
**Process:**
|
||||
```
|
||||
Try UTF-8 first:
|
||||
✓ Success → Use UTF-8
|
||||
✗ Failure → Try Latin-1
|
||||
✓ Success → Use Latin-1
|
||||
✗ Failure → Log error and skip file
|
||||
```
|
||||
|
||||
**Example:**
|
||||
|
||||
**UTF-8 file:**
|
||||
```
|
||||
"Hello 世界" → Reads correctly
|
||||
```
|
||||
|
||||
**Latin-1 file:**
|
||||
```
|
||||
"Hello café" → Reads correctly with Latin-1
|
||||
```
|
||||
|
||||
### 7.4 Error Handling
|
||||
|
||||
**What happens when processing fails?**
|
||||
|
||||
**Examples:**
|
||||
|
||||
**Corrupted PDF:**
|
||||
```
|
||||
File: corrupted.pdf
|
||||
→ Try to extract text
|
||||
→ Error: "Cannot read PDF"
|
||||
→ Log warning: "Failed to process corrupted.pdf"
|
||||
→ Skip file
|
||||
→ Continue with next file
|
||||
```
|
||||
|
||||
**Unsupported File Type:**
|
||||
```
|
||||
File: presentation.pptx
|
||||
→ Extension: .pptx
|
||||
→ Type: Not supported
|
||||
→ Warning: "Unsupported file type: .pptx"
|
||||
→ Skip file
|
||||
→ Continue with next file
|
||||
```
|
||||
|
||||
**Image OCR Failure:**
|
||||
```
|
||||
File: blurry_image.png
|
||||
→ Try OCR
|
||||
→ OCR returns empty or garbled text
|
||||
→ Filter removes empty lines
|
||||
→ No text extracted
|
||||
→ File processed (no output)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8. Common Questions
|
||||
|
||||
### Q1: Why process PDFs instead of using them directly?
|
||||
|
||||
**Answer:**
|
||||
Models work with numbers (token IDs), not file formats. PDFs have:
|
||||
- Complex structure (fonts, layouts, metadata)
|
||||
- Embedded formatting
|
||||
- Binary data mixed with text
|
||||
|
||||
Processing extracts just the text content, which is what the model needs.
|
||||
|
||||
### Q2: What if OCR doesn't work well on an image?
|
||||
|
||||
**Answer:**
|
||||
- Low-quality images produce poor OCR results
|
||||
- The system will extract what it can
|
||||
- Poor OCR output is filtered out (too short or garbled)
|
||||
- The file is processed but may contribute little or no text
|
||||
|
||||
**Solution:** Use high-quality images with clear text for best results.
|
||||
|
||||
### Q3: Why split text into lines?
|
||||
|
||||
**Answer:**
|
||||
- Each line becomes a training sample
|
||||
- Models predict next tokens in sequences
|
||||
- Shorter sequences are easier to process
|
||||
- Allows the model to learn from diverse sentence structures
|
||||
|
||||
### Q4: What happens to code formatting?
|
||||
|
||||
**Answer:**
|
||||
- Code is processed as text
|
||||
- Indentation and structure are preserved
|
||||
- Each line becomes a sample
|
||||
- The model learns code patterns and syntax
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
def hello():
|
||||
print("Hi")
|
||||
```
|
||||
|
||||
Becomes:
|
||||
```
|
||||
"def hello():"
|
||||
" print("Hi")"
|
||||
```
|
||||
|
||||
### Q5: Can I process files in parallel?
|
||||
|
||||
**Answer:**
|
||||
Currently, files are processed sequentially. Future improvements could include:
|
||||
- Parallel processing of multiple files
|
||||
- Multi-threaded extraction
|
||||
- Batch processing for efficiency
|
||||
|
||||
### Q6: What if a file is very large?
|
||||
|
||||
**Answer:**
|
||||
- Large files are processed line by line
|
||||
- Memory usage stays manageable
|
||||
- Progress is logged every 100 files
|
||||
- System can handle files of any size (within memory limits)
|
||||
|
||||
### Q7: How is data from different file types combined?
|
||||
|
||||
**Answer:**
|
||||
All extracted text is combined into a single list:
|
||||
|
||||
```
|
||||
PDF file → 50 lines extracted
|
||||
Text file → 30 lines extracted
|
||||
Code file → 100 lines extracted
|
||||
Image → 5 lines extracted
|
||||
|
||||
Combined: 185 text lines total
|
||||
```
|
||||
|
||||
All lines are treated equally, regardless of source file type.
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
### What is Data Processing?
|
||||
|
||||
**Data processing** is the transformation of raw files (PDFs, images, code, text) into clean text lines that can be tokenized and used for training.
|
||||
|
||||
### Key Steps
|
||||
|
||||
1. **Find Files**: Scan directory for all files
|
||||
2. **Classify**: Determine file type (.pdf, .txt, .png, etc.)
|
||||
3. **Extract**: Get text content from each file
|
||||
4. **Clean**: Remove noise and standardize format
|
||||
5. **Split**: Break into individual lines
|
||||
6. **Filter**: Keep only quality text samples
|
||||
|
||||
### Result
|
||||
|
||||
A list of text strings ready for:
|
||||
- Tokenization (converting to numbers)
|
||||
- Training (teaching the model)
|
||||
- Learning (model understanding patterns)
|
||||
|
||||
### Example Flow
|
||||
|
||||
```
|
||||
PDF file "document.pdf"
|
||||
↓
|
||||
Extract text from pages
|
||||
↓
|
||||
Clean and split into lines
|
||||
↓
|
||||
Filter by length
|
||||
↓
|
||||
["Sentence 1.", "Sentence 2.", "Sentence 3."]
|
||||
↓
|
||||
Ready for tokenization and training!
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
*This document explains what data processing means and how it transforms your raw files into training-ready text, step by step.*
|
||||
|
||||
287
docs/EMBEDDINGS_EXPLAINED.md
Normal file
287
docs/EMBEDDINGS_EXPLAINED.md
Normal file
@@ -0,0 +1,287 @@
|
||||
# What are Embeddings? Step-by-Step Explanation
|
||||
|
||||
Complete step-by-step explanation of embeddings in transformer models: how words become numbers that capture meaning.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [The Problem Embeddings Solve](#11-the-problem-embeddings-solve)
|
||||
2. [What is an Embedding?](#12-what-is-an-embedding)
|
||||
3. [How Embeddings Work](#13-how-embeddings-work)
|
||||
4. [Step-by-Step Example: Embedding "Hello"](#14-step-by-step-example-embedding-hello)
|
||||
5. [Why Embeddings Matter](#15-why-embeddings-matter)
|
||||
6. [Complete Example: Embedding Multiple Words](#16-complete-example-embedding-multiple-words)
|
||||
7. [Visual Representation](#17-visual-representation)
|
||||
8. [Key Takeaways](#18-key-takeaways)
|
||||
|
||||
---
|
||||
|
||||
## 1.1 The Problem Embeddings Solve
|
||||
|
||||
### The Challenge
|
||||
|
||||
**Computers understand numbers, not words.**
|
||||
|
||||
Your model receives:
|
||||
- Input: `"Hello"` (a word, not a number)
|
||||
|
||||
But neural networks need:
|
||||
- Input: Numbers (like `[0.1, -0.2, 0.3, ...]`)
|
||||
|
||||
### The Solution: Embeddings
|
||||
|
||||
**Embeddings convert words (or tokens) into numbers (vectors) that capture meaning.**
|
||||
|
||||
---
|
||||
|
||||
## 1.2 What is an Embedding?
|
||||
|
||||
### Simple Definition
|
||||
|
||||
An **embedding** is a numerical representation of a word or token that captures its semantic meaning.
|
||||
|
||||
**Think of it like this:**
|
||||
- Each word gets a unique "address" in a high-dimensional space
|
||||
- Similar words end up close together
|
||||
- Different words are far apart
|
||||
|
||||
### Visual Analogy
|
||||
|
||||
Imagine a map where:
|
||||
- Words are cities
|
||||
- Similar words are nearby cities
|
||||
- Different words are distant cities
|
||||
|
||||
```
|
||||
Semantic Space (2D visualization)
|
||||
|
||||
"cat" "dog"
|
||||
● ●
|
||||
|
||||
"car" "vehicle"
|
||||
● ●
|
||||
|
||||
"king" "queen"
|
||||
● ●
|
||||
```
|
||||
|
||||
In reality, embeddings use **512 dimensions** (not 2D), but the concept is the same.
|
||||
|
||||
---
|
||||
|
||||
## 1.3 How Embeddings Work
|
||||
|
||||
### Step 1: Vocabulary Mapping
|
||||
|
||||
**Create a mapping from words to numbers:**
|
||||
|
||||
```
|
||||
Vocabulary:
|
||||
"Hello" → Token ID: 72
|
||||
"World" → Token ID: 87
|
||||
"the" → Token ID: 32
|
||||
...
|
||||
```
|
||||
|
||||
**Result:** Each word has a unique ID number
|
||||
|
||||
### Step 2: Embedding Matrix
|
||||
|
||||
**Create a matrix where each row represents a word:**
|
||||
|
||||
```
|
||||
Embedding Matrix E:
|
||||
|
||||
Dimension 0 Dimension 1 Dimension 2 ... Dimension 511
|
||||
Token 0 [ 0.05 , -0.10 , 0.20 , ..., 0.15 ]
|
||||
Token 1 [ -0.08 , 0.12 , -0.05 , ..., 0.08 ]
|
||||
Token 2 [ 0.10 , -0.15 , 0.25 , ..., 0.12 ]
|
||||
...
|
||||
Token 72 [ 0.10 , -0.20 , 0.30 , ..., 0.05 ] ← "Hello"
|
||||
...
|
||||
Token 87 [ 0.15 , -0.18 , 0.28 , ..., 0.10 ] ← "World"
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- Each row is a 512-dimensional vector
|
||||
- Each row represents one token/word
|
||||
- The values are learned during training
|
||||
|
||||
### Step 3: Lookup Operation
|
||||
|
||||
**When you need an embedding, look it up:**
|
||||
|
||||
```
|
||||
Input: Token ID = 72 ("Hello")
|
||||
↓
|
||||
Lookup: E[72]
|
||||
↓
|
||||
Output: [0.10, -0.20, 0.30, ..., 0.05] (512 numbers)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 1.4 Step-by-Step Example: Embedding "Hello"
|
||||
|
||||
### Input
|
||||
|
||||
```
|
||||
Word: "Hello"
|
||||
Token ID: 72
|
||||
```
|
||||
|
||||
### Process
|
||||
|
||||
**Step 1: Get Token ID**
|
||||
```
|
||||
"Hello" → Lookup in vocabulary → 72
|
||||
```
|
||||
|
||||
**Step 2: Lookup Embedding**
|
||||
```
|
||||
E[72] = [0.10, -0.20, 0.30, 0.15, -0.05, ..., 0.05]
|
||||
```
|
||||
|
||||
**Step 3: Result**
|
||||
```
|
||||
Embedding vector: [0.10, -0.20, 0.30, ..., 0.05]
|
||||
Dimension: 512 numbers
|
||||
Meaning: Numerical representation of "Hello"
|
||||
```
|
||||
|
||||
### What These Numbers Mean
|
||||
|
||||
**Individual numbers don't mean much by themselves**, but **together** they represent:
|
||||
- Semantic meaning (what the word means)
|
||||
- Contextual relationships (how it relates to other words)
|
||||
- Syntactic information (grammatical role)
|
||||
|
||||
**Key Insight:** The model learns these values during training to capture meaning.
|
||||
|
||||
---
|
||||
|
||||
## 1.5 Why Embeddings Matter
|
||||
|
||||
### Benefit 1: Continuous Space
|
||||
|
||||
**Before Embeddings:**
|
||||
```
|
||||
"Hello" = 72
|
||||
"World" = 87
|
||||
Distance: |72 - 87| = 15 (meaningless!)
|
||||
```
|
||||
|
||||
**After Embeddings:**
|
||||
```
|
||||
"Hello" = [0.10, -0.20, 0.30, ...]
|
||||
"World" = [0.15, -0.18, 0.28, ...]
|
||||
Distance: Can measure similarity mathematically!
|
||||
```
|
||||
|
||||
### Benefit 2: Semantic Relationships
|
||||
|
||||
**Similar words have similar embeddings:**
|
||||
|
||||
```
|
||||
"cat" ≈ [0.8, 0.2, 0.1, ...]
|
||||
"dog" ≈ [0.7, 0.3, 0.1, ...] ← Similar to "cat"
|
||||
"car" ≈ [0.1, 0.9, 0.8, ...] ← Different from "cat"
|
||||
```
|
||||
|
||||
**Distance in embedding space = semantic similarity**
|
||||
|
||||
### Benefit 3: Mathematical Operations
|
||||
|
||||
**You can do math with embeddings:**
|
||||
|
||||
```
|
||||
"king" - "man" + "woman" ≈ "queen"
|
||||
```
|
||||
|
||||
This works because embeddings capture semantic relationships!
|
||||
|
||||
---
|
||||
|
||||
## 1.6 Complete Example: Embedding Multiple Words
|
||||
|
||||
### Input Sentence
|
||||
|
||||
```
|
||||
"Hello World"
|
||||
```
|
||||
|
||||
### Step-by-Step Processing
|
||||
|
||||
**Step 1: Tokenize**
|
||||
```
|
||||
"Hello" → Token ID: 72
|
||||
"World" → Token ID: 87
|
||||
```
|
||||
|
||||
**Step 2: Lookup Embeddings**
|
||||
```
|
||||
E[72] = [0.10, -0.20, 0.30, ..., 0.05] (512 numbers)
|
||||
E[87] = [0.15, -0.18, 0.28, ..., 0.10] (512 numbers)
|
||||
```
|
||||
|
||||
**Step 3: Stack Together**
|
||||
```
|
||||
Embedding Matrix:
|
||||
[
|
||||
[0.10, -0.20, 0.30, ..., 0.05], ← "Hello"
|
||||
[0.15, -0.18, 0.28, ..., 0.10] ← "World"
|
||||
]
|
||||
Shape: [2, 512]
|
||||
```
|
||||
|
||||
**Result:** Each word becomes a 512-dimensional vector
|
||||
|
||||
---
|
||||
|
||||
## 1.7 Visual Representation
|
||||
|
||||
### Embedding Space Visualization
|
||||
|
||||
```
|
||||
2D Projection of 512-Dimensional Embedding Space:
|
||||
|
||||
0.3 │ "World"
|
||||
│ ●
|
||||
0.2 │ "Hello"
|
||||
│ ●
|
||||
0.1 │
|
||||
│
|
||||
0.0 ├───────────────────────────
|
||||
│
|
||||
-0.1 │
|
||||
│
|
||||
-0.2 │
|
||||
│
|
||||
-0.3 │
|
||||
```
|
||||
|
||||
**Reality:** Embeddings exist in 512-dimensional space, but we can visualize them in 2D or 3D projections.
|
||||
|
||||
### Similarity Visualization
|
||||
|
||||
```
|
||||
Word Similarities (distance in embedding space):
|
||||
|
||||
"cat" ──── 0.15 distance ──── "dog" (similar)
|
||||
"cat" ──── 2.5 distance ──── "car" (different)
|
||||
"king" ──── 0.8 distance ──── "queen" (related)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 1.8 Key Takeaways: Embeddings
|
||||
|
||||
✅ **Embeddings convert words to numbers**
|
||||
✅ **Each word becomes a vector (list of numbers)**
|
||||
✅ **Similar words have similar vectors**
|
||||
✅ **Enables mathematical operations on words**
|
||||
✅ **Learned during training to capture meaning**
|
||||
|
||||
---
|
||||
|
||||
*This document provides a step-by-step explanation of embeddings, the fundamental component that converts words into numerical representations in transformer models.*
|
||||
|
||||
470
docs/FEED_FORWARD_EXPLAINED.md
Normal file
470
docs/FEED_FORWARD_EXPLAINED.md
Normal file
@@ -0,0 +1,470 @@
|
||||
# What is Feed-Forward? Step-by-Step Explanation
|
||||
|
||||
Complete step-by-step explanation of feed-forward networks in transformer models: how models transform and refine features.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [The Problem Feed-Forward Solves](#31-the-problem-feed-forward-solves)
|
||||
2. [What is Feed-Forward?](#32-what-is-feed-forward)
|
||||
3. [How Feed-Forward Works: Step-by-Step](#33-how-feed-forward-works-step-by-step)
|
||||
4. [Complete Example: Feed-Forward on "Hello"](#34-complete-example-feed-forward-on-hello)
|
||||
5. [Why Feed-Forward Matters](#35-why-feed-forward-matters)
|
||||
6. [Complete Feed-Forward Formula](#36-complete-feed-forward-formula)
|
||||
7. [Visual Representation](#37-visual-representation)
|
||||
8. [Why Expand and Compress?](#38-why-expand-and-compress)
|
||||
9. [Key Takeaways](#39-key-takeaways)
|
||||
|
||||
---
|
||||
|
||||
## 3.1 The Problem Feed-Forward Solves
|
||||
|
||||
### The Challenge
|
||||
|
||||
**Attention provides context, but we need to process and transform that information.**
|
||||
|
||||
Think of it like cooking:
|
||||
- **Attention:** Gathers ingredients (context)
|
||||
- **Feed-Forward:** Cooks and transforms ingredients (processing)
|
||||
|
||||
### The Solution: Feed-Forward Network
|
||||
|
||||
**Feed-Forward applies complex transformations to each position independently.**
|
||||
|
||||
---
|
||||
|
||||
## 3.2 What is Feed-Forward?
|
||||
|
||||
### Simple Definition
|
||||
|
||||
A **Feed-Forward Network (FFN)** is a two-layer neural network that:
|
||||
1. **Expands** the input to a larger dimension
|
||||
2. **Applies** a nonlinear transformation
|
||||
3. **Compresses** back to original dimension
|
||||
|
||||
### Visual Analogy
|
||||
|
||||
**Think of it like a funnel:**
|
||||
|
||||
```
|
||||
Input (512 dimensions)
|
||||
↓
|
||||
┌─────────────┐
|
||||
│ EXPAND │
|
||||
│ 512 → 2048 │
|
||||
└──────┬──────┘
|
||||
↓
|
||||
┌─────────────┐
|
||||
│ TRANSFORM │
|
||||
│ (GELU) │
|
||||
└──────┬──────┘
|
||||
↓
|
||||
┌─────────────┐
|
||||
│ COMPRESS │
|
||||
│ 2048 → 512 │
|
||||
└──────┬──────┘
|
||||
↓
|
||||
Output (512 dimensions)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3.3 How Feed-Forward Works: Step-by-Step
|
||||
|
||||
### High-Level Overview
|
||||
|
||||
```
|
||||
Step 1: Expand dimension (512 → 2048)
|
||||
Step 2: Apply nonlinear activation (GELU)
|
||||
Step 3: Compress dimension (2048 → 512)
|
||||
```
|
||||
|
||||
### Detailed Step-by-Step
|
||||
|
||||
#### Step 1: Expansion (First Linear Layer)
|
||||
|
||||
**Input:** Vector of size 512
|
||||
**Output:** Vector of size 2048
|
||||
|
||||
**Mathematical Operation:**
|
||||
```
|
||||
H = X × W₁ + b₁
|
||||
```
|
||||
|
||||
**Example:**
|
||||
|
||||
**Input X:**
|
||||
```
|
||||
[0.10, -0.20, 0.30, ..., 0.05] (512 numbers)
|
||||
```
|
||||
|
||||
**Weight Matrix W₁:**
|
||||
```
|
||||
Shape: [512, 2048]
|
||||
Each column transforms input to one output dimension
|
||||
```
|
||||
|
||||
**Process:**
|
||||
```
|
||||
H[0] = X[0]×W₁[0,0] + X[1]×W₁[1,0] + ... + X[511]×W₁[511,0]
|
||||
H[1] = X[0]×W₁[0,1] + X[1]×W₁[1,1] + ... + X[511]×W₁[511,1]
|
||||
...
|
||||
H[2047] = X[0]×W₁[0,2047] + ... + X[511]×W₁[511,2047]
|
||||
```
|
||||
|
||||
**Result:**
|
||||
```
|
||||
H = [0.12, -0.08, 0.25, ..., 0.18] (2048 numbers)
|
||||
```
|
||||
|
||||
**Why Expand?**
|
||||
- More dimensions = more capacity for complex transformations
|
||||
- Allows the model to learn intricate patterns
|
||||
- Think of it as "more room to work"
|
||||
|
||||
#### Step 2: Nonlinear Activation (GELU)
|
||||
|
||||
**Apply GELU to each element:**
|
||||
|
||||
**GELU Function:**
|
||||
```
|
||||
GELU(x) = x × Φ(x)
|
||||
|
||||
Where Φ(x) is the cumulative distribution function of standard normal distribution
|
||||
```
|
||||
|
||||
**Simplified Understanding:**
|
||||
- Values near zero → suppressed (close to 0)
|
||||
- Positive values → pass through (modified)
|
||||
- Negative values → suppressed more
|
||||
|
||||
**Example:**
|
||||
|
||||
**Input H:**
|
||||
```
|
||||
H = [0.12, -0.08, 0.25, ..., 0.18]
|
||||
```
|
||||
|
||||
**Apply GELU element-wise:**
|
||||
|
||||
```
|
||||
GELU(0.12) ≈ 0.12 × 0.548 ≈ 0.066
|
||||
GELU(-0.08) ≈ -0.08 × 0.468 ≈ -0.037
|
||||
GELU(0.25) ≈ 0.25 × 0.599 ≈ 0.150
|
||||
...
|
||||
GELU(0.18) ≈ 0.18 × 0.572 ≈ 0.103
|
||||
```
|
||||
|
||||
**Result:**
|
||||
```
|
||||
H' = [0.066, -0.037, 0.150, ..., 0.103] (2048 numbers)
|
||||
```
|
||||
|
||||
**Why Nonlinear?**
|
||||
- Linear transformations can only do so much
|
||||
- Nonlinearity enables complex function approximation
|
||||
- Essential for learning patterns
|
||||
|
||||
#### Step 3: Compression (Second Linear Layer)
|
||||
|
||||
**Input:** Vector of size 2048
|
||||
**Output:** Vector of size 512
|
||||
|
||||
**Mathematical Operation:**
|
||||
```
|
||||
O = H' × W₂ + b₂
|
||||
```
|
||||
|
||||
**Process:**
|
||||
```
|
||||
O[0] = H'[0]×W₂[0,0] + H'[1]×W₂[1,0] + ... + H'[2047]×W₂[2047,0]
|
||||
O[1] = H'[0]×W₂[0,1] + H'[1]×W₂[1,1] + ... + H'[2047]×W₂[2047,1]
|
||||
...
|
||||
O[511] = H'[0]×W₂[0,511] + ... + H'[2047]×W₂[2047,511]
|
||||
```
|
||||
|
||||
**Result:**
|
||||
```
|
||||
O = [0.15, -0.10, 0.22, ..., 0.12] (512 numbers)
|
||||
```
|
||||
|
||||
**Why Compress?**
|
||||
- Project back to original dimension
|
||||
- Maintains consistent size throughout model
|
||||
- Combines expanded features into compact representation
|
||||
|
||||
---
|
||||
|
||||
## 3.4 Complete Example: Feed-Forward on "Hello"
|
||||
|
||||
### Input
|
||||
|
||||
```
|
||||
Word: "Hello"
|
||||
After Attention: [0.146, 0.108, 0.192, ..., 0.11]
|
||||
Dimension: 512
|
||||
```
|
||||
|
||||
### Step-by-Step Processing
|
||||
|
||||
#### Step 1: Expansion
|
||||
|
||||
**Input X:**
|
||||
```
|
||||
[0.146, 0.108, 0.192, ..., 0.11] (512 numbers)
|
||||
```
|
||||
|
||||
**Weight Matrix W₁:**
|
||||
```
|
||||
Shape: [512, 2048]
|
||||
Values: Learned during training
|
||||
```
|
||||
|
||||
**Compute:**
|
||||
```
|
||||
H = X × W₁
|
||||
```
|
||||
|
||||
**Result:**
|
||||
```
|
||||
H = [0.21, -0.15, 0.28, ..., 0.19] (2048 numbers)
|
||||
```
|
||||
|
||||
**Visualization:**
|
||||
```
|
||||
512 dimensions ──→ ┌──────────┐ ──→ 2048 dimensions
|
||||
│ W₁ │
|
||||
└──────────┘
|
||||
```
|
||||
|
||||
#### Step 2: Activation
|
||||
|
||||
**Input H:**
|
||||
```
|
||||
[0.21, -0.15, 0.28, ..., 0.19] (2048 numbers)
|
||||
```
|
||||
|
||||
**Apply GELU element-wise:**
|
||||
|
||||
```
|
||||
GELU(0.21) ≈ 0.115
|
||||
GELU(-0.15) ≈ -0.058
|
||||
GELU(0.28) ≈ 0.168
|
||||
...
|
||||
GELU(0.19) ≈ 0.109
|
||||
```
|
||||
|
||||
**Result:**
|
||||
```
|
||||
H' = [0.115, -0.058, 0.168, ..., 0.109] (2048 numbers)
|
||||
```
|
||||
|
||||
**Visualization:**
|
||||
```
|
||||
2048 dimensions ──→ ┌──────────┐ ──→ 2048 dimensions
|
||||
│ GELU │
|
||||
└──────────┘
|
||||
```
|
||||
|
||||
#### Step 3: Compression
|
||||
|
||||
**Input H':**
|
||||
```
|
||||
[0.115, -0.058, 0.168, ..., 0.109] (2048 numbers)
|
||||
```
|
||||
|
||||
**Weight Matrix W₂:**
|
||||
```
|
||||
Shape: [2048, 512]
|
||||
Values: Learned during training
|
||||
```
|
||||
|
||||
**Compute:**
|
||||
```
|
||||
O = H' × W₂
|
||||
```
|
||||
|
||||
**Result:**
|
||||
```
|
||||
O = [0.18, -0.12, 0.24, ..., 0.14] (512 numbers)
|
||||
```
|
||||
|
||||
**Visualization:**
|
||||
```
|
||||
2048 dimensions ──→ ┌──────────┐ ──→ 512 dimensions
|
||||
│ W₂ │
|
||||
└──────────┘
|
||||
```
|
||||
|
||||
#### Final Output
|
||||
|
||||
```
|
||||
Output: [0.18, -0.12, 0.24, ..., 0.14] (512 numbers)
|
||||
```
|
||||
|
||||
**Meaning:** Transformed representation that captures processed features
|
||||
|
||||
---
|
||||
|
||||
## 3.5 Why Feed-Forward Matters
|
||||
|
||||
### Benefit 1: Feature Transformation
|
||||
|
||||
**Before FFN:**
|
||||
```
|
||||
Input: Raw attention output
|
||||
Information: Contextual relationships
|
||||
```
|
||||
|
||||
**After FFN:**
|
||||
```
|
||||
Output: Transformed features
|
||||
Information: Processed and refined understanding
|
||||
```
|
||||
|
||||
### Benefit 2: Non-Linear Processing
|
||||
|
||||
**Linear operations** (like attention) can only do limited transformations.
|
||||
**Non-linear operations** (like GELU in FFN) enable complex function learning.
|
||||
|
||||
**Analogy:**
|
||||
- Linear: Can only draw straight lines
|
||||
- Non-linear: Can draw curves, circles, complex shapes
|
||||
|
||||
### Benefit 3: Position-Wise Processing
|
||||
|
||||
**FFN processes each position independently:**
|
||||
|
||||
```
|
||||
Position 0 ("Hello"): FFN → Transformed representation
|
||||
Position 1 ("World"): FFN → Transformed representation
|
||||
```
|
||||
|
||||
**Each word gets its own transformation!**
|
||||
|
||||
---
|
||||
|
||||
## 3.6 Complete Feed-Forward Formula
|
||||
|
||||
### Mathematical Expression
|
||||
|
||||
```
|
||||
FFN(X) = GELU(X × W₁ + b₁) × W₂ + b₂
|
||||
```
|
||||
|
||||
**Breaking it down:**
|
||||
|
||||
**Part 1: First Linear Transformation**
|
||||
```
|
||||
H = X × W₁ + b₁
|
||||
```
|
||||
- Expands from 512 to 2048 dimensions
|
||||
|
||||
**Part 2: Non-Linear Activation**
|
||||
```
|
||||
H' = GELU(H)
|
||||
```
|
||||
- Applies non-linear transformation
|
||||
|
||||
**Part 3: Second Linear Transformation**
|
||||
```
|
||||
O = H' × W₂ + b₂
|
||||
```
|
||||
- Compresses from 2048 back to 512 dimensions
|
||||
|
||||
**Complete:**
|
||||
```
|
||||
FFN(X) = O
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3.7 Visual Representation
|
||||
|
||||
### Feed-Forward Pipeline
|
||||
|
||||
```
|
||||
Input Vector (512D)
|
||||
│
|
||||
│ [0.146, 0.108, 0.192, ..., 0.11]
|
||||
↓
|
||||
┌─────────────────────────────┐
|
||||
│ Linear Layer 1 │
|
||||
│ (512 → 2048 expansion) │
|
||||
│ │
|
||||
│ H = X × W₁ │
|
||||
└──────────┬──────────────────┘
|
||||
│
|
||||
│ [0.21, -0.15, 0.28, ..., 0.19] (2048D)
|
||||
↓
|
||||
┌─────────────────────────────┐
|
||||
│ GELU Activation │
|
||||
│ (Non-linear transformation) │
|
||||
│ │
|
||||
│ H' = GELU(H) │
|
||||
└──────────┬──────────────────┘
|
||||
│
|
||||
│ [0.115, -0.058, 0.168, ..., 0.109] (2048D)
|
||||
↓
|
||||
┌─────────────────────────────┐
|
||||
│ Linear Layer 2 │
|
||||
│ (2048 → 512 compression) │
|
||||
│ │
|
||||
│ O = H' × W₂ │
|
||||
└──────────┬──────────────────┘
|
||||
│
|
||||
│ [0.18, -0.12, 0.24, ..., 0.14] (512D)
|
||||
↓
|
||||
Output Vector (512D)
|
||||
```
|
||||
|
||||
### Dimension Flow
|
||||
|
||||
```
|
||||
512 ──→ [Expand] ──→ 2048 ──→ [Transform] ──→ 2048 ──→ [Compress] ──→ 512
|
||||
```
|
||||
|
||||
**Like a funnel:** Expand → Transform → Compress
|
||||
|
||||
---
|
||||
|
||||
## 3.8 Why Expand and Compress?
|
||||
|
||||
### The Expansion-Compression Strategy
|
||||
|
||||
**Why not stay at 512 dimensions?**
|
||||
|
||||
**Answer:** Expansion provides "working space"
|
||||
|
||||
**Analogy:**
|
||||
- Think of doing math on paper
|
||||
- Small paper (512D) = limited space
|
||||
- Large paper (2048D) = room to work
|
||||
- Then copy results back to small paper (512D)
|
||||
|
||||
**Benefits:**
|
||||
1. **More capacity:** 2048 dimensions = more parameters to learn
|
||||
2. **Better transformations:** More space = more complex functions
|
||||
3. **Feature refinement:** Transformation happens in expanded space
|
||||
|
||||
**Why compress back?**
|
||||
|
||||
**Answer:** Maintain consistent size throughout the model
|
||||
|
||||
- All layers use 512 dimensions
|
||||
- Consistent size enables stacking layers
|
||||
- Easier to manage and optimize
|
||||
|
||||
---
|
||||
|
||||
## 3.9 Key Takeaways: Feed-Forward
|
||||
|
||||
✅ **FFN transforms features through expansion and compression**
|
||||
✅ **Expands to larger dimension for processing**
|
||||
✅ **Applies non-linear transformation (GELU)**
|
||||
✅ **Compresses back to original dimension**
|
||||
✅ **Processes each position independently**
|
||||
|
||||
---
|
||||
|
||||
*This document provides a step-by-step explanation of feed-forward networks, the component that transforms and refines features in transformer models.*
|
||||
|
||||
748
docs/GENERATION_EXPLAINED.md
Normal file
748
docs/GENERATION_EXPLAINED.md
Normal file
@@ -0,0 +1,748 @@
|
||||
# What is Generation? Step-by-Step Explanation
|
||||
|
||||
Complete step-by-step explanation of text generation: how models generate text using autoregressive generation, sampling, and decoding strategies.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [What is Generation?](#91-what-is-generation)
|
||||
2. [Autoregressive Generation](#92-autoregressive-generation)
|
||||
3. [Sampling Strategies](#93-sampling-strategies)
|
||||
4. [Temperature](#94-temperature)
|
||||
5. [Top-k Sampling](#95-top-k-sampling)
|
||||
6. [Top-p (Nucleus) Sampling](#96-top-p-nucleus-sampling)
|
||||
7. [Step-by-Step Generation Process](#97-step-by-step-generation-process)
|
||||
8. [Exercise: Complete Generation Example](#98-exercise-complete-generation-example)
|
||||
9. [Key Takeaways](#99-key-takeaways)
|
||||
|
||||
---
|
||||
|
||||
## 9.1 What is Generation?
|
||||
|
||||
### Simple Definition
|
||||
|
||||
**Generation** (text generation) is the process of using a trained model to produce new text, one token at a time, based on a given prompt.
|
||||
|
||||
### Visual Analogy
|
||||
|
||||
**Think of generation like writing a story:**
|
||||
|
||||
```
|
||||
Prompt: "Once upon a time"
|
||||
|
||||
Model generates:
|
||||
"Once upon a time" → "there"
|
||||
"Once upon a time there" → "was"
|
||||
"Once upon a time there was" → "a"
|
||||
"Once upon a time there was a" → "princess"
|
||||
...
|
||||
|
||||
Final: "Once upon a time there was a princess..."
|
||||
```
|
||||
|
||||
**Model predicts next word, one at a time!**
|
||||
|
||||
### What Generation Does
|
||||
|
||||
**Generation:**
|
||||
1. **Takes** a prompt (starting text)
|
||||
2. **Predicts** next token probabilities
|
||||
3. **Samples** a token from distribution
|
||||
4. **Appends** token to sequence
|
||||
5. **Repeats** until complete
|
||||
|
||||
**Result:** Generated text continuation!
|
||||
|
||||
---
|
||||
|
||||
## 9.2 Autoregressive Generation
|
||||
|
||||
### What is Autoregressive?
|
||||
|
||||
**Autoregressive** means the model uses its own previous outputs as inputs for the next prediction.
|
||||
|
||||
### How It Works
|
||||
|
||||
**Step 1: Initial Prompt**
|
||||
```
|
||||
Prompt: "Hello"
|
||||
Sequence: ["Hello"]
|
||||
```
|
||||
|
||||
**Step 2: First Prediction**
|
||||
```
|
||||
Input: ["Hello"]
|
||||
Model output: Probabilities for next token
|
||||
"World": 0.4
|
||||
"there": 0.3
|
||||
"friend": 0.2
|
||||
...
|
||||
```
|
||||
|
||||
**Step 3: Sample Token**
|
||||
```
|
||||
Sample: "World" (selected)
|
||||
Sequence: ["Hello", "World"]
|
||||
```
|
||||
|
||||
**Step 4: Second Prediction**
|
||||
```
|
||||
Input: ["Hello", "World"]
|
||||
Model output: Probabilities for next token
|
||||
"!": 0.5
|
||||
".": 0.3
|
||||
",": 0.1
|
||||
...
|
||||
```
|
||||
|
||||
**Step 5: Continue**
|
||||
```
|
||||
Sample: "!"
|
||||
Sequence: ["Hello", "World", "!"]
|
||||
Continue until max length or stop token...
|
||||
```
|
||||
|
||||
### Mathematical Formulation
|
||||
|
||||
**For prompt $\mathbf{P} = [p_1, ..., p_k]$:**
|
||||
|
||||
**Initialization:**
|
||||
```math
|
||||
\mathbf{T}_0 = \mathbf{P}
|
||||
```
|
||||
|
||||
**For each step $t \geq k+1$:**
|
||||
|
||||
1. **Forward pass:**
|
||||
```math
|
||||
\mathbf{L}_t = \text{Model}(\mathbf{T}_{t-1})
|
||||
```
|
||||
|
||||
2. **Get next token probabilities:**
|
||||
```math
|
||||
\mathbf{p}_t = \text{softmax}(\mathbf{L}_t[:, -1, :])
|
||||
```
|
||||
|
||||
3. **Sample token:**
|
||||
```math
|
||||
t_t \sim \text{Categorical}(\mathbf{p}_t)
|
||||
```
|
||||
|
||||
4. **Append token:**
|
||||
```math
|
||||
\mathbf{T}_t = [\mathbf{T}_{t-1}, t_t]
|
||||
```
|
||||
|
||||
**Repeat until stop condition!**
|
||||
|
||||
---
|
||||
|
||||
## 9.3 Sampling Strategies
|
||||
|
||||
### Deterministic vs Stochastic
|
||||
|
||||
**Deterministic (Greedy):**
|
||||
```
|
||||
Always pick highest probability:
|
||||
"World": 0.4 ← Highest
|
||||
"there": 0.3
|
||||
"friend": 0.2
|
||||
|
||||
→ Always picks "World"
|
||||
→ Same output every time
|
||||
```
|
||||
|
||||
**Stochastic (Sampling):**
|
||||
```
|
||||
Sample from distribution:
|
||||
"World": 0.4 (40% chance)
|
||||
"there": 0.3 (30% chance)
|
||||
"friend": 0.2 (20% chance)
|
||||
|
||||
→ Different output each time
|
||||
→ More diverse generations
|
||||
```
|
||||
|
||||
### Why Sampling?
|
||||
|
||||
**Greedy (Deterministic):**
|
||||
- Same output every time
|
||||
- Can be repetitive
|
||||
- Less creative
|
||||
|
||||
**Sampling:**
|
||||
- Different outputs each time
|
||||
- More diverse
|
||||
- More creative
|
||||
- Better for creative tasks
|
||||
|
||||
---
|
||||
|
||||
## 9.4 Temperature
|
||||
|
||||
### What is Temperature?
|
||||
|
||||
**Temperature** controls the randomness of sampling by scaling the logits before applying softmax.
|
||||
|
||||
### Formula
|
||||
|
||||
```math
|
||||
\mathbf{p}_t = \text{softmax}\left(\frac{\mathbf{l}_t}{T}\right)
|
||||
```
|
||||
|
||||
**Where:**
|
||||
- $\mathbf{l}_t$ = logits (raw scores)
|
||||
- $T$ = temperature
|
||||
- $\mathbf{p}_t$ = probabilities
|
||||
|
||||
### How Temperature Works
|
||||
|
||||
**T = 0.5 (Low Temperature - More Deterministic):**
|
||||
```
|
||||
Logits: [2.0, 1.0, 0.5]
|
||||
After scaling: [4.0, 2.0, 1.0]
|
||||
After softmax: [0.88, 0.11, 0.01]
|
||||
→ Sharp distribution (one token dominates)
|
||||
→ More deterministic
|
||||
```
|
||||
|
||||
**T = 1.0 (Standard Temperature):**
|
||||
```
|
||||
Logits: [2.0, 1.0, 0.5]
|
||||
After scaling: [2.0, 1.0, 0.5]
|
||||
After softmax: [0.66, 0.24, 0.10]
|
||||
→ Moderate distribution
|
||||
→ Balanced
|
||||
```
|
||||
|
||||
**T = 2.0 (High Temperature - More Random):**
|
||||
```
|
||||
Logits: [2.0, 1.0, 0.5]
|
||||
After scaling: [1.0, 0.5, 0.25]
|
||||
After softmax: [0.52, 0.31, 0.17]
|
||||
→ Flat distribution (more uniform)
|
||||
→ More random
|
||||
```
|
||||
|
||||
### Visual Comparison
|
||||
|
||||
```
|
||||
Probability
|
||||
│
|
||||
1.0│ T=0.5: ●
|
||||
│
|
||||
0.8│
|
||||
│
|
||||
0.6│ T=1.0: ●
|
||||
│
|
||||
0.4│
|
||||
│
|
||||
0.2│ T=2.0: ●
|
||||
│
|
||||
0.0├───────────────────────── Token
|
||||
"World" "there" "friend"
|
||||
```
|
||||
|
||||
**Lower T = Sharper distribution = More deterministic**
|
||||
**Higher T = Flatter distribution = More random**
|
||||
|
||||
### When to Use Different Temperatures
|
||||
|
||||
**Low Temperature (T < 1.0):**
|
||||
- Factual tasks
|
||||
- Reproducible outputs
|
||||
- When you want consistent results
|
||||
|
||||
**Standard Temperature (T = 1.0):**
|
||||
- Default setting
|
||||
- Balanced behavior
|
||||
- Good for most tasks
|
||||
|
||||
**High Temperature (T > 1.0):**
|
||||
- Creative writing
|
||||
- Diverse outputs
|
||||
- When you want variety
|
||||
|
||||
---
|
||||
|
||||
## 9.5 Top-k Sampling
|
||||
|
||||
### What is Top-k?
|
||||
|
||||
**Top-k sampling** limits the sampling to only the top k most likely tokens.
|
||||
|
||||
### How It Works
|
||||
|
||||
**Step 1: Get Probabilities**
|
||||
```
|
||||
All tokens:
|
||||
"World": 0.4
|
||||
"there": 0.3
|
||||
"friend": 0.2
|
||||
"hello": 0.05
|
||||
"cat": 0.03
|
||||
"dog": 0.02
|
||||
...
|
||||
```
|
||||
|
||||
**Step 2: Select Top-k (e.g., k=3)**
|
||||
```
|
||||
Top 3:
|
||||
"World": 0.4
|
||||
"there": 0.3
|
||||
"friend": 0.2
|
||||
```
|
||||
|
||||
**Step 3: Remove Others**
|
||||
```
|
||||
Set others to 0:
|
||||
"World": 0.4
|
||||
"there": 0.3
|
||||
"friend": 0.2
|
||||
"hello": 0.0
|
||||
"cat": 0.0
|
||||
"dog": 0.0
|
||||
...
|
||||
```
|
||||
|
||||
**Step 4: Renormalize**
|
||||
```
|
||||
Sum = 0.4 + 0.3 + 0.2 = 0.9
|
||||
Renormalize:
|
||||
"World": 0.4/0.9 = 0.44
|
||||
"there": 0.3/0.9 = 0.33
|
||||
"friend": 0.2/0.9 = 0.22
|
||||
```
|
||||
|
||||
**Step 5: Sample from Top-k**
|
||||
```
|
||||
Sample from these 3 tokens only
|
||||
```
|
||||
|
||||
### Mathematical Formulation
|
||||
|
||||
**Given probabilities $\mathbf{p}_t$ and top-k:**
|
||||
|
||||
```math
|
||||
\mathbf{p}_t^{topk}[v] = \begin{cases}
|
||||
\frac{\mathbf{p}_t[v]}{\sum_{u \in \text{top-k}} \mathbf{p}_t[u]} & \text{if } v \in \text{top-k} \\
|
||||
0 & \text{otherwise}
|
||||
\end{cases}
|
||||
```
|
||||
|
||||
### Why Top-k?
|
||||
|
||||
**Benefits:**
|
||||
- Removes low-probability tokens
|
||||
- Focuses on likely candidates
|
||||
- Reduces randomness from unlikely tokens
|
||||
- Better quality generations
|
||||
|
||||
**Example:**
|
||||
```
|
||||
Without top-k: Might sample "xyz" (very unlikely)
|
||||
With top-k=50: Only samples from top 50 tokens
|
||||
→ Better quality!
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 9.6 Top-p (Nucleus) Sampling
|
||||
|
||||
### What is Top-p?
|
||||
|
||||
**Top-p (nucleus) sampling** keeps the smallest set of tokens whose cumulative probability is at least p.
|
||||
|
||||
### How It Works
|
||||
|
||||
**Step 1: Sort Probabilities**
|
||||
```
|
||||
Sorted (descending):
|
||||
"World": 0.4
|
||||
"there": 0.3
|
||||
"friend": 0.2
|
||||
"hello": 0.05
|
||||
"cat": 0.03
|
||||
"dog": 0.02
|
||||
...
|
||||
```
|
||||
|
||||
**Step 2: Compute Cumulative Probabilities**
|
||||
```
|
||||
Cumulative:
|
||||
"World": 0.4
|
||||
"there": 0.7 (0.4 + 0.3)
|
||||
"friend": 0.9 (0.7 + 0.2)
|
||||
"hello": 0.95 (0.9 + 0.05)
|
||||
"cat": 0.98 (0.95 + 0.03)
|
||||
...
|
||||
```
|
||||
|
||||
**Step 3: Find Nucleus (e.g., p=0.9)**
|
||||
```
|
||||
Find smallest set where sum ≥ 0.9:
|
||||
"World": 0.4
|
||||
"there": 0.3
|
||||
"friend": 0.2
|
||||
Cumulative: 0.9 ✓
|
||||
|
||||
→ Keep these 3 tokens
|
||||
```
|
||||
|
||||
**Step 4: Remove Others**
|
||||
```
|
||||
Keep:
|
||||
"World": 0.4
|
||||
"there": 0.3
|
||||
"friend": 0.2
|
||||
Others: 0.0
|
||||
```
|
||||
|
||||
**Step 5: Renormalize and Sample**
|
||||
```
|
||||
Renormalize and sample
|
||||
```
|
||||
|
||||
### Mathematical Formulation
|
||||
|
||||
**Given probabilities $\mathbf{p}_t$ and top-p:**
|
||||
|
||||
**Find smallest set S:**
|
||||
```math
|
||||
S = \arg\min \{ |S'| : \sum_{v \in S'} \mathbf{p}_t[v] \geq p \}
|
||||
```
|
||||
|
||||
**Then:**
|
||||
```math
|
||||
\mathbf{p}_t^{topp}[v] = \begin{cases}
|
||||
\frac{\mathbf{p}_t[v]}{\sum_{u \in S} \mathbf{p}_t[u]} & \text{if } v \in S \\
|
||||
0 & \text{otherwise}
|
||||
\end{cases}
|
||||
```
|
||||
|
||||
### Why Top-p?
|
||||
|
||||
**Benefits:**
|
||||
- Adapts to distribution shape
|
||||
- Keeps relevant tokens dynamically
|
||||
- Better than fixed k in some cases
|
||||
- More flexible than top-k
|
||||
|
||||
**Example:**
|
||||
```
|
||||
Sharp distribution: Top-p=0.9 might keep 3 tokens
|
||||
Flat distribution: Top-p=0.9 might keep 50 tokens
|
||||
→ Adapts automatically!
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 9.7 Step-by-Step Generation Process
|
||||
|
||||
### Complete Process
|
||||
|
||||
**Given prompt: "Hello"**
|
||||
|
||||
#### Step 1: Encode Prompt
|
||||
|
||||
```
|
||||
Prompt: "Hello"
|
||||
Token IDs: [72]
|
||||
```
|
||||
|
||||
#### Step 2: Forward Pass
|
||||
|
||||
```
|
||||
Input: [72]
|
||||
Model processes through layers
|
||||
Output: Logits for all tokens
|
||||
Token 72: 5.2
|
||||
Token 87: 4.8 ← "World"
|
||||
Token 101: 3.2 ← "there"
|
||||
Token 108: 2.1 ← "friend"
|
||||
...
|
||||
```
|
||||
|
||||
#### Step 3: Apply Temperature
|
||||
|
||||
```
|
||||
Temperature: T = 1.0
|
||||
Scaled logits: Same as above
|
||||
```
|
||||
|
||||
#### Step 4: Apply Top-k (Optional)
|
||||
|
||||
```
|
||||
Top-k: k = 50
|
||||
Keep top 50 tokens, remove others
|
||||
```
|
||||
|
||||
#### Step 5: Apply Top-p (Optional)
|
||||
|
||||
```
|
||||
Top-p: p = 0.95
|
||||
Keep tokens with cumulative prob ≥ 0.95
|
||||
```
|
||||
|
||||
#### Step 6: Compute Probabilities
|
||||
|
||||
```
|
||||
Apply softmax:
|
||||
"World": 0.4
|
||||
"there": 0.3
|
||||
"friend": 0.2
|
||||
...
|
||||
```
|
||||
|
||||
#### Step 7: Sample Token
|
||||
|
||||
```
|
||||
Sample from distribution:
|
||||
Selected: "World" (token 87)
|
||||
```
|
||||
|
||||
#### Step 8: Append Token
|
||||
|
||||
```
|
||||
Sequence: [72, 87]
|
||||
Text: "Hello World"
|
||||
```
|
||||
|
||||
#### Step 9: Repeat
|
||||
|
||||
```
|
||||
Input: [72, 87]
|
||||
→ Predict next token
|
||||
→ Sample
|
||||
→ Append
|
||||
→ Repeat...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 9.8 Exercise: Complete Generation Example
|
||||
|
||||
### Problem
|
||||
|
||||
**Given:**
|
||||
- Prompt: "The"
|
||||
- Model logits for next token: `[10.0, 8.0, 5.0, 2.0, 1.0, 0.5, ...]` (for tokens: "cat", "dog", "car", "house", "tree", "book", ...)
|
||||
- Temperature: T = 1.0
|
||||
- Top-k: k = 3
|
||||
- Top-p: p = 0.9
|
||||
|
||||
**Generate the next token step-by-step.**
|
||||
|
||||
### Step-by-Step Solution
|
||||
|
||||
#### Step 1: Initial Setup
|
||||
|
||||
**Prompt:**
|
||||
```
|
||||
"The"
|
||||
Token IDs: [32] (assuming "The" = token 32)
|
||||
```
|
||||
|
||||
**Logits:**
|
||||
```
|
||||
Token "cat": 10.0
|
||||
Token "dog": 8.0
|
||||
Token "car": 5.0
|
||||
Token "house": 2.0
|
||||
Token "tree": 1.0
|
||||
Token "book": 0.5
|
||||
...
|
||||
```
|
||||
|
||||
#### Step 2: Apply Temperature
|
||||
|
||||
**Temperature: T = 1.0**
|
||||
|
||||
**Scaled logits (divide by T):**
|
||||
```
|
||||
Token "cat": 10.0 / 1.0 = 10.0
|
||||
Token "dog": 8.0 / 1.0 = 8.0
|
||||
Token "car": 5.0 / 1.0 = 5.0
|
||||
Token "house": 2.0 / 1.0 = 2.0
|
||||
Token "tree": 1.0 / 1.0 = 1.0
|
||||
Token "book": 0.5 / 1.0 = 0.5
|
||||
```
|
||||
|
||||
**No change (T=1.0 is identity)**
|
||||
|
||||
#### Step 3: Apply Top-k Filtering
|
||||
|
||||
**Top-k: k = 3**
|
||||
|
||||
**Select top 3 tokens:**
|
||||
```
|
||||
Top 3:
|
||||
"cat": 10.0
|
||||
"dog": 8.0
|
||||
"car": 5.0
|
||||
```
|
||||
|
||||
**Set others to -∞:**
|
||||
```
|
||||
Token "cat": 10.0
|
||||
Token "dog": 8.0
|
||||
Token "car": 5.0
|
||||
Token "house": -∞
|
||||
Token "tree": -∞
|
||||
Token "book": -∞
|
||||
```
|
||||
|
||||
#### Step 4: Apply Top-p Filtering
|
||||
|
||||
**First, compute probabilities from top-k tokens:**
|
||||
|
||||
**Apply softmax:**
|
||||
```
|
||||
exp(10.0) = 22026.47
|
||||
exp(8.0) = 2980.96
|
||||
exp(5.0) = 148.41
|
||||
Sum = 25155.84
|
||||
|
||||
P("cat") = 22026.47 / 25155.84 ≈ 0.875
|
||||
P("dog") = 2980.96 / 25155.84 ≈ 0.119
|
||||
P("car") = 148.41 / 25155.84 ≈ 0.006
|
||||
```
|
||||
|
||||
**Cumulative probabilities:**
|
||||
```
|
||||
"cat": 0.875
|
||||
"dog": 0.994 (0.875 + 0.119)
|
||||
"car": 1.000 (0.994 + 0.006)
|
||||
```
|
||||
|
||||
**Find smallest set where sum ≥ 0.9:**
|
||||
```
|
||||
"cat": 0.875 < 0.9
|
||||
"cat" + "dog": 0.994 ≥ 0.9 ✓
|
||||
|
||||
→ Keep "cat" and "dog"
|
||||
→ Remove "car"
|
||||
```
|
||||
|
||||
**Result:**
|
||||
```
|
||||
Token "cat": 10.0
|
||||
Token "dog": 8.0
|
||||
Token "car": -∞ (removed)
|
||||
```
|
||||
|
||||
#### Step 5: Compute Final Probabilities
|
||||
|
||||
**Apply softmax to remaining tokens:**
|
||||
```
|
||||
exp(10.0) = 22026.47
|
||||
exp(8.0) = 2980.96
|
||||
Sum = 25007.43
|
||||
|
||||
P("cat") = 22026.47 / 25007.43 ≈ 0.881
|
||||
P("dog") = 2980.96 / 25007.43 ≈ 0.119
|
||||
```
|
||||
|
||||
#### Step 6: Sample Token
|
||||
|
||||
**Sample from distribution:**
|
||||
```
|
||||
Random number: 0.75
|
||||
|
||||
Cumulative:
|
||||
"cat": 0.881 ← 0.75 falls here
|
||||
"dog": 1.000
|
||||
|
||||
→ Selected: "cat"
|
||||
```
|
||||
|
||||
### Answer
|
||||
|
||||
**Generated token: "cat"**
|
||||
|
||||
**Final sequence:**
|
||||
```
|
||||
Prompt: "The"
|
||||
Generated: "cat"
|
||||
Full text: "The cat"
|
||||
```
|
||||
|
||||
### Summary
|
||||
|
||||
| Step | Operation | Result |
|
||||
|------|-----------|--------|
|
||||
| 1 | Initial logits | [10.0, 8.0, 5.0, 2.0, ...] |
|
||||
| 2 | Apply temperature (T=1.0) | [10.0, 8.0, 5.0, 2.0, ...] |
|
||||
| 3 | Top-k filtering (k=3) | Keep top 3: [10.0, 8.0, 5.0] |
|
||||
| 4 | Top-p filtering (p=0.9) | Keep cumulative ≥0.9: [10.0, 8.0] |
|
||||
| 5 | Compute probabilities | [0.881, 0.119] |
|
||||
| 6 | Sample | "cat" selected |
|
||||
|
||||
**The model generated "cat" following "The"!**
|
||||
|
||||
---
|
||||
|
||||
## 9.9 Key Takeaways
|
||||
|
||||
### Generation
|
||||
|
||||
✅ **Generation produces text one token at a time**
|
||||
✅ **Autoregressive: uses previous outputs as inputs**
|
||||
✅ **Iterative process: predict → sample → append → repeat**
|
||||
|
||||
### Sampling Strategies
|
||||
|
||||
✅ **Temperature: Controls randomness (lower = deterministic, higher = random)**
|
||||
✅ **Top-k: Limits to top k tokens**
|
||||
✅ **Top-p: Keeps smallest set with cumulative probability ≥ p**
|
||||
✅ **Combined: Often use temperature + top-k or top-p**
|
||||
|
||||
### Why Important
|
||||
|
||||
✅ **Enables text generation from trained models**
|
||||
✅ **Different strategies produce different outputs**
|
||||
✅ **Essential for language model deployment**
|
||||
|
||||
---
|
||||
|
||||
## Mathematical Summary
|
||||
|
||||
### Generation Process
|
||||
|
||||
**Initialization:**
|
||||
```math
|
||||
\mathbf{T}_0 = \mathbf{P}
|
||||
```
|
||||
|
||||
**For each step $t$:**
|
||||
```math
|
||||
\mathbf{l}_t = \text{Model}(\mathbf{T}_{t-1})[:, -1, :]
|
||||
```
|
||||
|
||||
```math
|
||||
\mathbf{l}_t' = \frac{\mathbf{l}_t}{T} \quad \text{(temperature)}
|
||||
```
|
||||
|
||||
```math
|
||||
\mathbf{l}_t'' = \text{Top-k}(\mathbf{l}_t') \quad \text{(optional)}
|
||||
```
|
||||
|
||||
```math
|
||||
\mathbf{l}_t''' = \text{Top-p}(\mathbf{l}_t'') \quad \text{(optional)}
|
||||
```
|
||||
|
||||
```math
|
||||
\mathbf{p}_t = \text{softmax}(\mathbf{l}_t''')
|
||||
```
|
||||
|
||||
```math
|
||||
t_t \sim \text{Categorical}(\mathbf{p}_t)
|
||||
```
|
||||
|
||||
```math
|
||||
\mathbf{T}_t = [\mathbf{T}_{t-1}, t_t]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
*This document provides a comprehensive explanation of text generation, including autoregressive generation, sampling strategies, temperature, top-k, and top-p with mathematical formulations and solved exercises.*
|
||||
|
||||
1087
docs/MATHEMATICS.md
Normal file
1087
docs/MATHEMATICS.md
Normal file
File diff suppressed because it is too large
Load Diff
195
docs/MULTI_FORMAT_DATA_GUIDE.md
Normal file
195
docs/MULTI_FORMAT_DATA_GUIDE.md
Normal file
@@ -0,0 +1,195 @@
|
||||
# Multi-Format Data Processing Guide
|
||||
|
||||
## Overview
|
||||
|
||||
The training script now supports processing multiple file types from your `data/` directory:
|
||||
|
||||
- **Text files**: `.txt`, `.md`, `.rst`, `.log`, `.csv`, `.json`, `.jsonl`, `.xml`, `.html`, `.htm`
|
||||
- **Code files**: `.py`, `.js`, `.ts`, `.java`, `.cpp`, `.c`, `.go`, `.rs`, `.rb`, `.php`, `.swift`, and many more
|
||||
- **PDF files**: `.pdf` (requires PyPDF2 or pdfplumber)
|
||||
- **Images**: `.png`, `.jpg`, `.jpeg`, `.gif`, `.bmp`, `.tiff`, `.webp` (requires pytesseract for OCR)
|
||||
|
||||
## Basic Usage
|
||||
|
||||
Simply point the training script to your data directory:
|
||||
|
||||
```bash
|
||||
python train.py --data /path/to/your/data/directory
|
||||
```
|
||||
|
||||
The script will automatically:
|
||||
1. Scan the directory (recursively by default)
|
||||
2. Extract text from all supported file types
|
||||
3. Process and tokenize the text
|
||||
4. Train the model on all extracted content
|
||||
|
||||
## Installation
|
||||
|
||||
### Core Dependencies
|
||||
|
||||
The core dependencies are already in `requirements.txt`. Install them with:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Optional Dependencies for PDF and Image Processing
|
||||
|
||||
If you want to process PDFs or images, install the optional dependencies:
|
||||
|
||||
```bash
|
||||
# For PDF processing (choose one):
|
||||
pip install PyPDF2
|
||||
# OR
|
||||
pip install pdfplumber # Alternative, often better for complex PDFs
|
||||
|
||||
# For image OCR:
|
||||
pip install pytesseract Pillow
|
||||
|
||||
# Also install Tesseract OCR engine:
|
||||
# macOS: brew install tesseract
|
||||
# Ubuntu/Debian: sudo apt-get install tesseract-ocr
|
||||
# Windows: Download from https://github.com/UB-Mannheim/tesseract/wiki
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
### 1. Text Files
|
||||
|
||||
Text files are read line by line. Each non-empty line becomes a training sample.
|
||||
|
||||
### 2. Code Files
|
||||
|
||||
Code files are processed as text. Each line of code becomes a training sample. This allows the model to learn code patterns and syntax.
|
||||
|
||||
### 3. PDF Files
|
||||
|
||||
PDFs are processed page by page:
|
||||
- Text is extracted from each page
|
||||
- Split into lines
|
||||
- Filtered to remove very short lines
|
||||
- Each line becomes a training sample
|
||||
|
||||
**Note**: PDF extraction works best with text-based PDFs. Scanned PDFs (images) should use OCR instead.
|
||||
|
||||
### 4. Image Files
|
||||
|
||||
Images are processed using OCR (Optical Character Recognition):
|
||||
- Images are opened using PIL/Pillow
|
||||
- pytesseract extracts text from the image
|
||||
- Text is split into lines
|
||||
- Each line becomes a training sample
|
||||
|
||||
**Note**: OCR quality depends on image quality. For best results:
|
||||
- Use high-resolution images
|
||||
- Ensure good contrast between text and background
|
||||
- Avoid images with complex layouts
|
||||
|
||||
## Configuration Options
|
||||
|
||||
You can customize the data processing behavior:
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
from data import DataProcessor
|
||||
|
||||
processor = DataProcessor(
|
||||
use_ocr=True, # Enable OCR for images
|
||||
use_pdf_extraction=True # Enable PDF extraction
|
||||
)
|
||||
|
||||
# Process directory
|
||||
texts = processor.process_to_list(
|
||||
directory=Path("data/"),
|
||||
recursive=True, # Process subdirectories
|
||||
min_length=10, # Minimum line length
|
||||
max_samples=None, # Limit number of samples (None = all)
|
||||
)
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
### Example 1: Process all files in directory
|
||||
|
||||
```bash
|
||||
python train.py --data /mnt/storage/sheepOp/data
|
||||
```
|
||||
|
||||
### Example 2: Process single file
|
||||
|
||||
```bash
|
||||
python train.py --data /mnt/storage/sheepOp/data/document.pdf
|
||||
```
|
||||
|
||||
### Example 3: Using Python API
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
from data import extract_text_from_directory
|
||||
|
||||
# Extract text from all supported files
|
||||
texts = extract_text_from_directory(
|
||||
directory=Path("data/"),
|
||||
recursive=True,
|
||||
use_ocr=True,
|
||||
use_pdf_extraction=True,
|
||||
min_length=10,
|
||||
)
|
||||
|
||||
print(f"Extracted {len(texts)} text samples")
|
||||
```
|
||||
|
||||
## Supported File Types Summary
|
||||
|
||||
| Category | Extensions | Requirements |
|
||||
|----------|-----------|--------------|
|
||||
| Text | `.txt`, `.md`, `.rst`, `.log`, `.csv`, `.json`, `.jsonl`, `.xml`, `.html`, `.htm` | None |
|
||||
| Code | `.py`, `.js`, `.ts`, `.java`, `.cpp`, `.c`, `.go`, `.rs`, `.rb`, `.php`, `.swift`, and 30+ more | None |
|
||||
| PDF | `.pdf` | PyPDF2 or pdfplumber |
|
||||
| Images | `.png`, `.jpg`, `.jpeg`, `.gif`, `.bmp`, `.tiff`, `.webp` | pytesseract + Pillow + Tesseract OCR |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### PDF extraction not working
|
||||
|
||||
- Install PyPDF2: `pip install PyPDF2`
|
||||
- Or install pdfplumber (better for complex PDFs): `pip install pdfplumber`
|
||||
- If PDFs are scanned images, use OCR instead
|
||||
|
||||
### OCR not working
|
||||
|
||||
1. Install pytesseract: `pip install pytesseract Pillow`
|
||||
2. Install Tesseract OCR engine (see installation instructions above)
|
||||
3. On some systems, you may need to set the tesseract path:
|
||||
```python
|
||||
import pytesseract
|
||||
pytesseract.pytesseract.tesseract_cmd = '/usr/local/bin/tesseract' # macOS example
|
||||
```
|
||||
|
||||
### No text extracted
|
||||
|
||||
- Check that files are in supported formats
|
||||
- Verify file permissions
|
||||
- Check logs for error messages
|
||||
- Try processing a single file first to debug
|
||||
|
||||
## Performance Tips
|
||||
|
||||
1. **Large directories**: Processing can take time for large directories. Progress is logged every 100 files.
|
||||
|
||||
2. **Parallel processing**: Consider processing files in parallel if you have many large files.
|
||||
|
||||
3. **Filtering**: Use `min_length` to filter out very short lines that may not be useful for training.
|
||||
|
||||
4. **Caching**: For repeated processing, consider saving extracted text to a file first.
|
||||
|
||||
## Next Steps
|
||||
|
||||
Once your data is processed:
|
||||
|
||||
1. The training script will automatically tokenize the text
|
||||
2. Create training batches
|
||||
3. Train your model
|
||||
|
||||
For more information on training, see `RETRAINING_GUIDE.md`.
|
||||
|
||||
948
docs/NEURAL_NETWORK_EXPLAINED.md
Normal file
948
docs/NEURAL_NETWORK_EXPLAINED.md
Normal file
@@ -0,0 +1,948 @@
|
||||
# What is a Neural Network? Step-by-Step Explanation
|
||||
|
||||
Complete step-by-step explanation of neural networks: what neurons are, what weights are, how calculations work, why they're important, with mathematical derivations and solved exercises.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [What is a Neural Network?](#61-what-is-a-neural-network)
|
||||
2. [What is a Neuron?](#62-what-is-a-neuron)
|
||||
3. [What are Weights?](#63-what-are-weights)
|
||||
4. [How Neurons Calculate](#64-how-neurons-calculate)
|
||||
5. [Why Weights are Important](#65-why-weights-are-important)
|
||||
6. [Complete Mathematical Formulation](#66-complete-mathematical-formulation)
|
||||
7. [Multi-Layer Neural Networks](#67-multi-layer-neural-networks)
|
||||
8. [Exercise 1: Single Neuron Calculation](#68-exercise-1-single-neuron-calculation)
|
||||
9. [Exercise 2: Multi-Layer Network](#69-exercise-2-multi-layer-network)
|
||||
10. [Exercise 3: Learning Weights](#610-exercise-3-learning-weights)
|
||||
11. [Key Takeaways](#611-key-takeaways)
|
||||
|
||||
---
|
||||
|
||||
## 6.1 What is a Neural Network?
|
||||
|
||||
### Simple Definition
|
||||
|
||||
A **neural network** is a computational model inspired by biological neurons that processes information through interconnected nodes (neurons) to make predictions or decisions.
|
||||
|
||||
### Visual Analogy
|
||||
|
||||
**Think of a neural network like a factory:**
|
||||
|
||||
```
|
||||
Input → Worker 1 → Worker 2 → Worker 3 → Output
|
||||
```
|
||||
|
||||
**Neural Network:**
|
||||
|
||||
```
|
||||
Input → Neuron 1 → Neuron 2 → Neuron 3 → Output
|
||||
```
|
||||
|
||||
**Each worker (neuron) does a specific job, and they work together to produce the final result.**
|
||||
|
||||
### Basic Structure
|
||||
|
||||
```
|
||||
Input Layer Hidden Layer Output Layer
|
||||
● ● ●
|
||||
● ● ●
|
||||
● ● ●
|
||||
● ●
|
||||
```
|
||||
|
||||
**Key Components:**
|
||||
|
||||
- **Input Layer:** Receives data
|
||||
- **Hidden Layers:** Process information
|
||||
- **Output Layer:** Produces predictions
|
||||
- **Connections:** Weights between neurons
|
||||
|
||||
---
|
||||
|
||||
## 6.2 What is a Neuron?
|
||||
|
||||
### Simple Definition
|
||||
|
||||
A **neuron** (also called a node or unit) is the basic processing unit of a neural network. It receives inputs, performs calculations, and produces an output.
|
||||
|
||||
### Biological Inspiration
|
||||
|
||||
**Biological Neuron:**
|
||||
|
||||
```
|
||||
Dendrites → Cell Body → Axon → Synapses
|
||||
(inputs) (process) (output) (connections)
|
||||
```
|
||||
|
||||
**Artificial Neuron:**
|
||||
|
||||
```
|
||||
Inputs → Weighted Sum → Activation → Output
|
||||
```
|
||||
|
||||
### Structure of a Neuron
|
||||
|
||||
```
|
||||
Input 1 (x₁) ────┐
|
||||
│
|
||||
Input 2 (x₂) ────┼──→ [Σ] ─→ [f] ─→ Output (y)
|
||||
│
|
||||
Input 3 (x₃) ────┘
|
||||
```
|
||||
|
||||
**Components:**
|
||||
|
||||
1. **Inputs:** Values fed into the neuron
|
||||
2. **Weights:** Strength of connections
|
||||
3. **Weighted Sum:** Sum of inputs × weights
|
||||
4. **Bias:** Added constant
|
||||
5. **Activation Function:** Applies nonlinearity
|
||||
6. **Output:** Final result
|
||||
|
||||
### Visual Representation
|
||||
|
||||
```
|
||||
Neuron:
|
||||
┌─────────────────────┐
|
||||
│ Inputs: x₁, x₂, x₃ │
|
||||
│ Weights: w₁, w₂, w₃│
|
||||
│ │
|
||||
│ z = Σ(xᵢ × wᵢ) + b │
|
||||
│ y = f(z) │
|
||||
│ │
|
||||
│ Output: y │
|
||||
└─────────────────────┘
|
||||
```
|
||||
|
||||
**Where:**
|
||||
|
||||
- `z` = weighted sum (before activation)
|
||||
- `f` = activation function
|
||||
- `y` = output (after activation)
|
||||
|
||||
---
|
||||
|
||||
## 6.3 What are Weights?
|
||||
|
||||
### Simple Definition
|
||||
|
||||
**Weights** are numerical values that determine the strength of connections between neurons. They control how much each input contributes to the output.
|
||||
|
||||
### Visual Analogy
|
||||
|
||||
**Think of weights like volume controls:**
|
||||
|
||||
```
|
||||
Music Source 1 ──[Volume: 0.8]──→ Speakers
|
||||
Music Source 2 ──[Volume: 0.3]──→ Speakers
|
||||
Music Source 3 ──[Volume: 0.5]──→ Speakers
|
||||
```
|
||||
|
||||
**Higher weight = Louder contribution**
|
||||
|
||||
**Neural Network:**
|
||||
|
||||
```
|
||||
Input 1 ──[Weight: 0.8]──→ Neuron
|
||||
Input 2 ──[Weight: 0.3]──→ Neuron
|
||||
Input 3 ──[Weight: 0.5]──→ Neuron
|
||||
```
|
||||
|
||||
**Higher weight = Stronger influence**
|
||||
|
||||
### What Weights Do
|
||||
|
||||
**Weights determine:**
|
||||
|
||||
1. **How much each input matters**
|
||||
2. **The relationship between inputs and outputs**
|
||||
3. **What patterns the neuron learns**
|
||||
|
||||
**Example:**
|
||||
|
||||
**Weight = 0.1:**
|
||||
|
||||
- Input has small influence
|
||||
- Weak connection
|
||||
|
||||
**Weight = 5.0:**
|
||||
|
||||
- Input has large influence
|
||||
- Strong connection
|
||||
|
||||
**Weight = -2.0:**
|
||||
|
||||
- Input has negative influence
|
||||
- Inverts the relationship
|
||||
|
||||
**Weight = 0.0:**
|
||||
|
||||
- Input has no influence
|
||||
- Connection is cut
|
||||
|
||||
### Weight Matrix
|
||||
|
||||
**In a layer with multiple neurons:**
|
||||
|
||||
```
|
||||
Input Layer Weights Matrix Output Layer
|
||||
x₁ ───────────────────┐
|
||||
│ w₁₁ w₁₂ y₁
|
||||
x₂ ───────────────────┼─ w₂₁ w₂₂ ──── y₂
|
||||
│ w₃₁ w₃₂
|
||||
x₃ ───────────────────┘
|
||||
```
|
||||
|
||||
**Weight Matrix:**
|
||||
|
||||
```
|
||||
W = [w₁₁ w₁₂]
|
||||
[w₂₁ w₂₂]
|
||||
[w₃₁ w₃₂]
|
||||
```
|
||||
|
||||
**Each row:** Connections from one input
|
||||
**Each column:** Connections to one output
|
||||
|
||||
---
|
||||
|
||||
## 6.4 How Neurons Calculate
|
||||
|
||||
### Step-by-Step Calculation
|
||||
|
||||
#### Step 1: Weighted Sum
|
||||
|
||||
**Multiply each input by its weight:**
|
||||
|
||||
```math
|
||||
z = x_1 \times w_1 + x_2 \times w_2 + x_3 \times w_3 + ... + b
|
||||
```
|
||||
|
||||
**Or in vector form:**
|
||||
|
||||
```math
|
||||
z = \mathbf{x} \cdot \mathbf{w} + b = \sum_{i=1}^{n} x_i w_i + b
|
||||
```
|
||||
|
||||
**Where:**
|
||||
|
||||
- $x_i$ = input value
|
||||
- $w_i$ = weight for input $i$
|
||||
- $b$ = bias (constant)
|
||||
- $n$ = number of inputs
|
||||
|
||||
#### Step 2: Add Bias
|
||||
|
||||
**Bias shifts the activation:**
|
||||
|
||||
```math
|
||||
z = \sum_{i=1}^{n} x_i w_i + b
|
||||
```
|
||||
|
||||
**Bias allows the neuron to:**
|
||||
|
||||
- Shift activation threshold
|
||||
- Learn patterns independent of inputs
|
||||
- Adjust baseline output
|
||||
|
||||
#### Step 3: Apply Activation Function
|
||||
|
||||
**Apply nonlinear function:**
|
||||
|
||||
```math
|
||||
y = f(z)
|
||||
```
|
||||
|
||||
**Common activation functions:**
|
||||
|
||||
**ReLU (Rectified Linear Unit):**
|
||||
|
||||
```math
|
||||
f(z) = \max(0, z)
|
||||
```
|
||||
|
||||
**Sigmoid:**
|
||||
|
||||
```math
|
||||
f(z) = \frac{1}{1 + e^{-z}}
|
||||
```
|
||||
|
||||
**Tanh:**
|
||||
|
||||
```math
|
||||
f(z) = \tanh(z) = \frac{e^z - e^{-z}}{e^z + e^{-z}}
|
||||
```
|
||||
|
||||
**GELU (used in transformers):**
|
||||
|
||||
```math
|
||||
f(z) = z \cdot \Phi(z)
|
||||
```
|
||||
|
||||
**Where $\Phi(z)$ is the CDF of standard normal distribution**
|
||||
|
||||
### Complete Example
|
||||
|
||||
**Given:**
|
||||
|
||||
- Inputs: $x_1 = 0.5, x_2 = 0.3, x_3 = 0.8$
|
||||
- Weights: $w_1 = 0.6, w_2 = 0.4, w_3 = 0.2$
|
||||
- Bias: $b = 0.1$
|
||||
- Activation: ReLU
|
||||
|
||||
**Step 1: Weighted Sum**
|
||||
|
||||
```
|
||||
z = (0.5 × 0.6) + (0.3 × 0.4) + (0.8 × 0.2) + 0.1
|
||||
= 0.3 + 0.12 + 0.16 + 0.1
|
||||
= 0.68
|
||||
```
|
||||
|
||||
**Step 2: Apply Activation**
|
||||
|
||||
```
|
||||
y = ReLU(0.68)
|
||||
= max(0, 0.68)
|
||||
= 0.68
|
||||
```
|
||||
|
||||
**Result:** Output = 0.68
|
||||
|
||||
---
|
||||
|
||||
## 6.5 Why Weights are Important
|
||||
|
||||
### Reason 1: They Determine What the Neuron Learns
|
||||
|
||||
**Different weights = Different patterns:**
|
||||
|
||||
**Pattern 1: Emphasis on Input 1**
|
||||
|
||||
```
|
||||
w₁ = 5.0, w₂ = 0.1, w₃ = 0.1
|
||||
→ Neuron cares mostly about input 1
|
||||
```
|
||||
|
||||
**Pattern 2: Balanced Weights**
|
||||
|
||||
```
|
||||
w₁ = 0.5, w₂ = 0.5, w₃ = 0.5
|
||||
→ Neuron treats all inputs equally
|
||||
```
|
||||
|
||||
**Pattern 3: Inverted Relationship**
|
||||
|
||||
```
|
||||
w₁ = -2.0, w₂ = 1.0, w₃ = 1.0
|
||||
→ Neuron inverses input 1's effect
|
||||
```
|
||||
|
||||
### Reason 2: They Enable Learning
|
||||
|
||||
**Training adjusts weights:**
|
||||
|
||||
**Before Training:**
|
||||
|
||||
```
|
||||
Weights: Random values
|
||||
→ Random predictions
|
||||
```
|
||||
|
||||
**After Training:**
|
||||
|
||||
```
|
||||
Weights: Learned values
|
||||
→ Accurate predictions
|
||||
```
|
||||
|
||||
**Weights are what the model learns!**
|
||||
|
||||
### Reason 3: They Control Information Flow
|
||||
|
||||
**High weights:** Information flows easily
|
||||
**Low weights:** Information flows weakly
|
||||
**Zero weights:** Information blocked
|
||||
**Negative weights:** Information inverted
|
||||
|
||||
### Reason 4: They Enable Complex Patterns
|
||||
|
||||
**Multiple neurons with different weights:**
|
||||
|
||||
```
|
||||
Neuron 1: w₁ = 1.0, w₂ = 0.0 → Detects pattern A
|
||||
Neuron 2: w₁ = 0.0, w₂ = 1.0 → Detects pattern B
|
||||
Neuron 3: w₁ = 0.5, w₂ = 0.5 → Detects pattern C
|
||||
```
|
||||
|
||||
**Together:** Model learns complex relationships!
|
||||
|
||||
---
|
||||
|
||||
## 6.6 Complete Mathematical Formulation
|
||||
|
||||
### Single Neuron Formula
|
||||
|
||||
**Complete neuron calculation:**
|
||||
|
||||
```math
|
||||
z = \sum_{i=1}^{n} x_i w_i + b
|
||||
```
|
||||
|
||||
```math
|
||||
y = f(z)
|
||||
```
|
||||
|
||||
**Where:**
|
||||
|
||||
- $\mathbf{x} = [x_1, x_2, ..., x_n]$ = input vector
|
||||
- $\mathbf{w} = [w_1, w_2, ..., w_n]$ = weight vector
|
||||
- $b$ = bias (scalar)
|
||||
- $f$ = activation function
|
||||
- $z$ = weighted sum (before activation)
|
||||
- $y$ = output (after activation)
|
||||
|
||||
### Matrix Formulation
|
||||
|
||||
**For multiple neurons:**
|
||||
|
||||
```math
|
||||
\mathbf{z} = \mathbf{X} \mathbf{W} + \mathbf{b}
|
||||
```
|
||||
|
||||
```math
|
||||
\mathbf{Y} = f(\mathbf{z})
|
||||
```
|
||||
|
||||
**Where:**
|
||||
|
||||
- $\mathbf{X} \in \mathbb{R}^{B \times n}$ = input matrix (B samples, n features)
|
||||
- $\mathbf{W} \in \mathbb{R}^{n \times m}$ = weight matrix (n inputs, m neurons)
|
||||
- $\mathbf{b} \in \mathbb{R}^{1 \times m}$ = bias vector
|
||||
- $\mathbf{z} \in \mathbb{R}^{B \times m}$ = weighted sums
|
||||
- $\mathbf{Y} \in \mathbb{R}^{B \times m}$ = outputs
|
||||
|
||||
**Example:**
|
||||
|
||||
**Input Matrix:**
|
||||
|
||||
```
|
||||
X = [x₁₁ x₁₂] (2 samples, 2 features)
|
||||
[x₂₁ x₂₂]
|
||||
```
|
||||
|
||||
**Weight Matrix:**
|
||||
|
||||
```
|
||||
W = [w₁₁ w₁₂] (2 inputs, 2 neurons)
|
||||
[w₂₁ w₂₂]
|
||||
```
|
||||
|
||||
**Bias Vector:**
|
||||
|
||||
```
|
||||
b = [b₁ b₂] (2 neurons)
|
||||
```
|
||||
|
||||
**Calculation:**
|
||||
|
||||
```
|
||||
z = X × W + b
|
||||
|
||||
z₁₁ = x₁₁×w₁₁ + x₁₂×w₂₁ + b₁
|
||||
z₁₂ = x₁₁×w₁₂ + x₁₂×w₂₂ + b₂
|
||||
z₂₁ = x₂₁×w₁₁ + x₂₂×w₂₁ + b₁
|
||||
z₂₂ = x₂₁×w₁₂ + x₂₂×w₂₂ + b₂
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6.7 Multi-Layer Neural Networks
|
||||
|
||||
### Structure
|
||||
|
||||
```
|
||||
Input Layer → Hidden Layer 1 → Hidden Layer 2 → Output Layer
|
||||
x₁ h₁₁ h₂₁ y₁
|
||||
x₂ h₁₂ h₂₂ y₂
|
||||
x₃ h₁₃ h₂₃
|
||||
```
|
||||
|
||||
### Forward Pass
|
||||
|
||||
**Layer 1:**
|
||||
|
||||
```math
|
||||
\mathbf{h}_1 = f_1(\mathbf{X} \mathbf{W}_1 + \mathbf{b}_1)
|
||||
```
|
||||
|
||||
**Layer 2:**
|
||||
|
||||
```math
|
||||
\mathbf{h}_2 = f_2(\mathbf{h}_1 \mathbf{W}_2 + \mathbf{b}_2)
|
||||
```
|
||||
|
||||
**Output Layer:**
|
||||
|
||||
```math
|
||||
\mathbf{Y} = f_3(\mathbf{h}_2 \mathbf{W}_3 + \mathbf{b}_3)
|
||||
```
|
||||
|
||||
**Chained together:**
|
||||
|
||||
```math
|
||||
\mathbf{Y} = f_3(f_2(f_1(\mathbf{X} \mathbf{W}_1 + \mathbf{b}_1) \mathbf{W}_2 + \mathbf{b}_2) \mathbf{W}_3 + \mathbf{b}_3)
|
||||
```
|
||||
|
||||
**Each layer transforms the input!**
|
||||
|
||||
---
|
||||
|
||||
## 6.8 Exercise 1: Single Neuron Calculation
|
||||
|
||||
### Problem
|
||||
|
||||
**Given a single neuron with:**
|
||||
|
||||
- Inputs: $x_1 = 2.0, x_2 = -1.0, x_3 = 0.5$
|
||||
- Weights: $w_1 = 0.5, w_2 = -0.3, w_3 = 0.8$
|
||||
- Bias: $b = 0.2$
|
||||
- Activation function: ReLU $f(z) = \max(0, z)$
|
||||
|
||||
**Calculate the output of this neuron.**
|
||||
|
||||
### Step-by-Step Solution
|
||||
|
||||
#### Step 1: Weighted Sum
|
||||
|
||||
**Compute:**
|
||||
|
||||
```math
|
||||
z = \sum_{i=1}^{3} x_i w_i + b
|
||||
```
|
||||
|
||||
**Substitute values:**
|
||||
|
||||
```math
|
||||
z = (2.0 \times 0.5) + (-1.0 \times -0.3) + (0.5 \times 0.8) + 0.2
|
||||
```
|
||||
|
||||
**Calculate each term:**
|
||||
|
||||
```math
|
||||
z = (1.0) + (0.3) + (0.4) + 0.2
|
||||
```
|
||||
|
||||
**Sum:**
|
||||
|
||||
```math
|
||||
z = 1.0 + 0.3 + 0.4 + 0.2 = 1.9
|
||||
```
|
||||
|
||||
#### Step 2: Apply Activation Function
|
||||
|
||||
**Apply ReLU:**
|
||||
|
||||
```math
|
||||
y = \text{ReLU}(z) = \max(0, z) = \max(0, 1.9) = 1.9
|
||||
```
|
||||
|
||||
### Answer
|
||||
|
||||
**The output of the neuron is $y = 1.9$.**
|
||||
|
||||
### Verification
|
||||
|
||||
**Check calculation:**
|
||||
|
||||
- Input contribution 1: $2.0 \times 0.5 = 1.0$
|
||||
- Input contribution 2: $-1.0 \times -0.3 = 0.3$
|
||||
- Input contribution 3: $0.5 \times 0.8 = 0.4$
|
||||
- Bias: $0.2$
|
||||
- Total: $1.0 + 0.3 + 0.4 + 0.2 = 1.9$ ✓
|
||||
- ReLU(1.9) = 1.9 ✓
|
||||
|
||||
---
|
||||
|
||||
## 6.9 Exercise 2: Multi-Layer Network
|
||||
|
||||
### Problem
|
||||
|
||||
**Given a neural network with 2 layers:**
|
||||
|
||||
**Layer 1:**
|
||||
|
||||
- Inputs: $x_1 = 1.0, x_2 = 0.5$
|
||||
- Weights: $W_1 = \begin{bmatrix} 0.6 & 0.4 \\ 0.2 & 0.8 \end{bmatrix}$
|
||||
- Bias: $b_1 = [0.1, -0.1]$
|
||||
- Activation: ReLU
|
||||
|
||||
**Layer 2:**
|
||||
|
||||
- Inputs: Outputs from Layer 1
|
||||
- Weights: $W_2 = \begin{bmatrix} 0.5 \\ 0.7 \end{bmatrix}$
|
||||
- Bias: $b_2 = 0.2$
|
||||
- Activation: ReLU
|
||||
|
||||
**Calculate the final output.**
|
||||
|
||||
### Step-by-Step Solution
|
||||
|
||||
#### Step 1: Layer 1 - Weighted Sum
|
||||
|
||||
**Input vector:**
|
||||
|
||||
```math
|
||||
\mathbf{x} = [1.0, 0.5]
|
||||
```
|
||||
|
||||
**Weight matrix:**
|
||||
|
||||
```math
|
||||
\mathbf{W}_1 = \begin{bmatrix} 0.6 & 0.4 \\ 0.2 & 0.8 \end{bmatrix}
|
||||
```
|
||||
|
||||
**Bias vector:**
|
||||
|
||||
```math
|
||||
\mathbf{b}_1 = [0.1, -0.1]
|
||||
```
|
||||
|
||||
**Calculate:**
|
||||
|
||||
```math
|
||||
\mathbf{z}_1 = \mathbf{x} \mathbf{W}_1 + \mathbf{b}_1
|
||||
```
|
||||
|
||||
**Matrix multiplication:**
|
||||
|
||||
```math
|
||||
\mathbf{z}_1 = [1.0, 0.5] \begin{bmatrix} 0.6 & 0.4 \\ 0.2 & 0.8 \end{bmatrix} + [0.1, -0.1]
|
||||
```
|
||||
|
||||
**Compute:**
|
||||
|
||||
```math
|
||||
z_{1,1} = 1.0 \times 0.6 + 0.5 \times 0.2 + 0.1 = 0.6 + 0.1 + 0.1 = 0.8
|
||||
```
|
||||
|
||||
```math
|
||||
z_{1,2} = 1.0 \times 0.4 + 0.5 \times 0.8 + (-0.1) = 0.4 + 0.4 - 0.1 = 0.7
|
||||
```
|
||||
|
||||
```math
|
||||
\mathbf{z}_1 = [0.8, 0.7]
|
||||
```
|
||||
|
||||
#### Step 2: Layer 1 - Apply Activation
|
||||
|
||||
**Apply ReLU:**
|
||||
|
||||
```math
|
||||
\mathbf{h}_1 = \text{ReLU}(\mathbf{z}_1) = [\max(0, 0.8), \max(0, 0.7)] = [0.8, 0.7]
|
||||
```
|
||||
|
||||
#### Step 3: Layer 2 - Weighted Sum
|
||||
|
||||
**Input (from Layer 1):**
|
||||
|
||||
```math
|
||||
\mathbf{h}_1 = [0.8, 0.7]
|
||||
```
|
||||
|
||||
**Weight matrix:**
|
||||
|
||||
```math
|
||||
\mathbf{W}_2 = \begin{bmatrix} 0.5 \\ 0.7 \end{bmatrix}
|
||||
```
|
||||
|
||||
**Bias:**
|
||||
|
||||
```math
|
||||
b_2 = 0.2
|
||||
```
|
||||
|
||||
**Calculate:**
|
||||
|
||||
```math
|
||||
z_2 = \mathbf{h}_1 \mathbf{W}_2 + b_2
|
||||
```
|
||||
|
||||
**Matrix multiplication:**
|
||||
|
||||
```math
|
||||
z_2 = [0.8, 0.7] \begin{bmatrix} 0.5 \\ 0.7 \end{bmatrix} + 0.2
|
||||
```
|
||||
|
||||
**Compute:**
|
||||
|
||||
```math
|
||||
z_2 = 0.8 \times 0.5 + 0.7 \times 0.7 + 0.2 = 0.4 + 0.49 + 0.2 = 1.09
|
||||
```
|
||||
|
||||
#### Step 4: Layer 2 - Apply Activation
|
||||
|
||||
**Apply ReLU:**
|
||||
|
||||
```math
|
||||
y = \text{ReLU}(z_2) = \max(0, 1.09) = 1.09
|
||||
```
|
||||
|
||||
### Answer
|
||||
|
||||
**The final output is $y = 1.09$.**
|
||||
|
||||
### Summary Table
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th>Layer</th>
|
||||
<th>Input</th>
|
||||
<th>Weights</th>
|
||||
<th>Bias</th>
|
||||
<th>Weighted Sum</th>
|
||||
<th>Activation</th>
|
||||
<th>Output</th>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>1</td>
|
||||
<td>[1.0, 0.5]</td>
|
||||
<td>$$\begin{bmatrix} 0.6 & 0.4 \\ 0.2 & 0.8 \end{bmatrix}$$</td>
|
||||
<td>[0.1, -0.1]</td>
|
||||
<td>[0.8, 0.7]</td>
|
||||
<td>ReLU</td>
|
||||
<td>[0.8, 0.7]</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>2</td>
|
||||
<td>[0.8, 0.7]</td>
|
||||
<td>$$\begin{bmatrix} 0.5 \\ 0.7 \end{bmatrix}$$</td>
|
||||
<td>0.2</td>
|
||||
<td>1.09</td>
|
||||
<td>ReLU</td>
|
||||
<td><strong>1.09</strong></td>
|
||||
</tr>
|
||||
</table>
|
||||
---
|
||||
|
||||
## 6.10 Exercise 3: Learning Weights
|
||||
|
||||
### Problem
|
||||
|
||||
**Given a neuron that should output 1.0 when inputs are [1.0, 1.0] and output 0.0 when inputs are [0.0, 0.0], find appropriate weights and bias.**
|
||||
|
||||
**Use:**
|
||||
|
||||
- Activation: Sigmoid $f(z) = \frac{1}{1 + e^{-z}}$
|
||||
- Desired behavior: AND gate (output 1 only when both inputs are 1)
|
||||
|
||||
### Step-by-Step Solution
|
||||
|
||||
#### Step 1: Set Up Equations
|
||||
|
||||
**For input [1.0, 1.0], desired output ≈ 1.0:**
|
||||
|
||||
```math
|
||||
f(w_1 \times 1.0 + w_2 \times 1.0 + b) = 1.0
|
||||
```
|
||||
|
||||
**For input [0.0, 0.0], desired output ≈ 0.0:**
|
||||
|
||||
```math
|
||||
f(w_1 \times 0.0 + w_2 \times 0.0 + b) = 0.0
|
||||
```
|
||||
|
||||
**Note:** Sigmoid outputs range from 0 to 1, so:
|
||||
|
||||
- $f(z) \approx 1.0$ when $z \gg 0$ (e.g., $z > 5$)
|
||||
- $f(z) \approx 0.0$ when $z \ll 0$ (e.g., $z < -5$)
|
||||
|
||||
#### Step 2: Solve for Bias
|
||||
|
||||
**From equation 2:**
|
||||
|
||||
```math
|
||||
f(b) = 0.0
|
||||
```
|
||||
|
||||
**For sigmoid to output ≈ 0:**
|
||||
|
||||
```math
|
||||
b < -5
|
||||
```
|
||||
|
||||
**Let's use:**
|
||||
|
||||
```math
|
||||
b = -10
|
||||
```
|
||||
|
||||
#### Step 3: Solve for Weights
|
||||
|
||||
**From equation 1:**
|
||||
|
||||
```math
|
||||
f(w_1 + w_2 - 10) = 1.0
|
||||
```
|
||||
|
||||
**For sigmoid to output ≈ 1:**
|
||||
|
||||
```math
|
||||
w_1 + w_2 - 10 > 5
|
||||
```
|
||||
|
||||
```math
|
||||
w_1 + w_2 > 15
|
||||
```
|
||||
|
||||
**Let's use equal weights:**
|
||||
|
||||
```math
|
||||
w_1 = w_2 = 8.0
|
||||
```
|
||||
|
||||
**Check:**
|
||||
|
||||
```math
|
||||
w_1 + w_2 = 8.0 + 8.0 = 16.0 > 15 \quad ✓
|
||||
```
|
||||
|
||||
#### Step 4: Verify Solution
|
||||
|
||||
**Test Case 1: Input [1.0, 1.0]**
|
||||
|
||||
```math
|
||||
z = 1.0 \times 8.0 + 1.0 \times 8.0 + (-10) = 8.0 + 8.0 - 10 = 6.0
|
||||
```
|
||||
|
||||
```math
|
||||
y = \frac{1}{1 + e^{-6.0}} = \frac{1}{1 + 0.0025} \approx 0.9975 \approx 1.0 \quad ✓
|
||||
```
|
||||
|
||||
**Test Case 2: Input [0.0, 0.0]**
|
||||
|
||||
```math
|
||||
z = 0.0 \times 8.0 + 0.0 \times 8.0 + (-10) = -10
|
||||
```
|
||||
|
||||
```math
|
||||
y = \frac{1}{1 + e^{10}} = \frac{1}{1 + 22026} \approx 0.00005 \approx 0.0 \quad ✓
|
||||
```
|
||||
|
||||
**Test Case 3: Input [1.0, 0.0]**
|
||||
|
||||
```math
|
||||
z = 1.0 \times 8.0 + 0.0 \times 8.0 + (-10) = 8.0 - 10 = -2.0
|
||||
```
|
||||
|
||||
```math
|
||||
y = \frac{1}{1 + e^{2.0}} = \frac{1}{1 + 7.39} \approx 0.12 < 0.5 \quad ✓
|
||||
```
|
||||
|
||||
**Test Case 4: Input [0.0, 1.0]**
|
||||
|
||||
```math
|
||||
z = 0.0 \times 8.0 + 1.0 \times 8.0 + (-10) = 8.0 - 10 = -2.0
|
||||
```
|
||||
|
||||
```math
|
||||
y = \frac{1}{1 + e^{2.0}} \approx 0.12 < 0.5 \quad ✓
|
||||
```
|
||||
|
||||
### Answer
|
||||
|
||||
**Appropriate weights and bias:**
|
||||
|
||||
- $w_1 = 8.0$
|
||||
- $w_2 = 8.0$
|
||||
- $b = -10.0$
|
||||
|
||||
**The neuron implements an AND gate correctly!**
|
||||
|
||||
### Key Insight
|
||||
|
||||
**This demonstrates learning:**
|
||||
|
||||
- Training finds weights that produce desired behavior
|
||||
- Different weights = Different logic functions
|
||||
- Learning algorithms (like backpropagation) automatically find these weights from data!
|
||||
|
||||
---
|
||||
|
||||
## 6.11 Key Takeaways
|
||||
|
||||
### Neurons
|
||||
|
||||
✅ **Neurons are the basic processing units**
|
||||
✅ **Receive inputs, compute weighted sum, apply activation**
|
||||
✅ **Output is the result of activation function**
|
||||
|
||||
### Weights
|
||||
|
||||
✅ **Weights control connection strength**
|
||||
✅ **Determine what patterns neurons learn**
|
||||
✅ **Are what the model learns during training**
|
||||
✅ **Enable complex pattern recognition**
|
||||
|
||||
### Calculation
|
||||
|
||||
✅ **Weighted sum: $z = \sum x_i w_i + b$**
|
||||
✅ **Activation: $y = f(z)$**
|
||||
✅ **Matrix form enables efficient computation**
|
||||
|
||||
### Importance
|
||||
|
||||
✅ **Weights enable learning**
|
||||
✅ **Control information flow**
|
||||
✅ **Enable complex pattern recognition**
|
||||
✅ **Are adjusted during training to minimize error**
|
||||
|
||||
### Neural Networks
|
||||
|
||||
✅ **Multiple neurons form layers**
|
||||
✅ **Multiple layers form networks**
|
||||
✅ **Each layer transforms the input**
|
||||
✅ **Deep networks learn hierarchical features**
|
||||
|
||||
---
|
||||
|
||||
## Mathematical Summary
|
||||
|
||||
### Single Neuron
|
||||
|
||||
```math
|
||||
z = \sum_{i=1}^{n} x_i w_i + b
|
||||
```
|
||||
|
||||
```math
|
||||
y = f(z)
|
||||
```
|
||||
|
||||
### Multiple Neurons (Matrix Form)
|
||||
|
||||
```math
|
||||
\mathbf{z} = \mathbf{X} \mathbf{W} + \mathbf{b}
|
||||
```
|
||||
|
||||
```math
|
||||
\mathbf{Y} = f(\mathbf{z})
|
||||
```
|
||||
|
||||
### Multi-Layer Network
|
||||
|
||||
```math
|
||||
\mathbf{h}_1 = f_1(\mathbf{X} \mathbf{W}_1 + \mathbf{b}_1)
|
||||
```
|
||||
|
||||
```math
|
||||
\mathbf{h}_2 = f_2(\mathbf{h}_1 \mathbf{W}_2 + \mathbf{b}_2)
|
||||
```
|
||||
|
||||
```math
|
||||
\mathbf{Y} = f_3(\mathbf{h}_2 \mathbf{W}_3 + \mathbf{b}_3)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
_This document provides a comprehensive explanation of neural networks, neurons, weights, and calculations with mathematical derivations and solved exercises._
|
||||
640
docs/NORMALIZATION_EXPLAINED.md
Normal file
640
docs/NORMALIZATION_EXPLAINED.md
Normal file
@@ -0,0 +1,640 @@
|
||||
# What is Normalization? Step-by-Step Explanation
|
||||
|
||||
Complete step-by-step explanation of normalization in transformer models: how normalization stabilizes training and improves model performance.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [The Problem Normalization Solves](#41-the-problem-normalization-solves)
|
||||
2. [What is Normalization?](#42-what-is-normalization)
|
||||
3. [How Layer Normalization Works: Step-by-Step](#43-how-layer-normalization-works-step-by-step)
|
||||
4. [Complete Example: Normalizing a Vector](#44-complete-example-normalizing-a-vector)
|
||||
5. [Why Normalization Matters](#45-why-normalization-matters)
|
||||
6. [Pre-Norm vs Post-Norm Architecture](#46-pre-norm-vs-post-norm-architecture)
|
||||
7. [Visual Representation](#47-visual-representation)
|
||||
8. [Key Takeaways](#48-key-takeaways)
|
||||
|
||||
---
|
||||
|
||||
## 4.1 The Problem Normalization Solves
|
||||
|
||||
### The Challenge
|
||||
|
||||
**During training, activations can become unstable:**
|
||||
|
||||
**Problem 1: Varying Activations**
|
||||
```
|
||||
Layer 1 output: [0.1, 0.2, 0.3, ...] (small values)
|
||||
Layer 2 output: [10.5, 20.3, 15.8, ...] (large values)
|
||||
Layer 3 output: [0.01, 0.02, 0.03, ...] (very small values)
|
||||
```
|
||||
|
||||
**Problem 2: Internal Covariate Shift**
|
||||
- Activations change distribution as weights update
|
||||
- Later layers struggle to adapt to changing inputs
|
||||
- Training becomes slower and less stable
|
||||
|
||||
**Problem 3: Gradient Problems**
|
||||
```
|
||||
Large activations → Large gradients → Exploding gradients
|
||||
Small activations → Small gradients → Vanishing gradients
|
||||
```
|
||||
|
||||
### The Solution: Normalization
|
||||
|
||||
**Normalization standardizes activations to have consistent statistics (mean zero, variance one), making training stable and efficient.**
|
||||
|
||||
---
|
||||
|
||||
## 4.2 What is Normalization?
|
||||
|
||||
### Simple Definition
|
||||
|
||||
**Normalization** is a technique that transforms activations to have:
|
||||
- **Mean of zero** (centered)
|
||||
- **Variance of one** (standardized scale)
|
||||
|
||||
**Think of it like standardization:**
|
||||
- Converts any distribution to a standard form
|
||||
- Makes values comparable across different scales
|
||||
- Helps the model learn faster and more reliably
|
||||
|
||||
### Visual Analogy
|
||||
|
||||
**Imagine weights on a scale:**
|
||||
|
||||
**Before Normalization:**
|
||||
```
|
||||
Bronze weight: 1 kg
|
||||
Silver weight: 100 kg
|
||||
Gold weight: 0.001 kg
|
||||
→ Hard to compare!
|
||||
```
|
||||
|
||||
**After Normalization:**
|
||||
```
|
||||
All weights standardized to mean 0, variance 1
|
||||
→ Easy to compare and work with!
|
||||
```
|
||||
|
||||
### Types of Normalization
|
||||
|
||||
**In transformers, we use Layer Normalization:**
|
||||
|
||||
- **Layer Normalization:** Normalizes across features (dimensions) for each sample
|
||||
- **Batch Normalization:** Normalizes across samples in a batch (not used in transformers)
|
||||
- **Instance Normalization:** Normalizes each sample independently
|
||||
|
||||
**Why Layer Normalization?**
|
||||
- Works well with variable sequence lengths
|
||||
- Doesn't depend on batch size
|
||||
- Suitable for autoregressive models
|
||||
|
||||
---
|
||||
|
||||
## 4.3 How Layer Normalization Works: Step-by-Step
|
||||
|
||||
### High-Level Overview
|
||||
|
||||
```
|
||||
Step 1: Compute mean of activations
|
||||
Step 2: Compute variance of activations
|
||||
Step 3: Normalize (subtract mean, divide by std)
|
||||
Step 4: Scale and shift (learnable parameters)
|
||||
```
|
||||
|
||||
### Detailed Step-by-Step
|
||||
|
||||
#### Step 1: Compute Mean
|
||||
|
||||
**Calculate the average value across all dimensions:**
|
||||
|
||||
```math
|
||||
\mu = \frac{1}{d} \sum_{i=1}^{d} x_i
|
||||
```
|
||||
|
||||
**Example:**
|
||||
|
||||
**Input vector:**
|
||||
```
|
||||
x = [1.0, 2.0, 3.0, 4.0]
|
||||
d = 4 (number of dimensions)
|
||||
```
|
||||
|
||||
**Compute mean:**
|
||||
```
|
||||
μ = (1.0 + 2.0 + 3.0 + 4.0) / 4
|
||||
= 10.0 / 4
|
||||
= 2.5
|
||||
```
|
||||
|
||||
**Meaning:** The center of the distribution is at 2.5
|
||||
|
||||
#### Step 2: Compute Variance
|
||||
|
||||
**Measure how spread out the values are:**
|
||||
|
||||
```math
|
||||
\sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2
|
||||
```
|
||||
|
||||
**Example:**
|
||||
|
||||
**Using the same input:**
|
||||
```
|
||||
x = [1.0, 2.0, 3.0, 4.0]
|
||||
μ = 2.5
|
||||
```
|
||||
|
||||
**Compute variance:**
|
||||
```
|
||||
σ² = [(1.0 - 2.5)² + (2.0 - 2.5)² + (3.0 - 2.5)² + (4.0 - 2.5)²] / 4
|
||||
= [(-1.5)² + (-0.5)² + (0.5)² + (1.5)²] / 4
|
||||
= [2.25 + 0.25 + 0.25 + 2.25] / 4
|
||||
= 5.0 / 4
|
||||
= 1.25
|
||||
```
|
||||
|
||||
**Compute standard deviation:**
|
||||
```
|
||||
σ = √σ² = √1.25 ≈ 1.118
|
||||
```
|
||||
|
||||
**Meaning:** Values are spread out with standard deviation of 1.118
|
||||
|
||||
#### Step 3: Normalize
|
||||
|
||||
**Subtract mean and divide by standard deviation:**
|
||||
|
||||
```math
|
||||
\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}
|
||||
```
|
||||
|
||||
**Where:**
|
||||
- $\epsilon$ is a small constant (default: 1e-5) to prevent division by zero
|
||||
|
||||
**Example:**
|
||||
|
||||
**Using the same input:**
|
||||
```
|
||||
x = [1.0, 2.0, 3.0, 4.0]
|
||||
μ = 2.5
|
||||
σ ≈ 1.118
|
||||
ε = 0.00001
|
||||
```
|
||||
|
||||
**Normalize each element:**
|
||||
```
|
||||
x̂₁ = (1.0 - 2.5) / (1.118 + 0.00001) ≈ -1.341
|
||||
x̂₂ = (2.0 - 2.5) / (1.118 + 0.00001) ≈ -0.447
|
||||
x̂₃ = (3.0 - 2.5) / (1.118 + 0.00001) ≈ 0.447
|
||||
x̂₄ = (4.0 - 2.5) / (1.118 + 0.00001) ≈ 1.341
|
||||
```
|
||||
|
||||
**Result:**
|
||||
```
|
||||
x̂ = [-1.341, -0.447, 0.447, 1.341]
|
||||
```
|
||||
|
||||
**Check:**
|
||||
- Mean ≈ 0 ✓
|
||||
- Standard deviation ≈ 1 ✓
|
||||
|
||||
**Meaning:** Values are now standardized!
|
||||
|
||||
#### Step 4: Scale and Shift
|
||||
|
||||
**Apply learnable parameters:**
|
||||
|
||||
```math
|
||||
\text{LayerNorm}(x) = \gamma \odot \hat{x} + \beta
|
||||
```
|
||||
|
||||
**Where:**
|
||||
- $\gamma$ = learnable scale parameter (initialized to 1)
|
||||
- $\beta$ = learnable shift parameter (initialized to 0)
|
||||
- $\odot$ = element-wise multiplication
|
||||
|
||||
**Example:**
|
||||
|
||||
**Normalized vector:**
|
||||
```
|
||||
x̂ = [-1.341, -0.447, 0.447, 1.341]
|
||||
```
|
||||
|
||||
**Learnable parameters (initialized):**
|
||||
```
|
||||
γ = [1.0, 1.0, 1.0, 1.0] (scale)
|
||||
β = [0.0, 0.0, 0.0, 0.0] (shift)
|
||||
```
|
||||
|
||||
**Apply scale and shift:**
|
||||
```
|
||||
Output = γ ⊙ x̂ + β
|
||||
= [1.0, 1.0, 1.0, 1.0] ⊙ [-1.341, -0.447, 0.447, 1.341] + [0.0, 0.0, 0.0, 0.0]
|
||||
= [-1.341, -0.447, 0.447, 1.341] + [0.0, 0.0, 0.0, 0.0]
|
||||
= [-1.341, -0.447, 0.447, 1.341]
|
||||
```
|
||||
|
||||
**Initially, normalization is identity!**
|
||||
**During training, γ and β learn optimal scale and shift.**
|
||||
|
||||
---
|
||||
|
||||
## 4.4 Complete Example: Normalizing a Vector
|
||||
|
||||
### Input
|
||||
|
||||
```
|
||||
Word embedding after attention: [0.146, 0.108, 0.192, 0.155, ..., 0.11]
|
||||
Dimension: 512
|
||||
```
|
||||
|
||||
### Step-by-Step Processing
|
||||
|
||||
#### Step 1: Compute Mean
|
||||
|
||||
**Input:**
|
||||
```
|
||||
x = [0.146, 0.108, 0.192, ..., 0.11] (512 numbers)
|
||||
```
|
||||
|
||||
**Compute mean:**
|
||||
```
|
||||
μ = (0.146 + 0.108 + 0.192 + ... + 0.11) / 512
|
||||
≈ 0.135
|
||||
```
|
||||
|
||||
**Visualization:**
|
||||
```
|
||||
Values: [0.146, 0.108, 0.192, ..., 0.11]
|
||||
└────────────────────────────────┘
|
||||
Mean: 0.135 (center point)
|
||||
```
|
||||
|
||||
#### Step 2: Compute Variance
|
||||
|
||||
**Compute variance:**
|
||||
```
|
||||
σ² = [(0.146 - 0.135)² + (0.108 - 0.135)² + (0.192 - 0.135)² + ... + (0.11 - 0.135)²] / 512
|
||||
≈ 0.0023
|
||||
```
|
||||
|
||||
**Compute standard deviation:**
|
||||
```
|
||||
σ = √0.0023 ≈ 0.048
|
||||
```
|
||||
|
||||
**Visualization:**
|
||||
```
|
||||
Values: [0.146, 0.108, 0.192, ..., 0.11]
|
||||
Spread: └───────── σ ≈ 0.048 ──────────┘
|
||||
```
|
||||
|
||||
#### Step 3: Normalize
|
||||
|
||||
**Normalize each element:**
|
||||
```
|
||||
x̂₁ = (0.146 - 0.135) / (0.048 + 0.00001) ≈ 0.229
|
||||
x̂₂ = (0.108 - 0.135) / (0.048 + 0.00001) ≈ -0.562
|
||||
x̂₃ = (0.192 - 0.135) / (0.048 + 0.00001) ≈ 1.188
|
||||
...
|
||||
x̂₅₁₂ = (0.11 - 0.135) / (0.048 + 0.00001) ≈ -0.521
|
||||
```
|
||||
|
||||
**Result:**
|
||||
```
|
||||
x̂ = [0.229, -0.562, 1.188, ..., -0.521]
|
||||
```
|
||||
|
||||
**Properties:**
|
||||
- Mean ≈ 0 ✓
|
||||
- Standard deviation ≈ 1 ✓
|
||||
|
||||
#### Step 4: Scale and Shift
|
||||
|
||||
**Apply learnable parameters:**
|
||||
```
|
||||
γ = [1.0, 1.0, ..., 1.0] (512 values, may change during training)
|
||||
β = [0.0, 0.0, ..., 0.0] (512 values, may change during training)
|
||||
```
|
||||
|
||||
**Output:**
|
||||
```
|
||||
Output = γ ⊙ x̂ + β
|
||||
= [0.229, -0.562, 1.188, ..., -0.521]
|
||||
```
|
||||
|
||||
**After training, γ and β adapt to optimal values!**
|
||||
|
||||
---
|
||||
|
||||
## 4.5 Why Normalization Matters
|
||||
|
||||
### Benefit 1: Stable Training
|
||||
|
||||
**Without Normalization:**
|
||||
```
|
||||
Layer 1: activations = [0.1, 0.2, ...]
|
||||
Layer 2: activations = [50.0, 100.0, ...] ← Exploding!
|
||||
Layer 3: activations = [0.001, 0.002, ...] ← Vanishing!
|
||||
```
|
||||
|
||||
**With Normalization:**
|
||||
```
|
||||
Layer 1: activations = [0.1, -0.2, ...] (normalized)
|
||||
Layer 2: activations = [0.3, -0.1, ...] (normalized)
|
||||
Layer 3: activations = [0.2, 0.4, ...] (normalized)
|
||||
→ Consistent scale throughout!
|
||||
```
|
||||
|
||||
### Benefit 2: Better Gradient Flow
|
||||
|
||||
**Normalization helps gradients flow better:**
|
||||
|
||||
**Without Normalization:**
|
||||
```
|
||||
Gradient 1: 0.0001 (too small, vanishing)
|
||||
Gradient 2: 1000.0 (too large, exploding)
|
||||
Gradient 3: 0.001 (too small)
|
||||
```
|
||||
|
||||
**With Normalization:**
|
||||
```
|
||||
Gradient 1: 0.01 (reasonable)
|
||||
Gradient 2: 0.02 (reasonable)
|
||||
Gradient 3: 0.015 (reasonable)
|
||||
→ Stable gradients!
|
||||
```
|
||||
|
||||
### Benefit 3: Faster Convergence
|
||||
|
||||
**Normalized activations allow:**
|
||||
- Higher learning rates
|
||||
- Faster weight updates
|
||||
- Quicker convergence to good solutions
|
||||
|
||||
**Analogy:**
|
||||
- **Without normalization:** Walking on rough terrain (slow progress)
|
||||
- **With normalization:** Walking on smooth path (fast progress)
|
||||
|
||||
### Benefit 4: Regularization Effect
|
||||
|
||||
**Normalization acts as a form of regularization:**
|
||||
- Reduces internal covariate shift
|
||||
- Makes optimization easier
|
||||
- Helps prevent overfitting
|
||||
|
||||
---
|
||||
|
||||
## 4.6 Pre-Norm vs Post-Norm Architecture
|
||||
|
||||
### Post-Norm (Original Transformer)
|
||||
|
||||
**Order:**
|
||||
```
|
||||
Input → Attention → LayerNorm → Output
|
||||
```
|
||||
|
||||
**Equation:**
|
||||
```
|
||||
x_out = LayerNorm(x + Attention(x))
|
||||
```
|
||||
|
||||
**Problems:**
|
||||
- Can be unstable with many layers
|
||||
- Gradient flow can be difficult
|
||||
- Harder to train deep networks
|
||||
|
||||
### Pre-Norm (Modern Approach)
|
||||
|
||||
**Order:**
|
||||
```
|
||||
Input → LayerNorm → Attention → Output
|
||||
```
|
||||
|
||||
**Equation:**
|
||||
```
|
||||
x_out = x + Attention(LayerNorm(x))
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- More stable training
|
||||
- Better gradient flow
|
||||
- Easier to train deep networks
|
||||
|
||||
**Visual Comparison:**
|
||||
|
||||
**Post-Norm:**
|
||||
```
|
||||
Input
|
||||
↓
|
||||
┌──────────────┐
|
||||
│ Attention │
|
||||
└──────┬───────┘
|
||||
↓
|
||||
┌──────────────┐
|
||||
│ LayerNorm │ ← Normalization after
|
||||
└──────┬───────┘
|
||||
↓
|
||||
Output
|
||||
```
|
||||
|
||||
**Pre-Norm:**
|
||||
```
|
||||
Input
|
||||
↓
|
||||
┌──────────────┐
|
||||
│ LayerNorm │ ← Normalization before
|
||||
└──────┬───────┘
|
||||
↓
|
||||
┌──────────────┐
|
||||
│ Attention │
|
||||
└──────┬───────┘
|
||||
↓
|
||||
Output
|
||||
```
|
||||
|
||||
**Our Model Uses Pre-Norm!**
|
||||
|
||||
---
|
||||
|
||||
## 4.7 Visual Representation
|
||||
|
||||
### Normalization Process
|
||||
|
||||
```
|
||||
Input Vector
|
||||
│
|
||||
│ [1.0, 2.0, 3.0, 4.0]
|
||||
↓
|
||||
┌─────────────────────────────┐
|
||||
│ Step 1: Compute Mean │
|
||||
│ μ = 2.5 │
|
||||
└──────────┬──────────────────┘
|
||||
│
|
||||
↓
|
||||
┌─────────────────────────────┐
|
||||
│ Step 2: Compute Variance │
|
||||
│ σ² = 1.25, σ ≈ 1.118 │
|
||||
└──────────┬──────────────────┘
|
||||
│
|
||||
↓
|
||||
┌────────────────────────────────┐
|
||||
│ Step 3: Normalize │
|
||||
│ x̂ = (x - μ) / σ │
|
||||
│ [-1.341, -0.447, 0.447, 1.341] │
|
||||
└──────────┬─────────────────────┘
|
||||
│
|
||||
↓
|
||||
┌─────────────────────────────┐
|
||||
│ Step 4: Scale and Shift │
|
||||
│ Output = γ ⊙ x̂ + β │
|
||||
└──────────┬──────────────────┘
|
||||
│
|
||||
↓
|
||||
Output Vector
|
||||
```
|
||||
|
||||
### Distribution Transformation
|
||||
|
||||
**Before Normalization:**
|
||||
```
|
||||
Distribution:
|
||||
│
|
||||
0.4│ ●
|
||||
│ ● ●
|
||||
0.3│ ● ●
|
||||
│ ● ●
|
||||
0.2│ ● ●
|
||||
│● ●
|
||||
0.1│ ●
|
||||
│
|
||||
0.0├─────────────────────────
|
||||
0 1 2 3 4 5
|
||||
Mean: 2.5, Std: 1.118
|
||||
```
|
||||
|
||||
**After Normalization:**
|
||||
```
|
||||
Distribution:
|
||||
│
|
||||
0.4│ ●
|
||||
│ ● ●
|
||||
0.3│ ● ●
|
||||
│ ● ●
|
||||
0.2│ ● ●
|
||||
│● ●
|
||||
0.1│ ●
|
||||
│
|
||||
0.0├─────────────────────────
|
||||
-2 -1 0 1 2 3
|
||||
Mean: 0, Std: 1
|
||||
```
|
||||
|
||||
**Standardized!**
|
||||
|
||||
### Gradient Flow Visualization
|
||||
|
||||
**Without Normalization:**
|
||||
```
|
||||
Gradient Magnitude:
|
||||
│
|
||||
1000│ ●
|
||||
│
|
||||
100│
|
||||
│
|
||||
10│
|
||||
│
|
||||
1│ ●
|
||||
│
|
||||
0.1│ ●
|
||||
│
|
||||
0.01│
|
||||
└──────────────────────── Layer
|
||||
1 2 3 4 5
|
||||
(Unstable, varying magnitudes)
|
||||
```
|
||||
|
||||
**With Normalization:**
|
||||
```
|
||||
Gradient Magnitude:
|
||||
│
|
||||
1000│
|
||||
│
|
||||
100│
|
||||
│
|
||||
10│
|
||||
│ ● ● ● ● ●
|
||||
1│
|
||||
│
|
||||
0.1│
|
||||
│
|
||||
0.01│
|
||||
└──────────────────────── Layer
|
||||
1 2 3 4 5
|
||||
(Stable, consistent magnitudes)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4.8 Key Takeaways: Normalization
|
||||
|
||||
✅ **Normalization standardizes activations to mean 0, variance 1**
|
||||
✅ **Stabilizes training by preventing exploding/vanishing gradients**
|
||||
✅ **Enables faster convergence and higher learning rates**
|
||||
✅ **Pre-norm architecture is preferred for deep networks**
|
||||
✅ **Learnable parameters (γ, β) allow optimal scaling**
|
||||
|
||||
---
|
||||
|
||||
## Complete Mathematical Formula
|
||||
|
||||
### Layer Normalization Formula
|
||||
|
||||
For input $\mathbf{x} \in \mathbb{R}^d$:
|
||||
|
||||
```math
|
||||
\mu = \frac{1}{d} \sum_{i=1}^{d} x_i
|
||||
```
|
||||
|
||||
```math
|
||||
\sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2
|
||||
```
|
||||
|
||||
```math
|
||||
\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}
|
||||
```
|
||||
|
||||
```math
|
||||
\text{LayerNorm}(\mathbf{x}) = \gamma \odot \hat{\mathbf{x}} + \beta
|
||||
```
|
||||
|
||||
**Where:**
|
||||
- $\epsilon$ = small constant (default: 1e-5) to prevent division by zero
|
||||
- $\gamma$ = learnable scale parameter (initialized to 1)
|
||||
- $\beta$ = learnable shift parameter (initialized to 0)
|
||||
- $\odot$ = element-wise multiplication
|
||||
- $d$ = number of dimensions
|
||||
|
||||
### In Transformer Block
|
||||
|
||||
**Pre-Norm Architecture:**
|
||||
|
||||
```math
|
||||
\mathbf{x}_{norm} = \text{LayerNorm}(\mathbf{x}_{in})
|
||||
```
|
||||
|
||||
```math
|
||||
\mathbf{x}_{attn} = \text{Attention}(\mathbf{x}_{norm})
|
||||
```
|
||||
|
||||
```math
|
||||
\mathbf{x}_{out} = \mathbf{x}_{in} + \mathbf{x}_{attn} \quad \text{(residual connection)}
|
||||
```
|
||||
|
||||
**Normalization happens before attention and feed-forward!**
|
||||
|
||||
---
|
||||
|
||||
*This document provides a step-by-step explanation of normalization, the critical component that stabilizes training and enables efficient learning in transformer models.*
|
||||
|
||||
223
docs/OPTIMIZATIONS.md
Normal file
223
docs/OPTIMIZATIONS.md
Normal file
@@ -0,0 +1,223 @@
|
||||
# Optimizations for Production RAG Systems
|
||||
|
||||
This document describes the optimizations implemented based on the paper "Optimizing LLM Inference and Retrieval: Novel Data Structures and Algorithms for Production RAG Systems".
|
||||
|
||||
## Implemented Optimizations
|
||||
|
||||
### 1. KV Cache for Efficient Autoregressive Generation
|
||||
|
||||
**Location**: `models/optimized_attention.py`
|
||||
|
||||
The KV (Key-Value) cache mechanism stores computed keys and values from previous tokens during autoregressive generation, eliminating redundant computations.
|
||||
|
||||
**Benefits**:
|
||||
- Reduces computational cost from O(n²) to O(n) for each new token
|
||||
- Significantly faster generation for long sequences
|
||||
- Lower memory bandwidth usage
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
from models import TransformerModel, OptimizedInference
|
||||
|
||||
model = TransformerModel(...)
|
||||
optimizer = model.get_optimized_inference()
|
||||
|
||||
# Generate with KV caching
|
||||
generated = optimizer.generate_with_cache(
|
||||
input_ids=input_ids,
|
||||
max_length=100,
|
||||
temperature=0.8,
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Optimized Attention Computation
|
||||
|
||||
**Location**: `models/optimized_attention.py`
|
||||
|
||||
Implements optimized attention computation using PyTorch's `scaled_dot_product_attention` when available (similar to Flash Attention).
|
||||
|
||||
**Features**:
|
||||
- Uses PyTorch's optimized attention implementation
|
||||
- Supports causal masking efficiently
|
||||
- Reduces memory usage during attention computation
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
from models import TransformerModel
|
||||
from models.blocks import TransformerBlock
|
||||
|
||||
# Use optimized attention in transformer blocks
|
||||
block = TransformerBlock(
|
||||
d_model=512,
|
||||
num_heads=8,
|
||||
use_optimized_attention=True, # Enable optimized attention
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Retrieval Cache for Similar Queries
|
||||
|
||||
**Location**: `models/optimized_attention.py`
|
||||
|
||||
Implements approximate caching for retrieval results, reducing expensive vector database lookups by caching similar queries.
|
||||
|
||||
**Features**:
|
||||
- Cosine similarity-based cache lookup
|
||||
- Configurable similarity threshold
|
||||
- Automatic cache eviction when full
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
from models.optimized_attention import RetrievalCache
|
||||
|
||||
cache = RetrievalCache(max_size=1000, similarity_threshold=0.9)
|
||||
|
||||
# Store retrieval results
|
||||
cache.set(query_hash, query_embedding, retrieved_docs)
|
||||
|
||||
# Retrieve cached results
|
||||
results = cache.get(query_hash, query_embedding)
|
||||
```
|
||||
|
||||
### 4. Prefetching Mechanisms
|
||||
|
||||
**Location**: `models/prefetching.py`
|
||||
|
||||
#### 4.1 PrefetchDataLoader
|
||||
Prefetches batches in background threads, reducing GPU idle time.
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
from models.prefetching import PrefetchDataLoader
|
||||
from data import create_dataloader
|
||||
|
||||
dataloader = create_dataloader(...)
|
||||
prefetch_loader = PrefetchDataLoader(
|
||||
dataloader=dataloader,
|
||||
prefetch_factor=2,
|
||||
device=device,
|
||||
)
|
||||
```
|
||||
|
||||
#### 4.2 LookaheadRetriever
|
||||
Prefetches retrieval results for anticipated queries.
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
from models.prefetching import LookaheadRetriever
|
||||
|
||||
def retrieve(query: str):
|
||||
# Your retrieval function
|
||||
return documents
|
||||
|
||||
retriever = LookaheadRetriever(
|
||||
retrieval_fn=retrieve,
|
||||
lookahead_window=3,
|
||||
)
|
||||
|
||||
# Start prefetching
|
||||
retriever.start_prefetching(query_stream)
|
||||
|
||||
# Get results (checks cache first)
|
||||
results = retriever.get(query)
|
||||
```
|
||||
|
||||
#### 4.3 BatchPrefetcher
|
||||
Groups queries into batches for efficient batch retrieval.
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
from models.prefetching import BatchPrefetcher
|
||||
|
||||
def batch_retrieve(queries: List[str]):
|
||||
# Batch retrieval function
|
||||
return [documents for each query]
|
||||
|
||||
prefetcher = BatchPrefetcher(
|
||||
batch_retrieval_fn=batch_retrieve,
|
||||
batch_size=8,
|
||||
)
|
||||
|
||||
prefetcher.start_prefetching(query_stream)
|
||||
results = prefetcher.get(query)
|
||||
```
|
||||
|
||||
### 5. Optimized Batch Inference
|
||||
|
||||
**Location**: `models/optimized_attention.py`
|
||||
|
||||
The `OptimizedInference` class provides batch generation utilities for processing multiple prompts efficiently.
|
||||
|
||||
**Features**:
|
||||
- Batch processing for multiple prompts
|
||||
- Automatic padding and batching
|
||||
- Efficient memory usage
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
from models import OptimizedInference
|
||||
|
||||
optimizer = model.get_optimized_inference()
|
||||
|
||||
# Generate for multiple prompts in batches
|
||||
results = optimizer.batch_generate(
|
||||
input_ids_list=[prompt1_ids, prompt2_ids, ...],
|
||||
max_length=100,
|
||||
batch_size=8,
|
||||
)
|
||||
```
|
||||
|
||||
## Performance Improvements
|
||||
|
||||
These optimizations provide the following benefits:
|
||||
|
||||
1. **Faster Inference**: KV caching reduces generation time by 2-5x for long sequences
|
||||
2. **Reduced Latency**: Prefetching reduces end-to-end latency by overlapping computation and I/O
|
||||
3. **Lower Costs**: Retrieval caching reduces expensive vector database calls
|
||||
4. **Better Throughput**: Batch processing increases throughput for multiple requests
|
||||
|
||||
## Integration
|
||||
|
||||
### Using Optimized Inference in Production
|
||||
|
||||
1. **Enable optimized attention** (for inference only):
|
||||
```python
|
||||
model = TransformerModel(
|
||||
...,
|
||||
use_optimized_attention=True, # Set in TransformerBlock
|
||||
)
|
||||
```
|
||||
|
||||
2. **Use optimized inference utility**:
|
||||
```python
|
||||
optimizer = model.get_optimized_inference()
|
||||
generated = optimizer.generate_with_cache(...)
|
||||
```
|
||||
|
||||
3. **Enable prefetching**:
|
||||
```python
|
||||
prefetch_loader = PrefetchDataLoader(dataloader, prefetch_factor=2)
|
||||
```
|
||||
|
||||
### CLI Usage
|
||||
|
||||
Use the `--optimized` flag when running inference:
|
||||
|
||||
```bash
|
||||
python inference.py \
|
||||
--checkpoint checkpoints/best_checkpoint.pt \
|
||||
--prompt "Your prompt here" \
|
||||
--optimized \
|
||||
--max-length 100
|
||||
```
|
||||
|
||||
## Example Script
|
||||
|
||||
See `example_optimized.py` for complete examples of all optimizations.
|
||||
|
||||
## References
|
||||
|
||||
Based on optimizations from:
|
||||
- "Optimizing LLM Inference and Retrieval: Novel Data Structures and Algorithms for Production RAG Systems"
|
||||
- TeleRAG: Lookahead Retrieval Mechanism
|
||||
- Flash Attention optimization techniques
|
||||
|
||||
848
docs/OPTIMIZATION_EXPLAINED.md
Normal file
848
docs/OPTIMIZATION_EXPLAINED.md
Normal file
@@ -0,0 +1,848 @@
|
||||
# What is Optimization? Step-by-Step Explanation
|
||||
|
||||
Complete step-by-step explanation of optimization in neural networks: how optimizers update weights to minimize loss.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [What is Optimization?](#71-what-is-optimization)
|
||||
2. [The Optimization Problem](#72-the-optimization-problem)
|
||||
3. [Gradient Descent](#73-gradient-descent)
|
||||
4. [AdamW Optimizer](#74-adamw-optimizer)
|
||||
5. [Why Optimization Matters](#75-why-optimization-matters)
|
||||
6. [Complete Mathematical Formulation](#76-complete-mathematical-formulation)
|
||||
7. [Exercise: Optimizer Step-by-Step](#77-exercise-optimizer-step-by-step)
|
||||
8. [Key Takeaways](#78-key-takeaways)
|
||||
|
||||
---
|
||||
|
||||
## 7.1 What is Optimization?
|
||||
|
||||
### Simple Definition
|
||||
|
||||
**Optimization** is the process of finding the best set of weights (parameters) that minimize the loss function and make the model's predictions as accurate as possible.
|
||||
|
||||
### Visual Analogy
|
||||
|
||||
**Think of optimization like finding the lowest point in a valley:**
|
||||
|
||||
```
|
||||
Loss Landscape:
|
||||
|
||||
High Loss
|
||||
│
|
||||
│ ● (current position)
|
||||
│ ╱│╲
|
||||
│ ╱ │ ╲
|
||||
│ ╱ │ ╲
|
||||
│╱ │ ╲
|
||||
│ ▼ │
|
||||
│ (goal)│
|
||||
│ │
|
||||
Low Loss ─────┘
|
||||
```
|
||||
|
||||
**Goal:** Find the bottom of the valley (minimum loss)
|
||||
|
||||
**Optimizer:** Your guide down the mountain
|
||||
|
||||
### What Optimization Does
|
||||
|
||||
**Optimization:**
|
||||
1. **Measures** how wrong the model is (loss)
|
||||
2. **Calculates** direction to improve (gradients)
|
||||
3. **Updates** weights to reduce loss
|
||||
4. **Repeats** until convergence
|
||||
|
||||
**Result:** Model learns to make accurate predictions!
|
||||
|
||||
### Optimization Process Flow
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
Start[Training Start] --> Init[Initialize Weights<br/>Random Values]
|
||||
Init --> Loop[Training Loop]
|
||||
|
||||
Loop --> Forward[Forward Pass<br/>Model Prediction]
|
||||
Forward --> Loss["Compute Loss<br/>L = loss(pred, target)"]
|
||||
Loss --> Check{Converged?}
|
||||
|
||||
Check -->|Yes| End[Training Complete]
|
||||
Check -->|No| Gradient["Compute Gradients<br/>∇L = ∂L/∂θ"]
|
||||
|
||||
Gradient --> Optimize[Optimizer<br/>Update Weights]
|
||||
Optimize --> Update["New Weights<br/>θ = θ - update"]
|
||||
Update --> Loop
|
||||
|
||||
style Start fill:#e1f5ff
|
||||
style End fill:#e1ffe1
|
||||
style Optimize fill:#fff4e1
|
||||
style Check fill:#ffe1f5
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7.2 The Optimization Problem
|
||||
|
||||
### The Objective
|
||||
|
||||
**We want to minimize:**
|
||||
|
||||
```math
|
||||
L(\theta) = \frac{1}{N} \sum_{i=1}^{N} \ell(y_i, f(x_i; \theta))
|
||||
```
|
||||
|
||||
**Where:**
|
||||
- $\theta$ = all model parameters (weights)
|
||||
- $L$ = total loss
|
||||
- $\ell$ = loss function (e.g., cross-entropy)
|
||||
- $y_i$ = correct answer
|
||||
- $f(x_i; \theta)$ = model prediction
|
||||
- $N$ = number of examples
|
||||
|
||||
### The Challenge
|
||||
|
||||
**Problem:** Loss function is complex and high-dimensional
|
||||
|
||||
**Solution:** Use iterative optimization algorithms
|
||||
|
||||
**Process:**
|
||||
```
|
||||
Initialize weights randomly
|
||||
Repeat:
|
||||
1. Compute loss
|
||||
2. Compute gradients
|
||||
3. Update weights
|
||||
Until convergence
|
||||
```
|
||||
|
||||
### Optimization Problem Flowchart
|
||||
|
||||
```mermaid
|
||||
graph LR
|
||||
subgraph "Optimization Problem"
|
||||
A["Loss Function<br/>L(θ)"] --> B["Find Minimum<br/>min L(θ)"]
|
||||
B --> C["Optimal Weights<br/>θ*"]
|
||||
end
|
||||
|
||||
subgraph "Solution Approach"
|
||||
D["Initialize θ"] --> E[Iterative Updates]
|
||||
E --> F[Compute Loss]
|
||||
F --> G[Compute Gradient]
|
||||
G --> H[Update Weights]
|
||||
H --> I{Converged?}
|
||||
I -->|No| E
|
||||
I -->|Yes| C
|
||||
end
|
||||
|
||||
A -.-> F
|
||||
C -.-> C
|
||||
|
||||
style A fill:#ffcccc
|
||||
style B fill:#ffffcc
|
||||
style C fill:#ccffcc
|
||||
style E fill:#cce5ff
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7.3 Gradient Descent
|
||||
|
||||
### What is Gradient Descent?
|
||||
|
||||
**Gradient Descent** is a basic optimization algorithm that updates weights by moving in the direction of steepest descent.
|
||||
|
||||
### How It Works
|
||||
|
||||
**Step 1: Compute Gradient**
|
||||
|
||||
```math
|
||||
\nabla_\theta L = \frac{\partial L}{\partial \theta}
|
||||
```
|
||||
|
||||
**Gradient tells us:**
|
||||
- Direction: Which way to go
|
||||
- Magnitude: How steep the slope
|
||||
|
||||
**Step 2: Update Weights**
|
||||
|
||||
```math
|
||||
\theta_{t+1} = \theta_t - \eta \nabla_\theta L
|
||||
```
|
||||
|
||||
**Where:**
|
||||
- $\theta_t$ = current weights
|
||||
- $\eta$ = learning rate (step size)
|
||||
- $\nabla_\theta L$ = gradient
|
||||
|
||||
**Meaning:** Move weights in direction opposite to gradient
|
||||
|
||||
### Visual Example
|
||||
|
||||
```
|
||||
Loss Landscape (2D):
|
||||
|
||||
Gradient
|
||||
Direction
|
||||
↓
|
||||
● ──────┼───── → Lower Loss
|
||||
│
|
||||
│
|
||||
```
|
||||
|
||||
**Move in direction of negative gradient!**
|
||||
|
||||
### Gradient Descent Flowchart
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Gradient Descent Algorithm"
|
||||
Start["Start: Initialize θ₀"] --> Loop["For each iteration t"]
|
||||
|
||||
Loop --> Forward[Forward Pass<br/>Compute Predictions]
|
||||
Forward --> Loss["Compute Loss<br/>L(θₜ)"]
|
||||
Loss --> Grad["Compute Gradient<br/>g = ∇L(θₜ)"]
|
||||
|
||||
Grad --> Direction["Determine Direction<br/>-g points to minimum"]
|
||||
Direction --> Step["Take Step<br/>η × g"]
|
||||
Step --> Update["Update Weights<br/>θₜ₊₁ = θₜ - ηg"]
|
||||
|
||||
Update --> Check{"Converged?<br/>|g| < ε"}
|
||||
Check -->|No| Loop
|
||||
Check -->|Yes| End["Found Minimum<br/>θ*"]
|
||||
end
|
||||
|
||||
subgraph "Gradient Information"
|
||||
GradInfo["Gradient g contains:<br/>- Direction: Which way to go<br/>- Magnitude: How steep"]
|
||||
end
|
||||
|
||||
Grad -.-> GradInfo
|
||||
|
||||
style Start fill:#e1f5ff
|
||||
style End fill:#e1ffe1
|
||||
style Grad fill:#fff4e1
|
||||
style Check fill:#ffe1f5
|
||||
style Update fill:#ccffcc
|
||||
```
|
||||
|
||||
### Types of Gradient Descent
|
||||
|
||||
**1. Batch Gradient Descent:**
|
||||
- Uses all training examples
|
||||
- Most accurate gradients
|
||||
- Slow for large datasets
|
||||
|
||||
**2. Stochastic Gradient Descent (SGD):**
|
||||
- Uses one example at a time
|
||||
- Fast but noisy
|
||||
- Can bounce around
|
||||
|
||||
**3. Mini-Batch Gradient Descent:**
|
||||
- Uses small batch of examples
|
||||
- Balance of speed and accuracy
|
||||
- Most commonly used
|
||||
|
||||
### Gradient Descent Types Comparison
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Batch Gradient Descent"
|
||||
B1[All Training Data] --> B2[Compute Gradient<br/>on Full Dataset]
|
||||
B2 --> B3[Single Update<br/>Most Accurate]
|
||||
B3 --> B4["Slow: O(N)"]
|
||||
end
|
||||
|
||||
subgraph "Stochastic Gradient Descent"
|
||||
S1[Single Example] --> S2[Compute Gradient<br/>on One Sample]
|
||||
S2 --> S3[Many Updates<br/>Fast but Noisy]
|
||||
S3 --> S4["Fast: O(1)"]
|
||||
end
|
||||
|
||||
subgraph "Mini-Batch Gradient Descent"
|
||||
M1[Small Batch<br/>32-256 samples] --> M2[Compute Gradient<br/>on Batch]
|
||||
M2 --> M3[Balanced Updates<br/>Good Accuracy]
|
||||
M3 --> M4["Fast: O(batch_size)"]
|
||||
end
|
||||
|
||||
style B3 fill:#ccffcc
|
||||
style S3 fill:#ffcccc
|
||||
style M3 fill:#fff4e1
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7.4 AdamW Optimizer
|
||||
|
||||
### What is AdamW?
|
||||
|
||||
**AdamW** (Adam with Weight Decay) is an advanced optimizer that combines:
|
||||
- **Adaptive learning rates** (like Adam)
|
||||
- **Weight decay** (regularization)
|
||||
|
||||
**Why AdamW?**
|
||||
- Per-parameter learning rates
|
||||
- Handles sparse gradients well
|
||||
- Works great for transformers
|
||||
|
||||
### How AdamW Works
|
||||
|
||||
**Step 1: Compute Gradient**
|
||||
|
||||
```math
|
||||
g_t = \nabla_\theta L(\theta_t)
|
||||
```
|
||||
|
||||
**Step 2: Update Momentum**
|
||||
|
||||
```math
|
||||
m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t
|
||||
```
|
||||
|
||||
**Where:**
|
||||
- $\beta_1 = 0.9$ (momentum decay)
|
||||
- $m_t$ = first moment estimate
|
||||
|
||||
**Meaning:** Moving average of gradients
|
||||
|
||||
**Step 3: Update Variance**
|
||||
|
||||
```math
|
||||
v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2
|
||||
```
|
||||
|
||||
**Where:**
|
||||
- $\beta_2 = 0.999$ (variance decay)
|
||||
- $v_t$ = second moment estimate
|
||||
|
||||
**Meaning:** Moving average of squared gradients
|
||||
|
||||
**Step 4: Bias Correction**
|
||||
|
||||
```math
|
||||
\hat{m}_t = \frac{m_t}{1 - \beta_1^t}
|
||||
```
|
||||
|
||||
```math
|
||||
\hat{v}_t = \frac{v_t}{1 - \beta_2^t}
|
||||
```
|
||||
|
||||
**Why?** Corrects bias in early iterations
|
||||
|
||||
**Step 5: Update Weights**
|
||||
|
||||
```math
|
||||
\theta_{t+1} = \theta_t - \eta \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \theta_t \right)
|
||||
```
|
||||
|
||||
**Where:**
|
||||
- $\eta$ = learning rate
|
||||
- $\epsilon = 10^{-8}$ (small constant)
|
||||
- $\lambda$ = weight decay coefficient
|
||||
|
||||
**Key Points:**
|
||||
- $\frac{\hat{m}_t}{\sqrt{\hat{v}_t}}$ = adaptive learning rate per parameter
|
||||
- $\lambda \theta_t$ = weight decay (regularization)
|
||||
|
||||
### AdamW Optimizer Flowchart
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "AdamW Optimization Process"
|
||||
Start["Start: Initialize<br/>θ₀, m₀=0, v₀=0"] --> Loop["For each iteration t"]
|
||||
|
||||
Loop --> Forward["Forward Pass<br/>Compute Loss L(θₜ)"]
|
||||
Forward --> Grad["Step 1: Compute Gradient<br/>gₜ = ∇L(θₜ)"]
|
||||
|
||||
Grad --> Mom["Step 2: Update Momentum<br/>mₜ = β₁mₜ₋₁ + (1-β₁)gₜ"]
|
||||
Mom --> Var["Step 3: Update Variance<br/>vₜ = β₂vₜ₋₁ + (1-β₂)gₜ²"]
|
||||
|
||||
Var --> Bias["Step 4: Bias Correction<br/>m̂ₜ = mₜ/(1-β₁ᵗ)<br/>v̂ₜ = vₜ/(1-β₂ᵗ)"]
|
||||
|
||||
Bias --> Adapt["Step 5: Adaptive LR<br/>LR = η/(√v̂ₜ + ε)"]
|
||||
|
||||
Adapt --> Decay["Step 6: Weight Decay<br/>λθₜ"]
|
||||
|
||||
Decay --> Update["Step 7: Update Weights<br/>θₜ₊₁ = θₜ - LR×m̂ₜ - λθₜ"]
|
||||
|
||||
Update --> Check{Converged?}
|
||||
Check -->|No| Loop
|
||||
Check -->|Yes| End["Optimal Weights θ*"]
|
||||
end
|
||||
|
||||
subgraph "Key Components"
|
||||
C1["Momentum mₜ<br/>Moving avg of gradients"]
|
||||
C2["Variance vₜ<br/>Moving avg of g²"]
|
||||
C3["Adaptive LR<br/>Per-parameter learning rate"]
|
||||
C4["Weight Decay<br/>Regularization"]
|
||||
end
|
||||
|
||||
Mom -.-> C1
|
||||
Var -.-> C2
|
||||
Adapt -.-> C3
|
||||
Decay -.-> C4
|
||||
|
||||
style Start fill:#e1f5ff
|
||||
style End fill:#e1ffe1
|
||||
style Grad fill:#fff4e1
|
||||
style Adapt fill:#ccffcc
|
||||
style Update fill:#ccffcc
|
||||
style Check fill:#ffe1f5
|
||||
```
|
||||
|
||||
### AdamW Detailed Subgraph
|
||||
|
||||
```mermaid
|
||||
graph LR
|
||||
subgraph "Input"
|
||||
I1["Gradient gₜ"]
|
||||
I2["Previous Momentum mₜ₋₁"]
|
||||
I3["Previous Variance vₜ₋₁"]
|
||||
I4["Current Weights θₜ"]
|
||||
end
|
||||
|
||||
subgraph "Momentum Update"
|
||||
M1["Multiply: β₁mₜ₋₁"] --> M2["Combine: β₁mₜ₋₁ + (1-β₁)gₜ"]
|
||||
I2 --> M1
|
||||
I1 --> M2
|
||||
end
|
||||
|
||||
subgraph "Variance Update"
|
||||
V1["Square: gₜ²"] --> V2["Combine: β₂vₜ₋₁ + (1-β₂)gₜ²"]
|
||||
I3 --> V2
|
||||
I1 --> V1
|
||||
end
|
||||
|
||||
subgraph "Bias Correction"
|
||||
M2 --> BC1["m̂ₜ = mₜ/(1-β₁ᵗ)"]
|
||||
V2 --> BC2["v̂ₜ = vₜ/(1-β₂ᵗ)"]
|
||||
end
|
||||
|
||||
subgraph "Adaptive Learning Rate"
|
||||
BC2 --> ALR["LR = η/(√v̂ₜ + ε)"]
|
||||
end
|
||||
|
||||
subgraph "Weight Update"
|
||||
BC1 --> WU1["Adaptive Step: LR × m̂ₜ"]
|
||||
ALR --> WU1
|
||||
I4 --> WU2["Decay Step: λθₜ"]
|
||||
WU1 --> WU3["Update: θₜ₊₁ = θₜ - LR×m̂ₜ - λθₜ"]
|
||||
WU2 --> WU3
|
||||
end
|
||||
|
||||
style M2 fill:#e1f5ff
|
||||
style V2 fill:#e1f5ff
|
||||
style BC1 fill:#fff4e1
|
||||
style BC2 fill:#fff4e1
|
||||
style ALR fill:#ccffcc
|
||||
style WU3 fill:#ccffcc
|
||||
```
|
||||
|
||||
### Why AdamW is Better
|
||||
|
||||
**Compared to SGD:**
|
||||
|
||||
**SGD:**
|
||||
```
|
||||
Same learning rate for all parameters
|
||||
→ Slow convergence
|
||||
→ Manual tuning needed
|
||||
```
|
||||
|
||||
**AdamW:**
|
||||
```
|
||||
Adaptive learning rate per parameter
|
||||
→ Faster convergence
|
||||
→ Less manual tuning
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
1. **Adaptive:** Each parameter gets its own learning rate
|
||||
2. **Robust:** Works well with noisy gradients
|
||||
3. **Efficient:** Converges faster than SGD
|
||||
4. **Regularized:** Weight decay prevents overfitting
|
||||
|
||||
### SGD vs AdamW Comparison
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Stochastic Gradient Descent"
|
||||
SGD1["Gradient gₜ"] --> SGD2["Fixed Learning Rate η"]
|
||||
SGD2 --> SGD3["Update: θₜ₊₁ = θₜ - ηgₜ"]
|
||||
SGD3 --> SGD4["All params same LR"]
|
||||
SGD4 --> SGD5["Slow Convergence<br/>Manual Tuning"]
|
||||
end
|
||||
|
||||
subgraph "AdamW Optimizer"
|
||||
AD1["Gradient gₜ"] --> AD2["Momentum mₜ"]
|
||||
AD1 --> AD3["Variance vₜ"]
|
||||
AD2 --> AD4[Bias Correction]
|
||||
AD3 --> AD4
|
||||
AD4 --> AD5["Adaptive LR per param"]
|
||||
AD5 --> AD6["Update: θₜ₊₁ = θₜ - LR×m̂ₜ - λθₜ"]
|
||||
AD6 --> AD7["Fast Convergence<br/>Less Tuning"]
|
||||
end
|
||||
|
||||
subgraph "Comparison"
|
||||
Comp1["Same Model<br/>Same Data"]
|
||||
Comp1 --> Comp2["SGD: Loss = 2.5<br/>After 100 epochs"]
|
||||
Comp1 --> Comp3["AdamW: Loss = 1.8<br/>After 100 epochs"]
|
||||
Comp3 --> Comp4[AdamW is Better!]
|
||||
end
|
||||
|
||||
SGD5 -.-> Comp2
|
||||
AD7 -.-> Comp3
|
||||
|
||||
style SGD5 fill:#ffcccc
|
||||
style AD7 fill:#ccffcc
|
||||
style Comp4 fill:#e1ffe1
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7.5 Why Optimization Matters
|
||||
|
||||
### Reason 1: Without Optimization
|
||||
|
||||
**Random weights:**
|
||||
```
|
||||
Weights: Random values
|
||||
Loss: Very high
|
||||
Predictions: Random
|
||||
Model: Useless
|
||||
```
|
||||
|
||||
### Reason 2: With Optimization
|
||||
|
||||
**Learned weights:**
|
||||
```
|
||||
Weights: Optimized values
|
||||
Loss: Low
|
||||
Predictions: Accurate
|
||||
Model: Useful
|
||||
```
|
||||
|
||||
### Reason 3: Determines Learning Speed
|
||||
|
||||
**Good optimizer:**
|
||||
- Fast convergence
|
||||
- Stable training
|
||||
- Good final performance
|
||||
|
||||
**Poor optimizer:**
|
||||
- Slow convergence
|
||||
- Unstable training
|
||||
- Poor final performance
|
||||
|
||||
### Reason 4: Affects Final Performance
|
||||
|
||||
**Same model, different optimizers:**
|
||||
|
||||
```
|
||||
SGD: Loss = 2.5 (after 100 epochs)
|
||||
AdamW: Loss = 1.8 (after 100 epochs)
|
||||
```
|
||||
|
||||
**Better optimizer = Better model!**
|
||||
|
||||
### Optimization Impact Visualization
|
||||
|
||||
```mermaid
|
||||
graph LR
|
||||
subgraph "Without Optimization"
|
||||
WO1[Random Weights] --> WO2["High Loss<br/>L ≈ 8-10"]
|
||||
WO2 --> WO3[Random Predictions]
|
||||
WO3 --> WO4[Model Useless]
|
||||
end
|
||||
|
||||
subgraph "With Optimization"
|
||||
W1[Random Weights] --> W2[Optimization Loop]
|
||||
W2 --> W3[Update Weights]
|
||||
W3 --> W4["Low Loss<br/>L ≈ 1-2"]
|
||||
W4 --> W5[Accurate Predictions]
|
||||
W5 --> W6[Model Useful]
|
||||
end
|
||||
|
||||
subgraph "Optimizer Quality"
|
||||
O1["Poor Optimizer<br/>SGD, Bad LR"] --> O2["Slow Convergence<br/>Loss = 2.5"]
|
||||
O3["Good Optimizer<br/>AdamW, Proper LR"] --> O4["Fast Convergence<br/>Loss = 1.8"]
|
||||
end
|
||||
|
||||
W2 -.-> O1
|
||||
W2 -.-> O3
|
||||
|
||||
style WO4 fill:#ffcccc
|
||||
style W6 fill:#ccffcc
|
||||
style O2 fill:#ffcccc
|
||||
style O4 fill:#ccffcc
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7.6 Complete Mathematical Formulation
|
||||
|
||||
### Optimization Problem
|
||||
|
||||
```math
|
||||
\theta^* = \arg\min_{\theta} L(\theta)
|
||||
```
|
||||
|
||||
**Where $\theta^*$ is the optimal set of weights**
|
||||
|
||||
### Gradient Descent Update
|
||||
|
||||
```math
|
||||
\theta_{t+1} = \theta_t - \eta \nabla_\theta L(\theta_t)
|
||||
```
|
||||
|
||||
### AdamW Update (Complete)
|
||||
|
||||
**For each parameter $\theta_i$:**
|
||||
|
||||
**Gradient:**
|
||||
```math
|
||||
g_{t,i} = \frac{\partial L}{\partial \theta_{t,i}}
|
||||
```
|
||||
|
||||
**Momentum:**
|
||||
```math
|
||||
m_{t,i} = \beta_1 m_{t-1,i} + (1 - \beta_1) g_{t,i}
|
||||
```
|
||||
|
||||
**Variance:**
|
||||
```math
|
||||
v_{t,i} = \beta_2 v_{t-1,i} + (1 - \beta_2) g_{t,i}^2
|
||||
```
|
||||
|
||||
**Bias Correction:**
|
||||
```math
|
||||
\hat{m}_{t,i} = \frac{m_{t,i}}{1 - \beta_1^t}
|
||||
```
|
||||
|
||||
```math
|
||||
\hat{v}_{t,i} = \frac{v_{t,i}}{1 - \beta_2^t}
|
||||
```
|
||||
|
||||
**Update:**
|
||||
```math
|
||||
\theta_{t+1,i} = \theta_{t,i} - \eta \left( \frac{\hat{m}_{t,i}}{\sqrt{\hat{v}_{t,i}} + \epsilon} + \lambda \theta_{t,i} \right)
|
||||
```
|
||||
|
||||
**Where:**
|
||||
- $\beta_1 = 0.9$
|
||||
- $\beta_2 = 0.999$
|
||||
- $\epsilon = 10^{-8}$
|
||||
- $\lambda$ = weight decay (e.g., 0.01)
|
||||
|
||||
### Complete Mathematical Flow
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Optimization Problem"
|
||||
OP1["Loss Function L(θ)"] --> OP2["Find: θ* = argmin L(θ)"]
|
||||
end
|
||||
|
||||
subgraph "Gradient Computation"
|
||||
GC1[Forward Pass] --> GC2[Compute Loss L]
|
||||
GC2 --> GC3[Backpropagation]
|
||||
GC3 --> GC4["Gradient gᵢ = ∂L/∂θᵢ"]
|
||||
end
|
||||
|
||||
subgraph "AdamW Update Steps"
|
||||
GC4 --> AU1["Momentum: mᵢ = β₁m + (1-β₁)gᵢ"]
|
||||
AU1 --> AU2["Variance: vᵢ = β₂v + (1-β₂)gᵢ²"]
|
||||
AU2 --> AU3["Bias Correction:<br/>m̂ = m/(1-β₁ᵗ), v̂ = v/(1-β₂ᵗ)"]
|
||||
AU3 --> AU4["Adaptive LR: η/(√v̂ + ε)"]
|
||||
AU4 --> AU5["Update: θᵢ = θᵢ - LR×m̂ - λθᵢ"]
|
||||
end
|
||||
|
||||
subgraph "Convergence Check"
|
||||
AU5 --> CC1{Converged?}
|
||||
CC1 -->|No| GC1
|
||||
CC1 -->|Yes| CC2["Optimal Weights θ*"]
|
||||
end
|
||||
|
||||
OP2 -.-> GC1
|
||||
CC2 -.-> OP2
|
||||
|
||||
style OP2 fill:#ffffcc
|
||||
style GC4 fill:#fff4e1
|
||||
style AU5 fill:#ccffcc
|
||||
style CC2 fill:#e1ffe1
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7.7 Exercise: Optimizer Step-by-Step
|
||||
|
||||
### Problem
|
||||
|
||||
**Given:**
|
||||
- Current weight: $\theta_0 = 2.0$
|
||||
- Loss function: $L(\theta) = (\theta - 1)^2$
|
||||
- Learning rate: $\eta = 0.1$
|
||||
- Use AdamW with $\beta_1 = 0.9$, $\beta_2 = 0.999$, $\lambda = 0.01$
|
||||
- Initial moments: $m_0 = 0$, $v_0 = 0$
|
||||
|
||||
**Calculate the weight update for step 1.**
|
||||
|
||||
### Step-by-Step Solution
|
||||
|
||||
#### Step 1: Compute Gradient
|
||||
|
||||
**Loss function:**
|
||||
```math
|
||||
L(\theta) = (\theta - 1)^2
|
||||
```
|
||||
|
||||
**Gradient:**
|
||||
```math
|
||||
g_1 = \frac{\partial L}{\partial \theta} = 2(\theta - 1)
|
||||
```
|
||||
|
||||
**At $\theta_0 = 2.0$:**
|
||||
```math
|
||||
g_1 = 2(2.0 - 1) = 2(1.0) = 2.0
|
||||
```
|
||||
|
||||
#### Step 2: Update Momentum
|
||||
|
||||
```math
|
||||
m_1 = \beta_1 m_0 + (1 - \beta_1) g_1
|
||||
```
|
||||
|
||||
```math
|
||||
m_1 = 0.9 \times 0 + (1 - 0.9) \times 2.0 = 0 + 0.1 \times 2.0 = 0.2
|
||||
```
|
||||
|
||||
#### Step 3: Update Variance
|
||||
|
||||
```math
|
||||
v_1 = \beta_2 v_0 + (1 - \beta_2) g_1^2
|
||||
```
|
||||
|
||||
```math
|
||||
v_1 = 0.999 \times 0 + (1 - 0.999) \times (2.0)^2 = 0 + 0.001 \times 4.0 = 0.004
|
||||
```
|
||||
|
||||
#### Step 4: Bias Correction
|
||||
|
||||
```math
|
||||
\hat{m}_1 = \frac{m_1}{1 - \beta_1^1} = \frac{0.2}{1 - 0.9} = \frac{0.2}{0.1} = 2.0
|
||||
```
|
||||
|
||||
```math
|
||||
\hat{v}_1 = \frac{v_1}{1 - \beta_2^1} = \frac{0.004}{1 - 0.999} = \frac{0.004}{0.001} = 4.0
|
||||
```
|
||||
|
||||
#### Step 5: Compute Update
|
||||
|
||||
```math
|
||||
\Delta \theta_1 = \eta \left( \frac{\hat{m}_1}{\sqrt{\hat{v}_1} + \epsilon} + \lambda \theta_0 \right)
|
||||
```
|
||||
|
||||
```math
|
||||
\Delta \theta_1 = 0.1 \left( \frac{2.0}{\sqrt{4.0} + 10^{-8}} + 0.01 \times 2.0 \right)
|
||||
```
|
||||
|
||||
```math
|
||||
\Delta \theta_1 = 0.1 \left( \frac{2.0}{2.0 + 0.00000001} + 0.02 \right)
|
||||
```
|
||||
|
||||
```math
|
||||
\Delta \theta_1 = 0.1 \left( \frac{2.0}{2.0} + 0.02 \right) = 0.1 (1.0 + 0.02) = 0.1 \times 1.02 = 0.102
|
||||
```
|
||||
|
||||
#### Step 6: Update Weight
|
||||
|
||||
```math
|
||||
\theta_1 = \theta_0 - \Delta \theta_1 = 2.0 - 0.102 = 1.898
|
||||
```
|
||||
|
||||
### Answer
|
||||
|
||||
**After one step:**
|
||||
- Old weight: $\theta_0 = 2.0$
|
||||
- New weight: $\theta_1 = 1.898$
|
||||
- Update: $\Delta \theta_1 = -0.102$
|
||||
|
||||
**The weight moved closer to the optimal value (1.0)!**
|
||||
|
||||
### Verification
|
||||
|
||||
**Check loss:**
|
||||
- Old loss: $L(2.0) = (2.0 - 1)^2 = 1.0$
|
||||
- New loss: $L(1.898) = (1.898 - 1)^2 = 0.806$
|
||||
|
||||
**Loss decreased! ✓**
|
||||
|
||||
### Exercise Solution Flowchart
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Given Values"
|
||||
G1["θ₀ = 2.0"] --> Start
|
||||
G2["m₀ = 0"] --> Start
|
||||
G3["v₀ = 0"] --> Start
|
||||
G4["η = 0.1, β₁ = 0.9<br/>β₂ = 0.999, λ = 0.01"] --> Start
|
||||
end
|
||||
|
||||
Start[Start] --> Step1["Step 1: Compute Gradient<br/>L(θ) = (θ-1)²<br/>g₁ = 2(θ₀-1) = 2.0"]
|
||||
|
||||
Step1 --> Step2["Step 2: Update Momentum<br/>m₁ = 0.9×0 + 0.1×2.0<br/>m₁ = 0.2"]
|
||||
|
||||
Step2 --> Step3["Step 3: Update Variance<br/>v₁ = 0.999×0 + 0.001×4.0<br/>v₁ = 0.004"]
|
||||
|
||||
Step3 --> Step4["Step 4: Bias Correction<br/>m̂₁ = 0.2/(1-0.9) = 2.0<br/>v̂₁ = 0.004/(1-0.999) = 4.0"]
|
||||
|
||||
Step4 --> Step5["Step 5: Compute Update<br/>Δθ₁ = 0.1×(2.0/√4.0 + 0.01×2.0)<br/>Δθ₁ = 0.102"]
|
||||
|
||||
Step5 --> Step6["Step 6: Update Weight<br/>θ₁ = 2.0 - 0.102<br/>θ₁ = 1.898"]
|
||||
|
||||
Step6 --> Verify["Verification:<br/>L(2.0) = 1.0 → L(1.898) = 0.806<br/>Loss Decreased!"]
|
||||
|
||||
Verify --> End["Result: θ₁ = 1.898<br/>Closer to optimum θ* = 1.0"]
|
||||
|
||||
style Start fill:#e1f5ff
|
||||
style Step1 fill:#fff4e1
|
||||
style Step2 fill:#fff4e1
|
||||
style Step3 fill:#fff4e1
|
||||
style Step4 fill:#fff4e1
|
||||
style Step5 fill:#fff4e1
|
||||
style Step6 fill:#ccffcc
|
||||
style Verify fill:#e1ffe1
|
||||
style End fill:#e1ffe1
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7.8 Key Takeaways
|
||||
|
||||
### Optimization
|
||||
|
||||
✅ **Optimization finds best weights to minimize loss**
|
||||
✅ **Uses gradients to determine update direction**
|
||||
✅ **Iterative process: compute → update → repeat**
|
||||
|
||||
### Gradient Descent
|
||||
|
||||
✅ **Basic algorithm: move opposite to gradient**
|
||||
✅ **Learning rate controls step size**
|
||||
✅ **Can be slow for complex problems**
|
||||
|
||||
### AdamW
|
||||
|
||||
✅ **Advanced optimizer with adaptive learning rates**
|
||||
✅ **Each parameter gets its own learning rate**
|
||||
✅ **Combines momentum and variance estimates**
|
||||
✅ **Includes weight decay for regularization**
|
||||
✅ **Works great for transformers**
|
||||
|
||||
### Why Important
|
||||
|
||||
✅ **Determines how fast model learns**
|
||||
✅ **Affects final model performance**
|
||||
✅ **Essential for training neural networks**
|
||||
|
||||
---
|
||||
|
||||
*This document provides a comprehensive explanation of optimization in neural networks, including gradient descent and AdamW optimizer with mathematical formulations and solved exercises.*
|
||||
|
||||
706
docs/PAIN_POINTS_AND_OPPORTUNITIES.md
Normal file
706
docs/PAIN_POINTS_AND_OPPORTUNITIES.md
Normal file
@@ -0,0 +1,706 @@
|
||||
# LLM Pain Points & Market Opportunities
|
||||
|
||||
A comprehensive analysis of the main challenges in language models and emerging opportunities in the market.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [Main Pain Points](#main-pain-points)
|
||||
2. [Market Opportunities](#market-opportunities)
|
||||
3. [Technical Solutions](#technical-solutions)
|
||||
4. [Market Segments](#market-segments)
|
||||
5. [Future Trends](#future-trends)
|
||||
|
||||
---
|
||||
|
||||
## Main Pain Points
|
||||
|
||||
### 1. Training Costs & Resource Requirements
|
||||
|
||||
**The Problem:**
|
||||
- **Extremely expensive**: Training GPT-3 cost ~$4.6M, GPT-4 likely $100M+
|
||||
- **Massive compute requirements**: Requires thousands of GPUs for months
|
||||
- **High barrier to entry**: Only large corporations can afford training from scratch
|
||||
- **Lengthy development cycles**: Months to years to train and iterate
|
||||
|
||||
**Impact:**
|
||||
```
|
||||
Small Companies: Cannot compete
|
||||
Researchers: Limited access to resources
|
||||
Innovation: Slowed by cost barriers
|
||||
```
|
||||
|
||||
**Numbers:**
|
||||
- GPT-3: 300B tokens, $4.6M training cost
|
||||
- GPT-4: Estimated $100M+ training cost
|
||||
- Training time: 3-6 months on thousands of GPUs
|
||||
- Infrastructure: Data centers with specialized hardware
|
||||
|
||||
### 2. Inference Latency & Speed
|
||||
|
||||
**The Problem:**
|
||||
- **Slow generation**: High-quality models generate 10-50 tokens/second
|
||||
- **High latency**: 500ms-5s response time for queries
|
||||
- **Poor scalability**: Linear scaling with number of users
|
||||
- **Real-time constraints**: Difficult to achieve interactive speeds
|
||||
|
||||
**Impact:**
|
||||
```
|
||||
User Experience: Frustrating delays
|
||||
Applications: Limited to batch processing
|
||||
Real-time Use: Not feasible for many cases
|
||||
Cost: More compute = slower response
|
||||
```
|
||||
|
||||
**Current Performance:**
|
||||
- Standard inference: 10-50 tokens/sec
|
||||
- High-end GPUs: 100-200 tokens/sec
|
||||
- With optimizations: 200-500 tokens/sec
|
||||
- Target for real-time: 1000+ tokens/sec
|
||||
|
||||
### 3. Memory Consumption
|
||||
|
||||
**The Problem:**
|
||||
- **Massive memory requirements**:
|
||||
- GPT-3 175B: ~350GB GPU memory
|
||||
- GPT-4: Estimated ~700GB+ memory
|
||||
- **Inefficient memory usage**: Attention matrices scale quadratically
|
||||
- **Limited device support**: Cannot run on consumer hardware
|
||||
- **High infrastructure costs**: Requires expensive GPUs
|
||||
|
||||
**Impact:**
|
||||
```
|
||||
Deployment: Expensive server infrastructure
|
||||
Accessibility: Limited to cloud providers
|
||||
Edge Devices: Impossible without optimization
|
||||
Cost: High memory = high server costs
|
||||
```
|
||||
|
||||
**Memory Breakdown:**
|
||||
- Model weights: 50-70% of memory
|
||||
- KV cache: 20-30% during inference
|
||||
- Activations: 10-20% during forward pass
|
||||
- Overhead: 5-10% for framework
|
||||
|
||||
### 4. Energy Consumption & Environmental Impact
|
||||
|
||||
**The Problem:**
|
||||
- **Extremely high energy usage**:
|
||||
- GPT-3 training: ~3,287 MWh (~$1.4M electricity)
|
||||
- Continuous inference: High carbon footprint
|
||||
- **Environmental concerns**: Equivalent to significant CO2 emissions
|
||||
- **Sustainability issues**: Unsustainable scaling
|
||||
|
||||
**Impact:**
|
||||
```
|
||||
Environment: Significant carbon footprint
|
||||
Cost: High electricity bills
|
||||
Regulation: Increasing environmental regulations
|
||||
Public Perception: Growing concern about AI's impact
|
||||
```
|
||||
|
||||
**Numbers:**
|
||||
- Training GPT-3: ~552 metric tons CO2 equivalent
|
||||
- Daily inference: Thousands of MWh per day globally
|
||||
- Cost: Electricity is major operational expense
|
||||
|
||||
### 5. Data Dependency & Quality
|
||||
|
||||
**The Problem:**
|
||||
- **Massive data requirements**: Billions of tokens needed
|
||||
- **Data quality issues**: Garbage in, garbage out
|
||||
- **Bias in training data**: Models inherit societal biases
|
||||
- **Copyright concerns**: Training on copyrighted material
|
||||
- **Data scarcity**: High-quality data is limited
|
||||
|
||||
**Impact:**
|
||||
```
|
||||
Quality: Poor data = poor models
|
||||
Bias: Perpetuates existing biases
|
||||
Legal: Copyright and licensing issues
|
||||
Cost: Data acquisition is expensive
|
||||
```
|
||||
|
||||
**Requirements:**
|
||||
- GPT-3: 300B tokens (~45TB of text)
|
||||
- Data cleaning: 70-80% of data preparation time
|
||||
- Quality control: Critical but expensive
|
||||
- Diversity: Need diverse, representative data
|
||||
|
||||
### 6. Hallucination & Reliability
|
||||
|
||||
**The Problem:**
|
||||
- **Factual inaccuracies**: Models generate plausible but false information
|
||||
- **Inconsistent outputs**: Same prompt can give different answers
|
||||
- **Difficulty verifying**: Hard to distinguish truth from hallucination
|
||||
- **Confidence estimation**: Models don't know when they're wrong
|
||||
|
||||
**Impact:**
|
||||
```
|
||||
Trust: Users lose confidence
|
||||
Applications: Cannot use for critical tasks
|
||||
Verification: Requires human oversight
|
||||
Legal: Liability concerns
|
||||
```
|
||||
|
||||
**Examples:**
|
||||
- Medical advice: Could be dangerous
|
||||
- Financial information: Could cause losses
|
||||
- Legal documents: Could have serious consequences
|
||||
- Scientific facts: Could mislead researchers
|
||||
|
||||
### 7. Fine-tuning & Customization Complexity
|
||||
|
||||
**The Problem:**
|
||||
- **Time-consuming**: Days to weeks for fine-tuning
|
||||
- **Expensive**: Requires significant compute resources
|
||||
- **Technical expertise**: Requires deep ML knowledge
|
||||
- **Dataset preparation**: Complex and time-consuming
|
||||
- **Hyperparameter tuning**: Trial and error process
|
||||
|
||||
**Impact:**
|
||||
```
|
||||
Adoption: High barrier for businesses
|
||||
Iteration: Slow feedback loops
|
||||
Cost: Expensive experimentation
|
||||
Expertise: Limited talent pool
|
||||
```
|
||||
|
||||
**Challenges:**
|
||||
- LoRA vs full fine-tuning: Trade-offs unclear
|
||||
- Data requirements: How much data is needed?
|
||||
- Evaluation: How to measure success?
|
||||
- Deployment: Complex integration process
|
||||
|
||||
### 8. Scalability & Infrastructure
|
||||
|
||||
**The Problem:**
|
||||
- **Horizontal scaling**: Difficult to distribute inference
|
||||
- **Load balancing**: Complex for stateful models
|
||||
- **Cost scaling**: Linear cost increase with users
|
||||
- **Infrastructure management**: Requires DevOps expertise
|
||||
- **High availability**: Complex to achieve 99.9%+ uptime
|
||||
|
||||
**Impact:**
|
||||
```
|
||||
Growth: Limits ability to scale
|
||||
Cost: Infrastructure costs grow with usage
|
||||
Reliability: Complex to maintain
|
||||
Engineering: Requires significant resources
|
||||
```
|
||||
|
||||
**Issues:**
|
||||
- State management: KV cache complicates scaling
|
||||
- Batch processing: Inefficient for single requests
|
||||
- Geographic distribution: Latency vs consistency
|
||||
- Cost optimization: Balancing performance and cost
|
||||
|
||||
---
|
||||
|
||||
## Market Opportunities
|
||||
|
||||
### 1. Efficient Training & Fine-tuning Solutions
|
||||
|
||||
**Opportunity:**
|
||||
- **Problem**: Training is too expensive and slow
|
||||
- **Solution**: Efficient training methods, LoRA, quantization
|
||||
- **Market Size**: $2-5B by 2027
|
||||
- **Key Players**: Hugging Face, Cohere, Anthropic
|
||||
|
||||
**Technologies:**
|
||||
- **LoRA (Low-Rank Adaptation)**: 10-100x cheaper fine-tuning
|
||||
- **Quantization**: 4x-8x memory reduction
|
||||
- **Gradient checkpointing**: 2x memory savings
|
||||
- **Distributed training**: Optimize multi-GPU setups
|
||||
|
||||
**Market Segments:**
|
||||
- Enterprise fine-tuning platforms
|
||||
- Training optimization tools
|
||||
- Pre-trained model marketplaces
|
||||
- Model compression services
|
||||
|
||||
**Revenue Models:**
|
||||
- SaaS platforms for fine-tuning
|
||||
- Consulting services
|
||||
- Model licensing
|
||||
- Training infrastructure
|
||||
|
||||
### 2. Inference Optimization & Acceleration
|
||||
|
||||
**Opportunity:**
|
||||
- **Problem**: Inference is too slow and expensive
|
||||
- **Solution**: KV caching, quantization, model pruning
|
||||
- **Market Size**: $5-10B by 2027
|
||||
- **Key Players**: NVIDIA, TensorRT, vLLM
|
||||
|
||||
**Technologies:**
|
||||
- **KV Caching**: 2-5x speedup
|
||||
- **Quantization**: 4x faster inference
|
||||
- **Model pruning**: 2-4x speedup
|
||||
- **Specialized hardware**: TPUs, specialized chips
|
||||
|
||||
**Market Segments:**
|
||||
- Real-time applications
|
||||
- Edge deployment
|
||||
- High-throughput services
|
||||
- Cost-sensitive applications
|
||||
|
||||
**Competitive Advantages:**
|
||||
- Ease of integration
|
||||
- Performance improvements
|
||||
- Cost reduction
|
||||
- Developer experience
|
||||
|
||||
### 3. Edge & Mobile Deployment
|
||||
|
||||
**Opportunity:**
|
||||
- **Problem**: Models too large for edge devices
|
||||
- **Solution**: Model compression, quantization, distillation
|
||||
- **Market Size**: $3-8B by 2027
|
||||
- **Key Players**: Qualcomm, Apple, Google
|
||||
|
||||
**Technologies:**
|
||||
- **Model distillation**: Smaller, faster models
|
||||
- **Quantization**: INT8/INT4 inference
|
||||
- **Pruning**: Remove unnecessary weights
|
||||
- **On-device ML**: Specialized hardware
|
||||
|
||||
**Market Segments:**
|
||||
- Smartphones
|
||||
- IoT devices
|
||||
- Autonomous vehicles
|
||||
- AR/VR devices
|
||||
|
||||
**Applications:**
|
||||
- Voice assistants
|
||||
- Camera processing
|
||||
- Real-time translation
|
||||
- Personalization
|
||||
|
||||
### 4. Domain-Specific Solutions
|
||||
|
||||
**Opportunity:**
|
||||
- **Problem**: General models underperform in specific domains
|
||||
- **Solution**: Specialized models for industries
|
||||
- **Market Size**: $10-20B by 2027
|
||||
- **Key Players**: Industry-specific startups
|
||||
|
||||
**Industries:**
|
||||
- **Healthcare**: Medical diagnosis, drug discovery
|
||||
- **Finance**: Fraud detection, trading algorithms
|
||||
- **Legal**: Contract analysis, legal research
|
||||
- **Education**: Personalized tutoring, content generation
|
||||
- **Customer Service**: Support automation, chatbots
|
||||
|
||||
**Value Propositions:**
|
||||
- Higher accuracy in domain
|
||||
- Regulatory compliance
|
||||
- Custom integrations
|
||||
- Expert knowledge built-in
|
||||
|
||||
**Revenue Models:**
|
||||
- SaaS subscriptions
|
||||
- Per-query pricing
|
||||
- Enterprise licenses
|
||||
- White-label solutions
|
||||
|
||||
### 5. Model Evaluation & Safety Tools
|
||||
|
||||
**Opportunity:**
|
||||
- **Problem**: Hard to evaluate model quality and safety
|
||||
- **Solution**: Comprehensive evaluation frameworks
|
||||
- **Market Size**: $500M-2B by 2027
|
||||
- **Key Players**: OpenAI, Anthropic, startup ecosystem
|
||||
|
||||
**Tools Needed:**
|
||||
- **Evaluation frameworks**: Benchmark suites
|
||||
- **Bias detection**: Identify and measure bias
|
||||
- **Safety testing**: Jailbreak detection, adversarial testing
|
||||
- **Explainability**: Understanding model decisions
|
||||
|
||||
**Market Segments:**
|
||||
- Enterprise model validation
|
||||
- Regulatory compliance
|
||||
- Research institutions
|
||||
- Government agencies
|
||||
|
||||
**Applications:**
|
||||
- Pre-deployment testing
|
||||
- Continuous monitoring
|
||||
- Regulatory reporting
|
||||
- Risk assessment
|
||||
|
||||
### 6. Data & Training Infrastructure
|
||||
|
||||
**Opportunity:**
|
||||
- **Problem**: Data preparation is expensive and time-consuming
|
||||
- **Solution**: Automated data pipelines and quality tools
|
||||
- **Market Size**: $2-5B by 2027
|
||||
- **Key Players**: Scale AI, Labelbox, Label Studio
|
||||
|
||||
**Solutions:**
|
||||
- **Data labeling**: Automated and human-in-the-loop
|
||||
- **Data quality**: Cleaning and validation tools
|
||||
- **Data pipelines**: ETL for ML workflows
|
||||
- **Synthetic data**: Generate training data
|
||||
|
||||
**Market Segments:**
|
||||
- Data labeling services
|
||||
- Quality assurance tools
|
||||
- Data pipeline platforms
|
||||
- Synthetic data generation
|
||||
|
||||
**Value:**
|
||||
- Faster data preparation
|
||||
- Higher quality training data
|
||||
- Reduced costs
|
||||
- Better model performance
|
||||
|
||||
### 7. Cost Optimization & Infrastructure
|
||||
|
||||
**Opportunity:**
|
||||
- **Problem**: Infrastructure costs are prohibitive
|
||||
- **Solution**: Optimized cloud services, cost management
|
||||
- **Market Size**: $5-15B by 2027
|
||||
- **Key Players**: AWS, Google Cloud, Azure, specialized providers
|
||||
|
||||
**Solutions:**
|
||||
- **GPU optimization**: Better utilization
|
||||
- **Model serving**: Efficient inference infrastructure
|
||||
- **Cost monitoring**: Track and optimize spending
|
||||
- **Multi-cloud**: Avoid vendor lock-in
|
||||
|
||||
**Market Segments:**
|
||||
- Cloud providers
|
||||
- Infrastructure optimization
|
||||
- Cost management tools
|
||||
- Managed ML services
|
||||
|
||||
**Value:**
|
||||
- Reduced infrastructure costs
|
||||
- Better performance
|
||||
- Easier scaling
|
||||
- Cost transparency
|
||||
|
||||
### 8. Open Source & Community Models
|
||||
|
||||
**Opportunity:**
|
||||
- **Problem**: Proprietary models lock users in
|
||||
- **Solution**: Open source alternatives
|
||||
- **Market Size**: Growing rapidly
|
||||
- **Key Players**: Hugging Face, Stability AI, Meta
|
||||
|
||||
**Trends:**
|
||||
- **Open source models**: Llama, Mistral, Falcon
|
||||
- **Model sharing**: Hugging Face Hub
|
||||
- **Community contributions**: Faster innovation
|
||||
- **Transparency**: Open weights and training data
|
||||
|
||||
**Market Impact:**
|
||||
- Lower barriers to entry
|
||||
- Faster innovation
|
||||
- More competition
|
||||
- Better accessibility
|
||||
|
||||
**Business Models:**
|
||||
- Open source with premium features
|
||||
- Hosting and infrastructure
|
||||
- Support and consulting
|
||||
- Enterprise editions
|
||||
|
||||
---
|
||||
|
||||
## Technical Solutions
|
||||
|
||||
### Current Solutions Addressing Pain Points
|
||||
|
||||
#### 1. Training Optimization
|
||||
|
||||
**LoRA (Low-Rank Adaptation)**
|
||||
- **Impact**: 10-100x cheaper fine-tuning
|
||||
- **Use Case**: Customizing models for specific tasks
|
||||
- **Adoption**: Widespread in research and industry
|
||||
|
||||
**Quantization**
|
||||
- **Impact**: 4x-8x memory reduction
|
||||
- **Use Case**: Fitting larger models on smaller GPUs
|
||||
- **Adoption**: Growing rapidly
|
||||
|
||||
**Gradient Checkpointing**
|
||||
- **Impact**: 2x memory savings
|
||||
- **Use Case**: Training larger models
|
||||
- **Adoption**: Standard practice
|
||||
|
||||
**Distributed Training**
|
||||
- **Impact**: Faster training, larger models
|
||||
- **Use Case**: Training billion-parameter models
|
||||
- **Adoption**: Required for large models
|
||||
|
||||
#### 2. Inference Optimization
|
||||
|
||||
**KV Caching**
|
||||
- **Impact**: 2-5x speedup
|
||||
- **Use Case**: Autoregressive generation
|
||||
- **Adoption**: Standard in production
|
||||
|
||||
**Quantization**
|
||||
- **Impact**: 4x faster inference
|
||||
- **Use Case**: Production deployment
|
||||
- **Adoption**: Common in production
|
||||
|
||||
**Model Pruning**
|
||||
- **Impact**: 2-4x speedup, smaller models
|
||||
- **Use Case**: Edge deployment
|
||||
- **Adoption**: Growing for edge devices
|
||||
|
||||
**Batch Processing**
|
||||
- **Impact**: Better GPU utilization
|
||||
- **Use Case**: High-throughput scenarios
|
||||
- **Adoption**: Standard practice
|
||||
|
||||
#### 3. Memory Optimization
|
||||
|
||||
**Flash Attention**
|
||||
- **Impact**: 2x memory reduction
|
||||
- **Use Case**: Long sequences
|
||||
- **Adoption**: Standard in new models
|
||||
|
||||
**Gradient Checkpointing**
|
||||
- **Impact**: 2x memory savings
|
||||
- **Use Case**: Training
|
||||
- **Adoption**: Common practice
|
||||
|
||||
**Model Sharding**
|
||||
- **Impact**: Distribute across GPUs
|
||||
- **Use Case**: Large models
|
||||
- **Adoption**: Required for large models
|
||||
|
||||
**Quantization**
|
||||
- **Impact**: 4x-8x memory reduction
|
||||
- **Use Case**: Inference and training
|
||||
- **Adoption**: Increasing rapidly
|
||||
|
||||
---
|
||||
|
||||
## Market Segments
|
||||
|
||||
### 1. Enterprise Software
|
||||
|
||||
**Size**: $10-30B by 2027
|
||||
**Characteristics**:
|
||||
- High willingness to pay
|
||||
- Enterprise features required
|
||||
- Compliance and security critical
|
||||
- Custom integrations needed
|
||||
|
||||
**Key Players**: OpenAI, Anthropic, Google, Microsoft
|
||||
**Opportunities**: Vertical solutions, integrations, compliance
|
||||
|
||||
### 2. Developer Tools & APIs
|
||||
|
||||
**Size**: $5-15B by 2027
|
||||
**Characteristics**:
|
||||
- Developer-friendly APIs
|
||||
- Good documentation
|
||||
- Competitive pricing
|
||||
- Reliability critical
|
||||
|
||||
**Key Players**: OpenAI, Anthropic, Cohere, Hugging Face
|
||||
**Opportunities**: Better APIs, developer experience, pricing
|
||||
|
||||
### 3. Consumer Applications
|
||||
|
||||
**Size**: $5-20B by 2027
|
||||
**Characteristics**:
|
||||
- Price-sensitive
|
||||
- User experience critical
|
||||
- Scale requirements
|
||||
- Privacy concerns
|
||||
|
||||
**Key Players**:
|
||||
- [ChatGPT](https://chat.openai.com) - OpenAI's conversational AI platform
|
||||
- [Claude](https://claude.ai) - Anthropic's AI assistant
|
||||
- [Perplexity](https://www.perplexity.ai) - AI-powered search engine
|
||||
- [Character.AI](https://character.ai) - Conversational AI characters platform
|
||||
|
||||
**Opportunities**: Better UX, lower costs, privacy
|
||||
|
||||
### 4. Research & Academia
|
||||
|
||||
**Size**: $1-3B by 2027
|
||||
**Characteristics**:
|
||||
- Open access preferred
|
||||
- Reproducibility important
|
||||
- Educational pricing
|
||||
- Community support
|
||||
|
||||
**Key Players**: Hugging Face, EleutherAI, Academic institutions
|
||||
**Opportunities**: Open source, educational tools, grants
|
||||
|
||||
### 5. Infrastructure & Cloud
|
||||
|
||||
**Size**: $10-25B by 2027
|
||||
**Characteristics**:
|
||||
- Scale critical
|
||||
- Reliability essential
|
||||
- Cost optimization
|
||||
- Multi-cloud support
|
||||
|
||||
**Key Players**: AWS, Google Cloud, Azure, specialized providers
|
||||
**Opportunities**: Better infrastructure, cost optimization
|
||||
|
||||
---
|
||||
|
||||
## Future Trends
|
||||
|
||||
### 1. Efficiency Improvements
|
||||
|
||||
**Trend**: Continued focus on efficiency
|
||||
- **Smaller models**: Better performance per parameter
|
||||
- **Smarter architectures**: More efficient attention mechanisms
|
||||
- **Hardware optimization**: Specialized chips for LLMs
|
||||
- **Algorithm improvements**: Better training and inference methods
|
||||
|
||||
**Impact**: Lower costs, better accessibility, faster adoption
|
||||
|
||||
### 2. Edge Deployment
|
||||
|
||||
**Trend**: Moving LLMs to edge devices
|
||||
- **Model compression**: Smaller, faster models
|
||||
- **Hardware acceleration**: Specialized mobile chips
|
||||
- **Hybrid approaches**: Cloud + edge combination
|
||||
- **Privacy**: On-device processing
|
||||
|
||||
**Impact**: Better privacy, lower latency, new applications
|
||||
|
||||
### 3. Specialized Models
|
||||
|
||||
**Trend**: Domain-specific models
|
||||
- **Industry focus**: Healthcare, finance, legal, etc.
|
||||
- **Better performance**: Domain expertise built-in
|
||||
- **Regulatory compliance**: Built-in compliance features
|
||||
- **Integration**: Easier integration with existing systems
|
||||
|
||||
**Impact**: Better performance, regulatory compliance, market segmentation
|
||||
|
||||
### 4. Open Source Growth
|
||||
|
||||
**Trend**: Growing open source ecosystem
|
||||
- **More models**: Better open source alternatives
|
||||
- **Community innovation**: Faster development
|
||||
- **Transparency**: Open weights and training data
|
||||
- **Accessibility**: Lower barriers to entry
|
||||
|
||||
**Impact**: More competition, faster innovation, better accessibility
|
||||
|
||||
### 5. Safety & Alignment
|
||||
|
||||
**Trend**: Focus on safety and alignment
|
||||
- **Evaluation frameworks**: Better testing tools
|
||||
- **Safety mechanisms**: Built-in safety features
|
||||
- **Alignment research**: Better understanding of alignment
|
||||
- **Regulation**: Increasing regulatory requirements
|
||||
|
||||
**Impact**: Safer models, regulatory compliance, public trust
|
||||
|
||||
### 6. Multimodal Expansion
|
||||
|
||||
**Trend**: Beyond text to images, audio, video
|
||||
- **Multimodal models**: Text + images + audio
|
||||
- **New applications**: Creative tools, video generation
|
||||
- **Unified models**: Single model for multiple modalities
|
||||
- **Interactions**: Better human-AI interaction
|
||||
|
||||
**Impact**: New applications, larger market, more complexity
|
||||
|
||||
### 7. Personalization
|
||||
|
||||
**Trend**: Highly personalized models
|
||||
- **Fine-tuning**: Easy personalization
|
||||
- **User data**: Learning from user interactions
|
||||
- **Privacy**: Balancing personalization and privacy
|
||||
- **Customization**: User-controlled customization
|
||||
|
||||
**Impact**: Better user experience, privacy challenges, new applications
|
||||
|
||||
### 8. Cost Reduction
|
||||
|
||||
**Trend**: Continued cost reduction
|
||||
- **Efficiency**: Better algorithms and hardware
|
||||
- **Competition**: More providers, lower prices
|
||||
- **Optimization**: Better resource utilization
|
||||
- **Accessibility**: Lower costs enable more use cases
|
||||
|
||||
**Impact**: More adoption, new applications, democratization
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
### Key Pain Points
|
||||
|
||||
1. **Training Costs**: Extremely expensive, limiting access
|
||||
2. **Inference Speed**: Too slow for many applications
|
||||
3. **Memory Usage**: Too large for most devices
|
||||
4. **Energy Consumption**: Environmental concerns
|
||||
5. **Data Dependency**: Need massive, high-quality data
|
||||
6. **Hallucination**: Reliability and trust issues
|
||||
7. **Fine-tuning Complexity**: Difficult to customize
|
||||
8. **Scalability**: Infrastructure challenges
|
||||
|
||||
### Major Opportunities
|
||||
|
||||
1. **Efficient Training**: LoRA, quantization, optimization tools
|
||||
2. **Inference Optimization**: KV caching, acceleration, compression
|
||||
3. **Edge Deployment**: Mobile and IoT applications
|
||||
4. **Domain-Specific Solutions**: Industry verticals
|
||||
5. **Evaluation Tools**: Safety and quality frameworks
|
||||
6. **Data Infrastructure**: Automated pipelines and quality tools
|
||||
7. **Cost Optimization**: Infrastructure and cloud services
|
||||
8. **Open Source**: Community-driven innovation
|
||||
|
||||
### Market Size
|
||||
|
||||
**Total Addressable Market**: $50-100B+ by 2027
|
||||
- Enterprise Software: $10-30B
|
||||
- Developer Tools: $5-15B
|
||||
- Consumer Applications: $5-20B
|
||||
- Infrastructure: $10-25B
|
||||
- Research & Academia: $1-3B
|
||||
- Specialized Solutions: $5-10B
|
||||
|
||||
### Competitive Landscape
|
||||
|
||||
**Established Players**: OpenAI, Google, Anthropic, Microsoft
|
||||
**Rising Stars**: Hugging Face, Cohere, Stability AI
|
||||
**Infrastructure**: AWS, Google Cloud, Azure, NVIDIA
|
||||
**Open Source**: Meta, EleutherAI, Community
|
||||
|
||||
### Success Factors
|
||||
|
||||
- **Technical Excellence**: Best performance and efficiency
|
||||
- **Developer Experience**: Easy to use and integrate
|
||||
- **Cost Effectiveness**: Competitive pricing
|
||||
- **Reliability**: Consistent performance
|
||||
- **Innovation**: Continuous improvement
|
||||
- **Community**: Strong ecosystem support
|
||||
|
||||
---
|
||||
|
||||
## Conclusion
|
||||
|
||||
The LLM market presents significant challenges but also enormous opportunities. The main pain points—cost, speed, memory, and reliability—create clear market opportunities for companies that can solve these problems.
|
||||
|
||||
**Key Takeaways:**
|
||||
|
||||
1. **Cost is the primary barrier**: Solutions that reduce training and inference costs will have significant market value
|
||||
2. **Speed matters**: Real-time applications require optimization
|
||||
3. **Efficiency is critical**: Better algorithms and hardware unlock new use cases
|
||||
4. **Specialization wins**: Domain-specific solutions better than general models
|
||||
5. **Open source is growing**: Community-driven innovation is accelerating
|
||||
6. **Infrastructure is key**: Better infrastructure enables adoption
|
||||
|
||||
The market is still early, with huge growth potential. Companies focusing on solving real pain points while building sustainable business models will capture significant value in this rapidly growing market.
|
||||
|
||||
---
|
||||
|
||||
*This document provides a comprehensive overview of the current state of LLMs, their challenges, and the opportunities they present. The market is evolving rapidly, with new solutions and opportunities emerging continuously.*
|
||||
656
docs/REPOSITORY_DOWNLOAD_GUIDE.md
Normal file
656
docs/REPOSITORY_DOWNLOAD_GUIDE.md
Normal file
@@ -0,0 +1,656 @@
|
||||
# Repository Download Guide
|
||||
|
||||
This guide explains how to automatically download GitHub repositories with open licenses for code training using the repository downloader scripts.
|
||||
|
||||
## Overview
|
||||
|
||||
The repository downloader allows you to automatically find and clone GitHub repositories based on:
|
||||
- **Categories**: Neovim configs, Lua repos, Bash scripts, Zsh configs, Python repos, ethical hacking tools, security tools, and all open-license repos
|
||||
- **Languages**: Python, JavaScript, Go, Rust, and 15+ more
|
||||
- **Licenses**: MIT, Apache, BSD, GPL, and other open source licenses
|
||||
- **Quality**: Filter by minimum stars (popularity)
|
||||
- **Size Limits**: Automatic stopping when reaching storage limits (default: 1 TB)
|
||||
|
||||
## Scripts
|
||||
|
||||
There are two scripts available:
|
||||
|
||||
1. **`download_all_repos.py`** - Convenience script to download all common categories at once
|
||||
2. **`download_repos.py`** - Full-featured script with all options and flexibility
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Download All Categories (Recommended)
|
||||
|
||||
The easiest way to download all repository categories:
|
||||
|
||||
```bash
|
||||
python3 download_all_repos.py
|
||||
```
|
||||
|
||||
This will download:
|
||||
- 📦 Neovim configurations and plugins
|
||||
- 📦 Lua programming repositories
|
||||
- 📦 Bash/shell script repositories
|
||||
- 📦 Zsh configuration and plugins
|
||||
- 📦 Python programming repositories
|
||||
- 📦 Ethical hacking and cybersecurity tools
|
||||
|
||||
**Default settings:**
|
||||
- Max repos per category: 50
|
||||
- Min stars: 100
|
||||
- Output directory: `data/repos`
|
||||
- Size limit: 1 TB (1024 GB)
|
||||
- Shallow clones (faster, less disk space)
|
||||
|
||||
### Download Specific Categories
|
||||
|
||||
```bash
|
||||
python3 download_repos.py --categories nvim lua bash zsh python hacking --max-repos 50
|
||||
```
|
||||
|
||||
### Download All Open-License Repos
|
||||
|
||||
Download repositories with any open license (any language):
|
||||
|
||||
```bash
|
||||
python3 download_repos.py --categories all-open --max-repos 1000 --max-size 1024.0
|
||||
```
|
||||
|
||||
### Download by Language
|
||||
|
||||
```bash
|
||||
python3 download_repos.py --language python --max-repos 100
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
No additional dependencies required! The script uses:
|
||||
- Python standard library (`urllib`, `json`, `subprocess`)
|
||||
- `tqdm` (already in requirements.txt)
|
||||
- `git` (should be installed on your system)
|
||||
|
||||
## Available Categories
|
||||
|
||||
### Neovim (`nvim`)
|
||||
Neovim configuration files and plugins written in Lua.
|
||||
|
||||
```bash
|
||||
python3 download_repos.py --categories nvim --max-repos 100
|
||||
```
|
||||
|
||||
**What it searches for:**
|
||||
- `neovim OR nvim-config OR neovim-config`
|
||||
- MIT licensed repositories (default)
|
||||
- 100+ stars minimum (default)
|
||||
|
||||
### Lua (`lua`)
|
||||
Lua programming language repositories.
|
||||
|
||||
```bash
|
||||
python3 download_repos.py --categories lua --max-repos 50
|
||||
```
|
||||
|
||||
**What it searches for:**
|
||||
- Language: Lua
|
||||
- MIT licensed repositories (default)
|
||||
- 100+ stars minimum (default)
|
||||
|
||||
### Bash (`bash`)
|
||||
Bash and shell script repositories.
|
||||
|
||||
```bash
|
||||
python3 download_repos.py --categories bash --max-repos 50
|
||||
```
|
||||
|
||||
**What it searches for:**
|
||||
- Language: Shell
|
||||
- MIT licensed repositories (default)
|
||||
- 100+ stars minimum (default)
|
||||
|
||||
### Zsh (`zsh`)
|
||||
Zsh configuration files and plugins (Oh My Zsh, etc.).
|
||||
|
||||
```bash
|
||||
python3 download_repos.py --categories zsh --max-repos 50
|
||||
```
|
||||
|
||||
**What it searches for:**
|
||||
- `zsh-config OR oh-my-zsh OR zsh-plugin`
|
||||
- MIT licensed repositories (default)
|
||||
- 100+ stars minimum (default)
|
||||
|
||||
### Python (`python`)
|
||||
Python programming language repositories.
|
||||
|
||||
```bash
|
||||
python3 download_repos.py --categories python --max-repos 100
|
||||
```
|
||||
|
||||
**What it searches for:**
|
||||
- Language: Python
|
||||
- MIT licensed repositories (default)
|
||||
- 100+ stars minimum (default)
|
||||
|
||||
### Ethical Hacking (`hacking`)
|
||||
Ethical hacking and cybersecurity tools.
|
||||
|
||||
```bash
|
||||
python3 download_repos.py --categories hacking --max-repos 100
|
||||
```
|
||||
|
||||
**What it searches for:**
|
||||
- `ethical-hacking OR cybersecurity OR penetration-testing OR security-tools OR red-team`
|
||||
- MIT licensed repositories (default)
|
||||
- 100+ stars minimum (default)
|
||||
|
||||
### Security (`security`)
|
||||
General security and cybersecurity repositories.
|
||||
|
||||
```bash
|
||||
python3 download_repos.py --categories security --max-repos 50
|
||||
```
|
||||
|
||||
**What it searches for:**
|
||||
- `security-tools OR cybersecurity OR penetration-testing OR red-team OR blue-team`
|
||||
- MIT licensed repositories (default)
|
||||
- 100+ stars minimum (default)
|
||||
|
||||
### All Open Licenses (`all-open`)
|
||||
All repositories with open licenses, any language. This is useful for downloading a diverse set of repositories.
|
||||
|
||||
```bash
|
||||
python3 download_repos.py --categories all-open --max-repos 1000 --max-size 1024.0
|
||||
```
|
||||
|
||||
**What it searches for:**
|
||||
- Any open-license repository (no language filter)
|
||||
- No specific license filter (searches all open licenses)
|
||||
- 100+ stars minimum (default)
|
||||
|
||||
**Note:** This category searches broadly and may return repositories with various licenses. You can still specify `--license` to filter to a specific license type.
|
||||
|
||||
## Command-Line Options
|
||||
|
||||
### `download_repos.py` Options
|
||||
|
||||
```bash
|
||||
python3 download_repos.py [OPTIONS]
|
||||
```
|
||||
|
||||
**Options:**
|
||||
|
||||
- `--output DIR` - Output directory (default: `data/repos`)
|
||||
- `--categories CAT1 CAT2 ...` - Categories to download: `nvim`, `lua`, `bash`, `zsh`, `python`, `hacking`, `security`, `all-open`
|
||||
- `--language LANG` - Single language to filter by
|
||||
- `--languages LANG1 LANG2 ...` - Multiple languages to download
|
||||
- `--license LICENSE` - License type (default: `mit`)
|
||||
- `--min-stars N` - Minimum stars (default: 100)
|
||||
- `--max-repos N` - Maximum repos per category/language (default: 50)
|
||||
- `--max-size N` - Maximum total size in GB (stops downloading when reached, e.g., `1024.0` for 1 TB)
|
||||
- `--full-clone` - Do full clone instead of shallow (slower but includes full history)
|
||||
|
||||
### `download_all_repos.py` Options
|
||||
|
||||
```bash
|
||||
python3 download_all_repos.py [OPTIONS]
|
||||
```
|
||||
|
||||
**Options:**
|
||||
|
||||
- `--max-repos N` - Maximum repos per category (default: 50)
|
||||
- `--min-stars N` - Minimum stars (default: 100)
|
||||
- `--output DIR` - Output directory (default: `data/repos`)
|
||||
- `--max-size N` - Maximum total size in GB (default: 1024.0 = 1 TB)
|
||||
- `--full-clone` - Do full clone instead of shallow
|
||||
|
||||
**Example:**
|
||||
```bash
|
||||
python3 download_all_repos.py --max-repos 100 --min-stars 200 --max-size 2048.0
|
||||
```
|
||||
|
||||
### Available Licenses
|
||||
|
||||
- `mit` (default)
|
||||
- `apache-2.0`
|
||||
- `bsd-3-clause`
|
||||
- `bsd-2-clause`
|
||||
- `isc`
|
||||
- `unlicense`
|
||||
- `mpl-2.0`
|
||||
- `lgpl-2.1`
|
||||
- `lgpl-3.0`
|
||||
- `gpl-2.0`
|
||||
- `gpl-3.0`
|
||||
|
||||
### Available Languages
|
||||
|
||||
- `python`
|
||||
- `javascript`
|
||||
- `typescript`
|
||||
- `java`
|
||||
- `cpp`
|
||||
- `c`
|
||||
- `go`
|
||||
- `rust`
|
||||
- `ruby`
|
||||
- `php`
|
||||
- `swift`
|
||||
- `kotlin`
|
||||
- `scala`
|
||||
- `r`
|
||||
- `sql`
|
||||
- `lua`
|
||||
- `shell` (for bash/shell scripts)
|
||||
|
||||
## Examples
|
||||
|
||||
### Example 1: Download All Categories (Simple)
|
||||
|
||||
```bash
|
||||
python3 download_all_repos.py
|
||||
```
|
||||
|
||||
Downloads all categories (nvim, lua, bash, zsh, python, hacking) with default settings and 1 TB size limit.
|
||||
|
||||
### Example 2: Download All Categories with Custom Settings
|
||||
|
||||
```bash
|
||||
python3 download_all_repos.py --max-repos 100 --min-stars 200 --max-size 2048.0
|
||||
```
|
||||
|
||||
Downloads all categories with:
|
||||
- 100 repos per category
|
||||
- Minimum 200 stars
|
||||
- 2 TB size limit
|
||||
|
||||
### Example 3: Download Specific Categories
|
||||
|
||||
```bash
|
||||
python3 download_repos.py --categories nvim lua bash zsh python hacking --max-repos 50
|
||||
```
|
||||
|
||||
Downloads specific categories with 50 repos each.
|
||||
|
||||
### Example 4: Download All Open-License Repos with Size Limit
|
||||
|
||||
```bash
|
||||
python3 download_repos.py --categories all-open --max-repos 1000 --max-size 1024.0
|
||||
```
|
||||
|
||||
Downloads up to 1000 repositories with any open license, stopping at 1 TB.
|
||||
|
||||
### Example 5: Download High-Quality Repos
|
||||
|
||||
```bash
|
||||
python3 download_repos.py --categories nvim lua bash zsh python hacking --min-stars 1000 --max-repos 20
|
||||
```
|
||||
|
||||
Downloads only highly popular repositories (1000+ stars).
|
||||
|
||||
### Example 6: Download Multiple Languages
|
||||
|
||||
```bash
|
||||
python3 download_repos.py --languages python javascript go rust --max-repos 50
|
||||
```
|
||||
|
||||
Downloads repositories in multiple programming languages.
|
||||
|
||||
### Example 7: Download with Apache License
|
||||
|
||||
```bash
|
||||
python3 download_repos.py --categories nvim --license apache-2.0 --max-repos 50
|
||||
```
|
||||
|
||||
Downloads Neovim repos with Apache 2.0 license.
|
||||
|
||||
### Example 8: Custom Output Directory
|
||||
|
||||
```bash
|
||||
python3 download_repos.py --categories nvim lua bash zsh python hacking --output /path/to/repos
|
||||
```
|
||||
|
||||
Saves repositories to a custom directory.
|
||||
|
||||
### Example 9: Full Clone (with History)
|
||||
|
||||
```bash
|
||||
python3 download_repos.py --categories nvim --full-clone --max-repos 10
|
||||
```
|
||||
|
||||
Does full clone including full git history (slower but more complete).
|
||||
|
||||
### Example 10: Size-Limited Download
|
||||
|
||||
```bash
|
||||
python3 download_repos.py --categories all-open --max-repos 2000 --max-size 512.0
|
||||
```
|
||||
|
||||
Downloads repositories but stops when reaching 512 GB (0.5 TB).
|
||||
|
||||
## Progress Tracking
|
||||
|
||||
The scripts include visual progress bars showing:
|
||||
|
||||
- **Category progress**: Overall progress across all categories
|
||||
- **Repository progress**: Progress for each category
|
||||
- **Real-time statistics**: Current repo, stars, language, cloned/failed counts
|
||||
- **Size tracking**: Current total size and size limit (when `--max-size` is used)
|
||||
|
||||
**Example output:**
|
||||
|
||||
```text
|
||||
📊 Current directory size: 45.23 GB
|
||||
📊 Size limit: 1024.00 GB
|
||||
📦 Processing 6 categories...
|
||||
Category: nvim: 100%|████████████| 6/6 [15:23<00:00, Size=156.78 GB, Total Cloned=300, Total Failed=2]
|
||||
Cloning nvim: 45%|████████████████▌ | 23/50 [02:15<03:45, Current=awesome-nvim, Stars=5.2k, Lang=Lua, Cloned=22, Failed=1, Size=12.45 GB]
|
||||
```
|
||||
|
||||
**Size limit reached:**
|
||||
|
||||
When the size limit is reached, the script will stop downloading and show:
|
||||
|
||||
```text
|
||||
⚠️ Size limit reached: 1024.00 GB >= 1024.00 GB
|
||||
Stopping all downloads.
|
||||
```
|
||||
|
||||
## GitHub API Rate Limits
|
||||
|
||||
GitHub API has rate limits:
|
||||
- **Unauthenticated**: 60 requests/hour
|
||||
- **Authenticated**: 5,000 requests/hour
|
||||
|
||||
### Using a GitHub Token
|
||||
|
||||
To increase rate limits, set a GitHub Personal Access Token:
|
||||
|
||||
```bash
|
||||
export GITHUB_TOKEN=your_token_here
|
||||
python3 download_repos.py --categories nvim lua bash hacking
|
||||
```
|
||||
|
||||
**How to create a token:**
|
||||
1. Go to GitHub Settings → Developer settings → Personal access tokens
|
||||
2. Generate new token (classic)
|
||||
3. Select scope: `public_repo` (read-only is enough)
|
||||
4. Copy token and set as environment variable
|
||||
|
||||
## Size Limits
|
||||
|
||||
The repository downloader includes automatic size limit checking to prevent running out of disk space.
|
||||
|
||||
### How It Works
|
||||
|
||||
- **Default limit**: 1 TB (1024 GB) for `download_all_repos.py`
|
||||
- **Customizable**: Use `--max-size` to set any limit
|
||||
- **Real-time tracking**: Size is checked before each repository clone
|
||||
- **Automatic stopping**: Downloads stop when limit is reached
|
||||
- **Progress display**: Current size shown in progress bars
|
||||
|
||||
### Setting Size Limits
|
||||
|
||||
**With `download_all_repos.py`:**
|
||||
```bash
|
||||
# Default 1 TB
|
||||
python3 download_all_repos.py
|
||||
|
||||
# Custom limit (2 TB)
|
||||
python3 download_all_repos.py --max-size 2048.0
|
||||
|
||||
# Smaller limit (500 GB)
|
||||
python3 download_all_repos.py --max-size 512.0
|
||||
```
|
||||
|
||||
**With `download_repos.py`:**
|
||||
```bash
|
||||
# No limit (downloads until max-repos reached)
|
||||
python3 download_repos.py --categories nvim --max-repos 100
|
||||
|
||||
# With 1 TB limit
|
||||
python3 download_repos.py --categories nvim --max-repos 1000 --max-size 1024.0
|
||||
```
|
||||
|
||||
### Size Calculation
|
||||
|
||||
The script calculates total size by:
|
||||
- Scanning all files in the output directory (`data/repos` by default)
|
||||
- Summing file sizes recursively
|
||||
- Checking before each new repository clone
|
||||
- Displaying human-readable sizes (B, KB, MB, GB, TB)
|
||||
|
||||
**Note:** Size checking happens before cloning, so the actual size may be slightly less than the limit when stopping.
|
||||
|
||||
## Cache and Resuming
|
||||
|
||||
The scripts automatically:
|
||||
|
||||
- **Skips existing repos**: If a repository already exists, it's skipped (no re-download)
|
||||
- **Resumes downloads**: You can run the script multiple times safely
|
||||
- **Progress tracking**: Shows what's already downloaded
|
||||
- **Size awareness**: Accounts for existing repositories when checking size limits
|
||||
|
||||
After downloading repositories, they're automatically processed during training:
|
||||
|
||||
```bash
|
||||
# Download repos
|
||||
python3 download_all_repos.py
|
||||
|
||||
# Train with all data (text + code)
|
||||
python3 train.py --data data/ --config config.json --device cuda
|
||||
```
|
||||
|
||||
The training script will:
|
||||
1. Process all your text data (Wiki, Books, Amazon reviews, etc.)
|
||||
2. Process all code repositories
|
||||
3. Combine everything into training data
|
||||
|
||||
## Supported File Types
|
||||
|
||||
The data processor automatically handles code files from repositories:
|
||||
|
||||
- **Text files**: `.txt`, `.md`, `.rst`, `.log`, `.csv`, `.json`, `.jsonl`, `.xml`, `.html`, `.htm`
|
||||
- **Code files**: `.py`, `.js`, `.ts`, `.java`, `.cpp`, `.c`, `.go`, `.rs`, `.rb`, `.php`, `.swift`, `.lua`, `.sh`, and 30+ more
|
||||
- **PDF files**: `.pdf` (if pdfplumber is installed)
|
||||
- **Images**: `.png`, `.jpg`, etc. (if OCR is set up)
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Rate Limit Exceeded
|
||||
|
||||
**Error:** `Rate limit exceeded`
|
||||
|
||||
**Solution:**
|
||||
1. Wait a few minutes and try again
|
||||
2. Use a GitHub token: `export GITHUB_TOKEN=your_token`
|
||||
3. Reduce `--max-repos` to download fewer repos per run
|
||||
|
||||
### Repository Clone Fails
|
||||
|
||||
**Error:** `Failed to clone repository`
|
||||
|
||||
**Possible causes:**
|
||||
- Repository was deleted or made private
|
||||
- Network issues
|
||||
- Repository is too large (timeout)
|
||||
|
||||
**Solution:**
|
||||
- The script continues with other repos
|
||||
- Failed repos are counted and reported at the end
|
||||
- You can re-run the script to retry failed repos
|
||||
|
||||
### No Repositories Found
|
||||
|
||||
**Error:** `No repositories found`
|
||||
|
||||
**Possible causes:**
|
||||
- Search query too restrictive
|
||||
- License filter too narrow
|
||||
- Minimum stars too high
|
||||
|
||||
**Solution:**
|
||||
- Lower `--min-stars` threshold
|
||||
- Try different `--license` options
|
||||
- Check if category name is correct
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Start Small
|
||||
|
||||
Test with a small number first:
|
||||
|
||||
```bash
|
||||
python3 download_repos.py --categories nvim --max-repos 10
|
||||
```
|
||||
|
||||
### 2. Use Size Limits
|
||||
|
||||
Always set a size limit to prevent running out of disk space:
|
||||
|
||||
```bash
|
||||
# Recommended: 1 TB limit
|
||||
python3 download_all_repos.py --max-size 1024.0
|
||||
|
||||
# Or custom limit based on available space
|
||||
python3 download_repos.py --categories all-open --max-size 512.0
|
||||
```
|
||||
|
||||
### 3. Use Shallow Clones
|
||||
|
||||
Shallow clones are faster and use less disk space:
|
||||
|
||||
```bash
|
||||
# Default (shallow clone)
|
||||
python3 download_repos.py --categories nvim
|
||||
|
||||
# Full clone (only if you need history)
|
||||
python3 download_repos.py --categories nvim --full-clone
|
||||
```
|
||||
|
||||
### 4. Filter by Quality
|
||||
|
||||
Use `--min-stars` to get quality repositories:
|
||||
|
||||
```bash
|
||||
python3 download_repos.py --categories nvim --min-stars 500 --max-repos 50
|
||||
```
|
||||
|
||||
### 5. Use GitHub Token
|
||||
|
||||
For large downloads, use a GitHub token:
|
||||
|
||||
```bash
|
||||
export GITHUB_TOKEN=your_token_here
|
||||
python3 download_all_repos.py --max-repos 100
|
||||
```
|
||||
|
||||
### 6. Monitor Disk Space
|
||||
|
||||
Check available disk space before starting:
|
||||
|
||||
```bash
|
||||
df -h data/repos
|
||||
```
|
||||
|
||||
### 7. Use `all-open` Category Wisely
|
||||
|
||||
The `all-open` category downloads broadly. Consider:
|
||||
- Setting a reasonable `--max-repos` limit
|
||||
- Using `--min-stars` to filter quality
|
||||
- Setting `--max-size` to prevent excessive downloads
|
||||
|
||||
```bash
|
||||
python3 download_repos.py --categories all-open --max-repos 500 --min-stars 200 --max-size 1024.0
|
||||
```
|
||||
|
||||
## Storage Considerations
|
||||
|
||||
### Size Limits and Disk Space Management
|
||||
|
||||
- **Default**: 1 TB (1024 GB) for `download_all_repos.py`
|
||||
- **Recommended**: Set based on available disk space
|
||||
- **Monitoring**: Script shows current size vs limit in progress bars
|
||||
|
||||
### Shallow vs Full Clones
|
||||
|
||||
**Shallow clones (default):**
|
||||
- Faster download
|
||||
- Less disk space (~10-50% of full clone)
|
||||
- No git history
|
||||
- Good for training data
|
||||
|
||||
**Full clones:**
|
||||
- Slower download
|
||||
- More disk space (includes full history)
|
||||
- Includes full git history
|
||||
- Useful if you need version history
|
||||
|
||||
**Typical sizes (shallow clones):**
|
||||
- Small repo: 1-10 MB
|
||||
- Medium repo: 10-100 MB
|
||||
- Large repo: 100 MB - 1 GB
|
||||
- Very large repo: 1-10 GB+
|
||||
|
||||
**Example:** Downloading 300 repositories with shallow clones typically uses 5-30 GB, depending on repository sizes.
|
||||
|
||||
### Estimating Storage Needs
|
||||
|
||||
To estimate how many repositories you can download:
|
||||
|
||||
1. **Check current size:**
|
||||
```bash
|
||||
du -sh data/repos
|
||||
```
|
||||
|
||||
2. **Calculate average repo size:**
|
||||
- Small repos: ~5 MB average
|
||||
- Medium repos: ~50 MB average
|
||||
- Large repos: ~500 MB average
|
||||
|
||||
3. **Estimate:**
|
||||
- 100 small repos: ~500 MB
|
||||
- 100 medium repos: ~5 GB
|
||||
- 100 large repos: ~50 GB
|
||||
- 1000 mixed repos: ~50-200 GB
|
||||
|
||||
4. **Set appropriate limit:**
|
||||
```bash
|
||||
# For 1 TB available space, use 900 GB limit (leave buffer)
|
||||
python3 download_all_repos.py --max-size 900.0
|
||||
```
|
||||
|
||||
## Summary
|
||||
|
||||
The repository downloader makes it easy to:
|
||||
- ✅ Automatically find high-quality open-source repositories
|
||||
- ✅ Filter by category, language, license, and popularity
|
||||
- ✅ Download with progress tracking and size monitoring
|
||||
- ✅ Set size limits to prevent running out of disk space
|
||||
- ✅ Integrate seamlessly with training pipeline
|
||||
- ✅ Resume interrupted downloads
|
||||
|
||||
**Available categories:**
|
||||
- `nvim` - Neovim configurations and plugins
|
||||
- `lua` - Lua programming repositories
|
||||
- `bash` - Bash/shell script repositories
|
||||
- `zsh` - Zsh configuration and plugins
|
||||
- `python` - Python programming repositories
|
||||
- `hacking` - Ethical hacking and cybersecurity tools
|
||||
- `security` - Security and cybersecurity repositories
|
||||
- `all-open` - All repositories with open licenses (any language)
|
||||
|
||||
**Quick commands to get started:**
|
||||
|
||||
```bash
|
||||
# Download all categories with 1 TB limit (recommended)
|
||||
python3 download_all_repos.py
|
||||
|
||||
# Download specific categories
|
||||
python3 download_repos.py --categories nvim lua bash zsh python hacking --max-repos 50
|
||||
|
||||
# Download all open-license repos with size limit
|
||||
python3 download_repos.py --categories all-open --max-repos 1000 --max-size 1024.0
|
||||
```
|
||||
|
||||
This downloads repositories and prepares them for training!
|
||||
143
docs/RETRAINING_GUIDE.md
Normal file
143
docs/RETRAINING_GUIDE.md
Normal file
@@ -0,0 +1,143 @@
|
||||
# Retraining Guide for Better Model Performance
|
||||
|
||||
## Current Issues
|
||||
|
||||
Your model only trained for **40 global steps** across 10 epochs, which means:
|
||||
- Very little training data (~4 batches per epoch)
|
||||
- Model hasn't learned language patterns
|
||||
- Model just repeats input and stops
|
||||
|
||||
## Retraining Recommendations
|
||||
|
||||
### 1. **Increase Training Data**
|
||||
|
||||
The model needs much more data. Check your current data:
|
||||
|
||||
```bash
|
||||
# Check how much data you have
|
||||
wc -l data/*.txt
|
||||
```
|
||||
|
||||
**Recommendations:**
|
||||
- **Minimum**: 10,000+ text samples
|
||||
- **Good**: 100,000+ text samples
|
||||
- **Better**: 1,000,000+ text samples
|
||||
|
||||
### 2. **Update Training Configuration**
|
||||
|
||||
Edit `config.json` for better training:
|
||||
|
||||
```json
|
||||
{
|
||||
"training": {
|
||||
"batch_size": 32,
|
||||
"max_epochs": 50, // Increase from 10 to 50+
|
||||
"learning_rate": 1e-4,
|
||||
"weight_decay": 0.01,
|
||||
"warmup_steps": 1000,
|
||||
"max_grad_norm": 1.0,
|
||||
"gradient_accumulation_steps": 4, // Increase to simulate larger batches
|
||||
"use_amp": true,
|
||||
"save_dir": "./checkpoints",
|
||||
"log_interval": 10, // More frequent logging
|
||||
"eval_interval": 500 // More frequent evaluation
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 3. **Add Validation Set**
|
||||
|
||||
Split your data for validation:
|
||||
|
||||
```python
|
||||
# In train.py, add validation split
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
train_texts, val_texts = train_test_split(texts, test_size=0.1, random_state=42)
|
||||
```
|
||||
|
||||
### 4. **Improve Training Data Quality**
|
||||
|
||||
Ensure your training data:
|
||||
- ✅ Contains complete sentences/paragraphs
|
||||
- ✅ Has diverse topics and styles
|
||||
- ✅ Doesn't have excessive padding
|
||||
- ✅ Uses proper text formatting
|
||||
|
||||
### 5. **Monitor Training**
|
||||
|
||||
Watch for:
|
||||
- **Loss decreasing**: Should trend downward
|
||||
- **Perplexity**: Should decrease (lower is better)
|
||||
- **Generation quality**: Test periodically during training
|
||||
|
||||
### 6. **Training Command**
|
||||
|
||||
```bash
|
||||
# Train with more data
|
||||
python3 train.py \
|
||||
--data data/your_training_data.txt \
|
||||
--config config.json \
|
||||
--output ./checkpoints \
|
||||
--device cpu # or cuda/mps
|
||||
```
|
||||
|
||||
### 7. **Check Training Progress**
|
||||
|
||||
During training, you should see:
|
||||
```
|
||||
Epoch 1: Train Loss = 8.5 → Epoch 10: Train Loss = 6.0 → Epoch 50: Train Loss = 3.5
|
||||
```
|
||||
|
||||
If loss stops decreasing, the model has converged.
|
||||
|
||||
### 8. **Early Stopping**
|
||||
|
||||
Consider adding early stopping if validation loss plateaus:
|
||||
- Stop if validation loss doesn't improve for 5 epochs
|
||||
- Save the best model based on validation loss
|
||||
|
||||
### 9. **Test During Training**
|
||||
|
||||
After each epoch, test generation:
|
||||
|
||||
```bash
|
||||
python3 inference.py \
|
||||
--checkpoint checkpoints/checkpoint_epoch_X.pt \
|
||||
--prompt "The future of" \
|
||||
--optimized
|
||||
```
|
||||
|
||||
Good training should show:
|
||||
- ✅ Model generates coherent text
|
||||
- ✅ Model continues beyond input prompt
|
||||
- ✅ Model doesn't immediately generate padding tokens
|
||||
|
||||
## Quick Start Retraining
|
||||
|
||||
1. **Get more training data** (most important!)
|
||||
2. **Update config.json** with more epochs
|
||||
3. **Start training**:
|
||||
```bash
|
||||
python3 train.py --data data/your_data.txt --config config.json
|
||||
```
|
||||
4. **Monitor loss** - should decrease over time
|
||||
5. **Test periodically** - check if generation improves
|
||||
|
||||
## Expected Results
|
||||
|
||||
After proper training:
|
||||
- Loss should decrease from ~8-10 to ~2-4
|
||||
- Perplexity should decrease from ~3000 to ~10-50
|
||||
- Model should generate 50+ tokens before stopping
|
||||
- Generated text should be coherent and diverse
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. ✅ Early stopping is now fixed (prevents padding tokens)
|
||||
2. ⏳ **Retrain with more data and epochs**
|
||||
3. ⏳ Monitor training metrics
|
||||
4. ⏳ Test generation quality during training
|
||||
|
||||
Good luck with retraining! 🚀
|
||||
|
||||
618
docs/SCHEDULING_EXPLAINED.md
Normal file
618
docs/SCHEDULING_EXPLAINED.md
Normal file
@@ -0,0 +1,618 @@
|
||||
# What is Scheduling? Step-by-Step Explanation
|
||||
|
||||
Complete step-by-step explanation of learning rate scheduling: how scheduling adjusts learning rates during training to improve convergence.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [What is Scheduling?](#81-what-is-scheduling)
|
||||
2. [Why Do We Need Scheduling?](#82-why-do-we-need-scheduling)
|
||||
3. [Fixed Learning Rate](#83-fixed-learning-rate)
|
||||
4. [Cosine Annealing](#84-cosine-annealing)
|
||||
5. [Other Scheduling Strategies](#85-other-scheduling-strategies)
|
||||
6. [Why Scheduling Matters](#86-why-scheduling-matters)
|
||||
7. [Complete Mathematical Formulation](#87-complete-mathematical-formulation)
|
||||
8. [Exercise: Schedule Calculation](#88-exercise-schedule-calculation)
|
||||
9. [Key Takeaways](#89-key-takeaways)
|
||||
|
||||
---
|
||||
|
||||
## 8.1 What is Scheduling?
|
||||
|
||||
### Simple Definition
|
||||
|
||||
**Scheduling** (learning rate scheduling) is the process of adjusting the learning rate during training to improve convergence and final model performance.
|
||||
|
||||
### Visual Analogy
|
||||
|
||||
**Think of scheduling like adjusting speed while driving:**
|
||||
|
||||
```
|
||||
Fixed Learning Rate:
|
||||
┌──────────────────────────┐
|
||||
│ Speed: 60 mph (constant) │
|
||||
└──────────────────────────┘
|
||||
→ Hard to stop precisely!
|
||||
|
||||
Scheduled Learning Rate:
|
||||
┌──────────────────────────┐
|
||||
│ Speed: 60 → 40 → 20 → 10 │
|
||||
└──────────────────────────┘
|
||||
→ Smooth deceleration!
|
||||
```
|
||||
|
||||
**Scheduling adjusts speed (learning rate) as you approach the destination (convergence)!**
|
||||
|
||||
### What Scheduling Does
|
||||
|
||||
**Scheduling:**
|
||||
1. **Starts** with higher learning rate (fast learning)
|
||||
2. **Gradually reduces** learning rate (precise fine-tuning)
|
||||
3. **Converges** to optimal solution
|
||||
|
||||
**Result:** Better convergence and performance!
|
||||
|
||||
---
|
||||
|
||||
## 8.2 Why Do We Need Scheduling?
|
||||
|
||||
### The Problem with Fixed Learning Rate
|
||||
|
||||
**High Learning Rate:**
|
||||
```
|
||||
Learning Rate: 0.001 (constant)
|
||||
→ Fast initial learning ✓
|
||||
→ But overshoots minimum ✗
|
||||
→ Bounces around ✗
|
||||
→ Poor convergence ✗
|
||||
```
|
||||
|
||||
**Low Learning Rate:**
|
||||
```
|
||||
Learning Rate: 0.0001 (constant)
|
||||
→ Stable convergence ✓
|
||||
→ But very slow learning ✗
|
||||
→ Takes forever to converge ✗
|
||||
```
|
||||
|
||||
**Can't have both!**
|
||||
|
||||
### The Solution: Scheduling
|
||||
|
||||
**Adaptive Learning Rate:**
|
||||
```
|
||||
Start: 0.001 (fast learning)
|
||||
Middle: 0.0005 (moderate)
|
||||
End: 0.0001 (fine-tuning)
|
||||
→ Fast initial learning ✓
|
||||
→ Stable convergence ✓
|
||||
→ Best of both worlds!
|
||||
```
|
||||
|
||||
### Benefits of Scheduling
|
||||
|
||||
**1. Faster Convergence**
|
||||
- High initial rate = Fast progress
|
||||
- Lower later rate = Precise convergence
|
||||
|
||||
**2. Better Final Performance**
|
||||
- Fine-tuning at end = Better solution
|
||||
- Avoids overshooting = More stable
|
||||
|
||||
**3. More Stable Training**
|
||||
- Gradual reduction = Smooth optimization
|
||||
- Less oscillation = More reliable
|
||||
|
||||
---
|
||||
|
||||
## 8.3 Fixed Learning Rate
|
||||
|
||||
### What is Fixed Learning Rate?
|
||||
|
||||
**Learning rate stays constant throughout training:**
|
||||
|
||||
```math
|
||||
\eta_t = \eta_0 \quad \text{for all } t
|
||||
```
|
||||
|
||||
**Where:**
|
||||
- $\eta_0$ = initial learning rate
|
||||
- $t$ = training step
|
||||
|
||||
### Example
|
||||
|
||||
**Fixed Rate:**
|
||||
```
|
||||
Step 0: η = 0.001
|
||||
Step 100: η = 0.001
|
||||
Step 1000: η = 0.001
|
||||
Step 10000: η = 0.001
|
||||
```
|
||||
|
||||
**Constant throughout!**
|
||||
|
||||
### Visualization
|
||||
|
||||
```
|
||||
Learning Rate
|
||||
│
|
||||
0.001│─────────────────────────────────────
|
||||
│
|
||||
│
|
||||
│
|
||||
│
|
||||
└───────────────────────────────────── Steps
|
||||
```
|
||||
|
||||
### Problems
|
||||
|
||||
**1. Too High:**
|
||||
- Overshoots minimum
|
||||
- Oscillates around solution
|
||||
- Never converges precisely
|
||||
|
||||
**2. Too Low:**
|
||||
- Very slow training
|
||||
- Takes forever to converge
|
||||
- May get stuck
|
||||
|
||||
**Solution:** Use scheduling!
|
||||
|
||||
---
|
||||
|
||||
## 8.4 Cosine Annealing
|
||||
|
||||
### What is Cosine Annealing?
|
||||
|
||||
**Cosine Annealing** reduces the learning rate following a cosine curve from maximum to minimum.
|
||||
|
||||
### Formula
|
||||
|
||||
```math
|
||||
\eta_t = \eta_{min} + (\eta_{max} - \eta_{min}) \times \frac{1 + \cos\left(\frac{\pi t}{T_{max}}\right)}{2}
|
||||
```
|
||||
|
||||
**Where:**
|
||||
- $\eta_t$ = learning rate at step $t$
|
||||
- $\eta_{min}$ = minimum learning rate (default: 0)
|
||||
- $\eta_{max}$ = initial/maximum learning rate
|
||||
- $T_{max}$ = total number of steps
|
||||
- $t$ = current step
|
||||
|
||||
### How It Works
|
||||
|
||||
**Step 1: Calculate Cosine Value**
|
||||
```math
|
||||
\cos\left(\frac{\pi t}{T_{max}}\right)
|
||||
```
|
||||
|
||||
**Step 2: Shift to [0, 1] Range**
|
||||
```math
|
||||
\frac{1 + \cos\left(\frac{\pi t}{T_{max}}\right)}{2}
|
||||
```
|
||||
|
||||
**Step 3: Scale to Learning Rate Range**
|
||||
```math
|
||||
\eta_{min} + (\eta_{max} - \eta_{min}) \times \text{scale}
|
||||
```
|
||||
|
||||
### Example Calculation
|
||||
|
||||
**Given:**
|
||||
- $\eta_{max} = 0.001$
|
||||
- $\eta_{min} = 0$
|
||||
- $T_{max} = 10000$
|
||||
|
||||
**At step $t = 0$:**
|
||||
```math
|
||||
\eta_0 = 0 + (0.001 - 0) \times \frac{1 + \cos(0)}{2} = 0.001 \times 1 = 0.001
|
||||
```
|
||||
|
||||
**At step $t = 2500$:**
|
||||
```math
|
||||
\eta_{2500} = 0 + 0.001 \times \frac{1 + \cos(\pi/4)}{2} = 0.001 \times \frac{1 + 0.707}{2} \approx 0.000854
|
||||
```
|
||||
|
||||
**At step $t = 5000$:**
|
||||
```math
|
||||
\eta_{5000} = 0 + 0.001 \times \frac{1 + \cos(\pi/2)}{2} = 0.001 \times \frac{1 + 0}{2} = 0.0005
|
||||
```
|
||||
|
||||
**At step $t = 7500$:**
|
||||
```math
|
||||
\eta_{7500} = 0 + 0.001 \times \frac{1 + \cos(3\pi/4)}{2} = 0.001 \times \frac{1 + (-0.707)}{2} \approx 0.000146
|
||||
```
|
||||
|
||||
**At step $t = 10000$:**
|
||||
```math
|
||||
\eta_{10000} = 0 + 0.001 \times \frac{1 + \cos(\pi)}{2} = 0.001 \times \frac{1 + (-1)}{2} = 0
|
||||
```
|
||||
|
||||
### Visualization
|
||||
|
||||
```
|
||||
Learning Rate
|
||||
│
|
||||
0.001 │●───────────────\
|
||||
│ \
|
||||
│ \
|
||||
0.0005│ \
|
||||
│ \
|
||||
│ \
|
||||
│ \
|
||||
│ \
|
||||
│ \
|
||||
│ \
|
||||
0│ ●─────
|
||||
└───────────────────────────────────── Steps
|
||||
0 2500 5000 7500 10000
|
||||
```
|
||||
|
||||
**Smooth cosine curve!**
|
||||
|
||||
### Why Cosine Annealing?
|
||||
|
||||
**Benefits:**
|
||||
1. **Smooth decay:** No abrupt changes
|
||||
2. **Gradual reduction:** Better fine-tuning
|
||||
3. **Works well:** Commonly used in practice
|
||||
4. **High initial rate:** Fast learning
|
||||
5. **Low final rate:** Precise convergence
|
||||
|
||||
---
|
||||
|
||||
## 8.5 Other Scheduling Strategies
|
||||
|
||||
### 1. Step Decay
|
||||
|
||||
**Reduce learning rate at fixed intervals:**
|
||||
|
||||
```math
|
||||
\eta_t = \eta_0 \times \gamma^{\lfloor t / s \rfloor}
|
||||
```
|
||||
|
||||
**Where:**
|
||||
- $\gamma$ = decay factor (e.g., 0.1)
|
||||
- $s$ = step size (e.g., every 1000 steps)
|
||||
|
||||
**Example:**
|
||||
```
|
||||
Step 0-999: η = 0.001
|
||||
Step 1000-1999: η = 0.0001 (×0.1)
|
||||
Step 2000-2999: η = 0.00001 (×0.1)
|
||||
```
|
||||
|
||||
**Visualization:**
|
||||
```
|
||||
Learning Rate
|
||||
│
|
||||
0.001 │───────┐
|
||||
│ │
|
||||
│ └───────┐
|
||||
0.0001│ │
|
||||
│ └───────┐
|
||||
│ │
|
||||
└───────────────────────── Steps
|
||||
```
|
||||
|
||||
### 2. Exponential Decay
|
||||
|
||||
**Continuous exponential reduction:**
|
||||
|
||||
```math
|
||||
\eta_t = \eta_0 \times \gamma^t
|
||||
```
|
||||
|
||||
**Where:**
|
||||
- $\gamma$ = decay rate (e.g., 0.9995)
|
||||
|
||||
**Visualization:**
|
||||
```
|
||||
Learning Rate
|
||||
│
|
||||
0.001│●──────────────\
|
||||
│ \
|
||||
│ \
|
||||
│ \
|
||||
│ \
|
||||
│ \
|
||||
│ \
|
||||
│ \
|
||||
└──────────────────────── Steps
|
||||
```
|
||||
|
||||
### 3. Warmup Scheduling
|
||||
|
||||
**Start with low rate, increase, then decrease:**
|
||||
|
||||
**Warmup Phase:**
|
||||
```math
|
||||
\eta_t = \eta_{max} \times \frac{t}{T_{warmup}}
|
||||
```
|
||||
|
||||
**After Warmup:**
|
||||
```math
|
||||
\eta_t = \text{Cosine Annealing or other schedule}
|
||||
```
|
||||
|
||||
**Visualization:**
|
||||
```
|
||||
Learning Rate
|
||||
│
|
||||
0.001│ ╱───────\
|
||||
│ ╱ \
|
||||
│ ╱ \
|
||||
│ ╱ \
|
||||
│ ╱ \
|
||||
│ ╱ \
|
||||
│╱ \
|
||||
└───────────────────── Steps
|
||||
```
|
||||
|
||||
### 4. One Cycle Learning Rate
|
||||
|
||||
**One cycle: increase then decrease:**
|
||||
|
||||
```math
|
||||
\eta_t = \begin{cases}
|
||||
\eta_{min} + (\eta_{max} - \eta_{min}) \times \frac{t}{T_1} & t \leq T_1 \\
|
||||
\eta_{max} - (\eta_{max} - \eta_{min}) \times \frac{t - T_1}{T_2} & t > T_1
|
||||
\end{cases}
|
||||
```
|
||||
|
||||
**Visualization:**
|
||||
```
|
||||
Learning Rate
|
||||
│
|
||||
0.001│ ╱─────\
|
||||
│ ╱ \
|
||||
│ ╱ \
|
||||
│ ╱ \
|
||||
│ ╱ \
|
||||
│ ╱ \
|
||||
│╱ \
|
||||
└─────────────────── Steps
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8.6 Why Scheduling Matters
|
||||
|
||||
### Benefit 1: Better Convergence
|
||||
|
||||
**Without Scheduling:**
|
||||
```
|
||||
Loss: 3.0 → 2.5 → 2.3 → 2.2 → 2.15 → 2.12 → ...
|
||||
(slow convergence at end)
|
||||
```
|
||||
|
||||
**With Scheduling:**
|
||||
```
|
||||
Loss: 3.0 → 2.5 → 2.3 → 2.2 → 2.1 → 2.05 → ...
|
||||
(faster convergence, better final loss)
|
||||
```
|
||||
|
||||
### Benefit 2: More Stable Training
|
||||
|
||||
**Fixed High Rate:**
|
||||
```
|
||||
Loss: 3.0 → 2.5 → 2.3 → 2.4 → 2.3 → 2.4 → ...
|
||||
(oscillating, unstable)
|
||||
```
|
||||
|
||||
**Scheduled Rate:**
|
||||
```
|
||||
Loss: 3.0 → 2.5 → 2.3 → 2.2 → 2.15 → 2.12 → ...
|
||||
(smooth, stable)
|
||||
```
|
||||
|
||||
### Benefit 3: Better Final Performance
|
||||
|
||||
**Comparison:**
|
||||
```
|
||||
Fixed LR: Final Loss = 2.15
|
||||
Scheduled LR: Final Loss = 2.05
|
||||
|
||||
→ 5% improvement!
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8.7 Complete Mathematical Formulation
|
||||
|
||||
### General Scheduling Formula
|
||||
|
||||
```math
|
||||
\eta_t = f(t, \eta_0, \eta_{min}, T_{max}, ...)
|
||||
```
|
||||
|
||||
**Where $f$ is the scheduling function**
|
||||
|
||||
### Cosine Annealing (Complete)
|
||||
|
||||
```math
|
||||
\eta_t = \eta_{min} + (\eta_{max} - \eta_{min}) \times \frac{1 + \cos\left(\frac{\pi t}{T_{max}}\right)}{2}
|
||||
```
|
||||
|
||||
**Boundary Conditions:**
|
||||
- At $t = 0$: $\eta_0 = \eta_{max}$
|
||||
- At $t = T_{max}$: $\eta_{T_{max}} = \eta_{min}$
|
||||
|
||||
### Step Decay
|
||||
|
||||
```math
|
||||
\eta_t = \eta_0 \times \gamma^{\lfloor t / s \rfloor}
|
||||
```
|
||||
|
||||
### Exponential Decay
|
||||
|
||||
```math
|
||||
\eta_t = \eta_0 \times \gamma^t
|
||||
```
|
||||
|
||||
### Warmup + Cosine Annealing
|
||||
|
||||
**Warmup Phase ($t \leq T_{warmup}$):**
|
||||
```math
|
||||
\eta_t = \eta_{max} \times \frac{t}{T_{warmup}}
|
||||
```
|
||||
|
||||
**Annealing Phase ($t > T_{warmup}$):**
|
||||
```math
|
||||
\eta_t = \eta_{min} + (\eta_{max} - \eta_{min}) \times \frac{1 + \cos\left(\frac{\pi (t - T_{warmup})}{T_{max} - T_{warmup}}\right)}{2}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8.8 Exercise: Schedule Calculation
|
||||
|
||||
### Problem
|
||||
|
||||
**Given Cosine Annealing schedule:**
|
||||
|
||||
- $\eta_{max} = 0.002$
|
||||
- $\eta_{min} = 0.0001$
|
||||
- $T_{max} = 5000$ steps
|
||||
|
||||
**Calculate the learning rate at:**
|
||||
1. Step $t = 0$
|
||||
2. Step $t = 1250$
|
||||
3. Step $t = 2500$
|
||||
4. Step $t = 3750$
|
||||
5. Step $t = 5000$
|
||||
|
||||
### Step-by-Step Solution
|
||||
|
||||
#### General Formula
|
||||
|
||||
```math
|
||||
\eta_t = \eta_{min} + (\eta_{max} - \eta_{min}) \times \frac{1 + \cos\left(\frac{\pi t}{T_{max}}\right)}{2}
|
||||
```
|
||||
|
||||
**Substitute values:**
|
||||
```math
|
||||
\eta_t = 0.0001 + (0.002 - 0.0001) \times \frac{1 + \cos\left(\frac{\pi t}{5000}\right)}{2}
|
||||
```
|
||||
|
||||
```math
|
||||
\eta_t = 0.0001 + 0.0019 \times \frac{1 + \cos\left(\frac{\pi t}{5000}\right)}{2}
|
||||
```
|
||||
|
||||
#### Step 1: t = 0
|
||||
|
||||
```math
|
||||
\eta_0 = 0.0001 + 0.0019 \times \frac{1 + \cos(0)}{2}
|
||||
```
|
||||
|
||||
```math
|
||||
\eta_0 = 0.0001 + 0.0019 \times \frac{1 + 1}{2}
|
||||
```
|
||||
|
||||
```math
|
||||
\eta_0 = 0.0001 + 0.0019 \times 1 = 0.0001 + 0.0019 = 0.002
|
||||
```
|
||||
|
||||
**Answer:** $\eta_0 = 0.002$
|
||||
|
||||
#### Step 2: t = 1250
|
||||
|
||||
```math
|
||||
\eta_{1250} = 0.0001 + 0.0019 \times \frac{1 + \cos(\pi/4)}{2}
|
||||
```
|
||||
|
||||
```math
|
||||
\eta_{1250} = 0.0001 + 0.0019 \times \frac{1 + 0.707}{2}
|
||||
```
|
||||
|
||||
```math
|
||||
\eta_{1250} = 0.0001 + 0.0019 \times 0.8535 = 0.0001 + 0.001621 = 0.001721
|
||||
```
|
||||
|
||||
**Answer:** $\eta_{1250} \approx 0.001721$
|
||||
|
||||
#### Step 3: t = 2500
|
||||
|
||||
```math
|
||||
\eta_{2500} = 0.0001 + 0.0019 \times \frac{1 + \cos(\pi/2)}{2}
|
||||
```
|
||||
|
||||
```math
|
||||
\eta_{2500} = 0.0001 + 0.0019 \times \frac{1 + 0}{2}
|
||||
```
|
||||
|
||||
```math
|
||||
\eta_{2500} = 0.0001 + 0.0019 \times 0.5 = 0.0001 + 0.00095 = 0.00105
|
||||
```
|
||||
|
||||
**Answer:** $\eta_{2500} = 0.00105$
|
||||
|
||||
#### Step 4: t = 3750
|
||||
|
||||
```math
|
||||
\eta_{3750} = 0.0001 + 0.0019 \times \frac{1 + \cos(3\pi/4)}{2}
|
||||
```
|
||||
|
||||
```math
|
||||
\eta_{3750} = 0.0001 + 0.0019 \times \frac{1 + (-0.707)}{2}
|
||||
```
|
||||
|
||||
```math
|
||||
\eta_{3750} = 0.0001 + 0.0019 \times 0.1465 = 0.0001 + 0.000278 = 0.000378
|
||||
```
|
||||
|
||||
**Answer:** $\eta_{3750} \approx 0.000378$
|
||||
|
||||
#### Step 5: t = 5000
|
||||
|
||||
```math
|
||||
\eta_{5000} = 0.0001 + 0.0019 \times \frac{1 + \cos(\pi)}{2}
|
||||
```
|
||||
|
||||
```math
|
||||
\eta_{5000} = 0.0001 + 0.0019 \times \frac{1 + (-1)}{2}
|
||||
```
|
||||
|
||||
```math
|
||||
\eta_{5000} = 0.0001 + 0.0019 \times 0 = 0.0001 + 0 = 0.0001
|
||||
```
|
||||
|
||||
**Answer:** $\eta_{5000} = 0.0001$
|
||||
|
||||
### Summary Table
|
||||
|
||||
| Step | Cosine Value | Scale Factor | Learning Rate |
|
||||
|------|--------------|--------------|---------------|
|
||||
| 0 | 1.0 | 1.0 | 0.002 |
|
||||
| 1250 | 0.707 | 0.854 | 0.001721 |
|
||||
| 2500 | 0.0 | 0.5 | 0.00105 |
|
||||
| 3750 | -0.707 | 0.146 | 0.000378 |
|
||||
| 5000 | -1.0 | 0.0 | 0.0001 |
|
||||
|
||||
**Smooth decay from 0.002 to 0.0001!**
|
||||
|
||||
---
|
||||
|
||||
## 8.9 Key Takeaways
|
||||
|
||||
### Scheduling
|
||||
|
||||
✅ **Scheduling adjusts learning rate during training**
|
||||
✅ **Starts high (fast learning), ends low (fine-tuning)**
|
||||
✅ **Improves convergence and final performance**
|
||||
|
||||
### Cosine Annealing
|
||||
|
||||
✅ **Smooth cosine-based decay**
|
||||
✅ **Gradual reduction from max to min**
|
||||
✅ **Works well for transformers**
|
||||
|
||||
### Why Important
|
||||
|
||||
✅ **Faster convergence**
|
||||
✅ **More stable training**
|
||||
✅ **Better final performance**
|
||||
✅ **Essential for optimal training**
|
||||
|
||||
---
|
||||
|
||||
*This document provides a comprehensive explanation of learning rate scheduling, including cosine annealing and other strategies with mathematical formulations and solved exercises.*
|
||||
|
||||
978
docs/TOKENIZATION_EXPLAINED.md
Normal file
978
docs/TOKENIZATION_EXPLAINED.md
Normal file
@@ -0,0 +1,978 @@
|
||||
# Tokenization Explained - Mathematical Formulation
|
||||
|
||||
Complete mathematical derivation and step-by-step explanation of tokenization, Byte Pair Encoding (BPE), and UTF-8 encoding used in the SheepOp Language Model.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [Introduction to Tokenization](#1-introduction-to-tokenization)
|
||||
2. [UTF-8 Encoding](#2-utf-8-encoding)
|
||||
3. [Byte Pair Encoding Algorithm](#3-byte-pair-encoding-algorithm)
|
||||
4. [Vocabulary Construction](#4-vocabulary-construction)
|
||||
5. [Encoding Process](#5-encoding-process)
|
||||
6. [Decoding Process](#6-decoding-process)
|
||||
7. [Regex Pattern Splitting](#7-regex-pattern-splitting)
|
||||
8. [Special Tokens](#8-special-tokens)
|
||||
9. [Complete Tokenization Pipeline](#9-complete-tokenization-pipeline)
|
||||
10. [Tokenization Challenges and Solutions](#10-tokenization-challenges-and-solutions)
|
||||
|
||||
---
|
||||
|
||||
## 1. Introduction to Tokenization
|
||||
|
||||
### 1.1 What is Tokenization?
|
||||
|
||||
**Tokenization** is the process of converting raw text into a sequence of discrete tokens (integers) that can be processed by neural networks.
|
||||
|
||||
**Mathematical Definition:**
|
||||
|
||||
Given a text string $s \in \Sigma^*$ where $\Sigma$ is the character alphabet, tokenization maps:
|
||||
|
||||
```math
|
||||
\mathcal{T}: \Sigma^* \rightarrow \mathbb{N}^*
|
||||
```
|
||||
|
||||
```math
|
||||
s = c_1 c_2 \ldots c_n \mapsto \mathbf{t} = [t_1, t_2, \ldots, t_m]
|
||||
```
|
||||
|
||||
where:
|
||||
|
||||
- $s$ = input text string
|
||||
- $c_i$ = individual characters
|
||||
- $\mathbf{t}$ = sequence of token IDs
|
||||
- $n$ = number of characters
|
||||
- $m$ = number of tokens (typically $m \leq n$)
|
||||
|
||||
### 1.2 Why Tokenization?
|
||||
|
||||
**Problem:** Neural networks require numerical inputs, not raw text.
|
||||
|
||||
**Solution:** Convert text to token sequences:
|
||||
|
||||
```mermaid
|
||||
graph LR
|
||||
A["Raw Text<br/>'Hello world'"] --> B[Tokenizer]
|
||||
B --> C["Token IDs<br/>[15496, 1917]"]
|
||||
C --> D[Embedding Layer]
|
||||
D --> E["Embeddings<br/>ℝ^n×d"]
|
||||
|
||||
style A fill:#e1f5ff
|
||||
style B fill:#fff4e1
|
||||
style C fill:#e1ffe1
|
||||
style E fill:#ffe1f5
|
||||
```
|
||||
|
||||
### 1.3 Tokenization Approaches
|
||||
|
||||
**Three Main Approaches:**
|
||||
|
||||
1. **Character-Level**: Each character becomes a token
|
||||
- Vocabulary size: ~100-200
|
||||
- Sequence length: Very long
|
||||
|
||||
2. **Word-Level**: Each word becomes a token
|
||||
- Vocabulary size: ~50,000-100,000
|
||||
- Sequence length: Moderate
|
||||
|
||||
3. **Subword-Level (BPE)**: Sequences of bytes/characters become tokens
|
||||
- Vocabulary size: ~30,000-100,000
|
||||
- Sequence length: Efficient
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Character-Level"
|
||||
C1["'Hello'"] --> C2["['H','e','l','l','o']"]
|
||||
C2 --> C3["5 tokens"]
|
||||
end
|
||||
|
||||
subgraph "Word-Level"
|
||||
W1["'Hello world'"] --> W2["['Hello', 'world']"]
|
||||
W2 --> W3["2 tokens"]
|
||||
end
|
||||
|
||||
subgraph "Subword-Level (BPE)"
|
||||
B1["'Hello world'"] --> B2["['Hello', ' world']"]
|
||||
B2 --> B3["2 tokens"]
|
||||
end
|
||||
|
||||
style C3 fill:#ffcccc
|
||||
style W3 fill:#ccffcc
|
||||
style B3 fill:#ccffcc
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 2. UTF-8 Encoding
|
||||
|
||||
### 2.1 Unicode Code Points
|
||||
|
||||
**Unicode** defines a mapping from characters to integers (code points):
|
||||
|
||||
```math
|
||||
f: \mathcal{C} \rightarrow \{0, 1, \ldots, 0x10FFFF\}
|
||||
```
|
||||
|
||||
where $\mathcal{C}$ is the set of all Unicode characters.
|
||||
|
||||
**Example:**
|
||||
|
||||
```math
|
||||
f('H') = 72, \quad f('ε') = 949, \quad f('中') = 20013
|
||||
```
|
||||
|
||||
### 2.2 UTF-8 Encoding Function
|
||||
|
||||
**UTF-8** encodes Unicode code points into variable-length byte sequences:
|
||||
|
||||
```math
|
||||
\text{UTF-8}: \{0, \ldots, 0x10FFFF\} \rightarrow \{0, \ldots, 255\}^*
|
||||
```
|
||||
|
||||
**Encoding Rules:**
|
||||
|
||||
For a code point $c$:
|
||||
|
||||
```math
|
||||
\text{UTF-8}(c) = \begin{cases}
|
||||
[b_0] & \text{if } c < 128 \\
|
||||
[b_0, b_1] & \text{if } c < 2048 \\
|
||||
[b_0, b_1, b_2] & \text{if } c < 65536 \\
|
||||
[b_0, b_1, b_2, b_3] & \text{if } c < 0x10FFFF
|
||||
\end{cases}
|
||||
```
|
||||
|
||||
where $b_i \in \{0, \ldots, 255\}$ are bytes.
|
||||
|
||||
### 2.3 UTF-8 Encoding Process
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
A["Unicode String<br/>'Hello 世界'"] --> B[Extract Code Points]
|
||||
B --> C["[72, 101, 108, 108, 111, 32, 19990, 30028]"]
|
||||
C --> D[UTF-8 Encode]
|
||||
D --> E["Bytes<br/>[72, 101, 108, 108, 111, 32, 228, 184, 150, 231, 149, 140]"]
|
||||
|
||||
style A fill:#e1f5ff
|
||||
style E fill:#e1ffe1
|
||||
```
|
||||
|
||||
**Mathematical Formulation:**
|
||||
|
||||
For a string $s = c_1 c_2 \ldots c_n$:
|
||||
|
||||
```math
|
||||
\text{bytes}(s) = \bigoplus_{i=1}^n \text{UTF-8}(f(c_i))
|
||||
```
|
||||
|
||||
where $\bigoplus$ denotes byte concatenation.
|
||||
|
||||
**Example:**
|
||||
|
||||
```math
|
||||
\text{bytes}("Hi") = \text{UTF-8}(72) \oplus \text{UTF-8}(105) = [72] \oplus [105] = [72, 105]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. Byte Pair Encoding Algorithm
|
||||
|
||||
### 3.1 BPE Overview
|
||||
|
||||
**Byte Pair Encoding (BPE)** is a data compression algorithm that iteratively merges the most frequent consecutive pairs.
|
||||
|
||||
**Goal:** Create an efficient vocabulary by merging frequent byte/character pairs.
|
||||
|
||||
### 3.2 BPE Training Algorithm
|
||||
|
||||
**Initial State:**
|
||||
|
||||
```math
|
||||
V^{(0)} = \{0, 1, \ldots, 255\} \quad \text{(all bytes)}
|
||||
```
|
||||
|
||||
```math
|
||||
\text{tokens}^{(0)} = \text{bytes}(s) = [b_1, b_2, \ldots, b_n]
|
||||
```
|
||||
|
||||
**Iterative Merging:**
|
||||
|
||||
For iteration $k = 1, 2, \ldots, K$:
|
||||
|
||||
**Step 1: Calculate Pair Frequencies**
|
||||
|
||||
```math
|
||||
\text{stats}^{(k)} = \text{CountPairs}(\text{tokens}^{(k-1)})
|
||||
```
|
||||
|
||||
```math
|
||||
\text{stats}^{(k)}(i, j) = |\{p : \text{tokens}^{(k-1)}_p = i \land \text{tokens}^{(k-1)}_{p+1} = j\}|
|
||||
```
|
||||
|
||||
**Step 2: Find Most Frequent Pair**
|
||||
|
||||
```math
|
||||
(i^*, j^*) = \arg\max_{(i,j)} \text{stats}^{(k)}(i, j)
|
||||
```
|
||||
|
||||
**Step 3: Create New Token**
|
||||
|
||||
```math
|
||||
V^{(k)} = V^{(k-1)} \cup \{256 + k - 1\}
|
||||
```
|
||||
|
||||
```math
|
||||
\text{merges}^{(k)} = \text{merges}^{(k-1)} \cup \{(i^*, j^*) \mapsto 256 + k - 1\}
|
||||
```
|
||||
|
||||
**Step 4: Apply Merge**
|
||||
|
||||
```math
|
||||
\text{tokens}^{(k)} = \text{Merge}(\text{tokens}^{(k-1)}, (i^*, j^*), 256 + k - 1)
|
||||
```
|
||||
|
||||
where:
|
||||
|
||||
```math
|
||||
\text{Merge}(T, (a, b), \text{new\_id})_i = \begin{cases}
|
||||
\text{new\_id} & \text{if } T_i = a \land T_{i+1} = b \\
|
||||
T_i & \text{otherwise}
|
||||
\end{cases}
|
||||
```
|
||||
|
||||
### 3.3 BPE Training Flowchart
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
Start[Start Training] --> Init["Initialize<br/>V⁽⁰⁾ = {0...255}<br/>tokens⁽⁰⁾ = bytes(text)"]
|
||||
Init --> Iter{More Merges?}
|
||||
|
||||
Iter -->|Yes| Stats["Calculate Pair Frequencies<br/>stats⁽ᵏ⁾(i,j)"]
|
||||
Stats --> Find["Find Most Frequent Pair<br/>(i*, j*) = argmax stats"]
|
||||
Find --> Create["Create New Token<br/>id = 256 + k"]
|
||||
Create --> Merge["Merge All Occurrences<br/>tokens⁽ᵏ⁾ = Merge(tokens⁽ᵏ⁻¹⁾, (i*,j*), id)"]
|
||||
Merge --> Update["Update Vocabulary<br/>V⁽ᵏ⁾ = V⁽ᵏ⁻¹⁾ ∪ {id}"]
|
||||
Update --> Iter
|
||||
|
||||
Iter -->|No| End[Training Complete]
|
||||
|
||||
style Start fill:#e1f5ff
|
||||
style End fill:#e1ffe1
|
||||
style Stats fill:#fff4e1
|
||||
style Merge fill:#ffe1f5
|
||||
```
|
||||
|
||||
### 3.4 BPE Example
|
||||
|
||||
**Example:** Training on text "aaab"
|
||||
|
||||
**Iteration 0:**
|
||||
|
||||
```math
|
||||
V^{(0)} = \{0, 1, \ldots, 255\}
|
||||
```
|
||||
|
||||
```math
|
||||
\text{tokens}^{(0)} = [97, 97, 97, 98] \quad \text{(bytes for 'aaab')}
|
||||
```
|
||||
|
||||
**Calculate frequencies:**
|
||||
|
||||
```math
|
||||
\text{stats}^{(0)} = \{(97, 97): 2, (97, 98): 1\}
|
||||
```
|
||||
|
||||
**Iteration 1:**
|
||||
|
||||
```math
|
||||
(i^*, j^*) = (97, 97), \quad \text{new\_id} = 256
|
||||
```
|
||||
|
||||
```math
|
||||
\text{tokens}^{(1)} = [256, 97, 98]
|
||||
```
|
||||
|
||||
```math
|
||||
V^{(1)} = V^{(0)} \cup \{256\}
|
||||
```
|
||||
|
||||
**Iteration 2:**
|
||||
|
||||
```math
|
||||
\text{stats}^{(1)} = \{(256, 97): 1, (97, 98): 1\}
|
||||
```
|
||||
|
||||
**Choose one:** \((256, 97)\), \(\text{new_id} = 257\)
|
||||
|
||||
```math
|
||||
\text{tokens}^{(2)} = [257, 98]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. Vocabulary Construction
|
||||
|
||||
### 4.1 Vocabulary Structure
|
||||
|
||||
**Final Vocabulary:**
|
||||
|
||||
```math
|
||||
V = V_{\text{bytes}} \cup V_{\text{merges}} \cup V_{\text{special}}
|
||||
```
|
||||
|
||||
where:
|
||||
|
||||
- $V_{\text{bytes}} = \{0, 1, \ldots, 255\}$ (256 tokens)
|
||||
- $V_{\text{merges}} = \{256, 257, \ldots, 256 + K - 1\}$ (K merged tokens)
|
||||
- $V_{\text{special}} = \{\text{<pad>}, \text{<unk>}, \text{<bos>}, \text{<eos>}\}$
|
||||
|
||||
**Vocabulary Size:**
|
||||
|
||||
```math
|
||||
|V| = 256 + K + |V_{\text{special}}|
|
||||
```
|
||||
|
||||
### 4.2 Merge Dictionary
|
||||
|
||||
**Merge Dictionary:**
|
||||
|
||||
```math
|
||||
M: \mathbb{N} \times \mathbb{N} \rightarrow \mathbb{N}
|
||||
```
|
||||
|
||||
```math
|
||||
M(i, j) = \begin{cases}
|
||||
\text{merged\_id} & \text{if } (i, j) \text{ was merged} \\
|
||||
\text{undefined} & \text{otherwise}
|
||||
\end{cases}
|
||||
```
|
||||
|
||||
**Vocabulary Mapping:**
|
||||
|
||||
```math
|
||||
\text{vocab}: \mathbb{N} \rightarrow \{0, \ldots, 255\}^*
|
||||
```
|
||||
|
||||
```math
|
||||
\text{vocab}(\text{id}) = \begin{cases}
|
||||
[\text{id}] & \text{if } \text{id} < 256 \\
|
||||
\text{vocab}(i) \oplus \text{vocab}(j) & \text{if } M(i, j) = \text{id}
|
||||
\end{cases}
|
||||
```
|
||||
|
||||
### 4.3 Vocabulary Construction Diagram
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Initial Vocabulary"
|
||||
A1["256 Byte Tokens<br/>{0, 1, ..., 255}"]
|
||||
end
|
||||
|
||||
subgraph "BPE Merges"
|
||||
B1["Merge 1: (101, 32) → 256"]
|
||||
B2["Merge 2: (256, 108) → 257"]
|
||||
B3["Merge K: ... → 256+K-1"]
|
||||
end
|
||||
|
||||
subgraph "Special Tokens"
|
||||
C1["<pad> → 50256"]
|
||||
C2["<unk> → 50257"]
|
||||
C3["<bos> → 50258"]
|
||||
C4["<eos> → 50259"]
|
||||
end
|
||||
|
||||
A1 --> D[Final Vocabulary]
|
||||
B1 --> D
|
||||
B2 --> D
|
||||
B3 --> D
|
||||
C1 --> D
|
||||
C2 --> D
|
||||
C3 --> D
|
||||
C4 --> D
|
||||
|
||||
D --> E["|V| = 256 + K + 4"]
|
||||
|
||||
style A1 fill:#e1f5ff
|
||||
style D fill:#e1ffe1
|
||||
style E fill:#fff4e1
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. Encoding Process
|
||||
|
||||
### 5.1 Encoding Algorithm
|
||||
|
||||
**Input:** Text string $s$
|
||||
|
||||
**Step 1: Regex Splitting**
|
||||
|
||||
```math
|
||||
\text{chunks} = \text{RegexSplit}(s, \text{pattern})
|
||||
```
|
||||
|
||||
**Step 2: Convert to Bytes**
|
||||
|
||||
For each chunk $c \in \text{chunks}$:
|
||||
|
||||
```math
|
||||
\text{bytes}(c) = \text{UTF-8}(c) = [b_1, b_2, \ldots, b_n]
|
||||
```
|
||||
|
||||
**Step 3: Apply BPE Merges**
|
||||
|
||||
Initialize: $\text{tokens} = \text{bytes}(c)$
|
||||
|
||||
While merges possible:
|
||||
|
||||
```math
|
||||
\text{Find earliest merge: } (i^*, j^*) = \arg\min_{(i,j) \in M} \text{merge\_index}(M(i, j))
|
||||
```
|
||||
|
||||
```math
|
||||
\text{Apply merge: } \text{tokens} = \text{Merge}(\text{tokens}, (i^*, j^*), M(i^*, j^*))
|
||||
```
|
||||
|
||||
**Step 4: Combine Results**
|
||||
|
||||
```math
|
||||
\text{token\_ids} = \bigoplus_{c \in \text{chunks}} \text{BPE}(\text{bytes}(c))
|
||||
```
|
||||
|
||||
### 5.2 Encoding Function
|
||||
|
||||
**Mathematical Definition:**
|
||||
|
||||
```math
|
||||
\text{encode}(s) = \bigoplus_{c \in \text{RegexSplit}(s)} \text{BPE}(\text{UTF-8}(c))
|
||||
```
|
||||
|
||||
where $\bigoplus$ denotes token sequence concatenation.
|
||||
|
||||
### 5.3 Encoding Flowchart
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
A["Input Text<br/>'Hello world'"] --> B[Regex Split]
|
||||
B --> C["Chunks<br/>['Hello', ' world']"]
|
||||
|
||||
C --> D[For Each Chunk]
|
||||
D --> E["UTF-8 Encode<br/>bytes = [72, 101, 108, 108, 111]"]
|
||||
E --> F["Initialize Tokens<br/>tokens = bytes"]
|
||||
|
||||
F --> G{Merge Possible?}
|
||||
G -->|Yes| H["Find Earliest Merge<br/>(i*, j*)"]
|
||||
H --> I["Apply Merge<br/>tokens = Merge(tokens, (i*,j*), id)"]
|
||||
I --> G
|
||||
|
||||
G -->|No| J[Add to Result]
|
||||
J --> K{More Chunks?}
|
||||
K -->|Yes| D
|
||||
K -->|No| L["Final Token IDs<br/>[15496, 1917]"]
|
||||
|
||||
style A fill:#e1f5ff
|
||||
style L fill:#e1ffe1
|
||||
style G fill:#ffe1f5
|
||||
```
|
||||
|
||||
### 5.4 Encoding Example
|
||||
|
||||
**Input:** "Hello world"
|
||||
|
||||
**Step 1: Regex Split**
|
||||
|
||||
```math
|
||||
\text{chunks} = ["Hello", " world"]
|
||||
```
|
||||
|
||||
**Step 2: UTF-8 Encoding**
|
||||
|
||||
```math
|
||||
\text{UTF-8}("Hello") = [72, 101, 108, 108, 111]
|
||||
```
|
||||
|
||||
```math
|
||||
\text{UTF-8}(" world") = [32, 119, 111, 114, 108, 100]
|
||||
```
|
||||
|
||||
**Step 3: BPE Merging**
|
||||
|
||||
Assume merges: $M(72, 101) = 256, M(256, 108) = 257, M(257, 108) = 258, M(258, 111) = 259$
|
||||
|
||||
```math
|
||||
[72, 101, 108, 108, 111] \xrightarrow{M(72,101)} [256, 108, 108, 111]
|
||||
```
|
||||
|
||||
```math
|
||||
\xrightarrow{M(256,108)} [257, 108, 111] \xrightarrow{M(257,108)} [258, 111]
|
||||
```
|
||||
|
||||
```math
|
||||
\xrightarrow{M(258,111)} [259]
|
||||
```
|
||||
|
||||
**Final:** $[259, 1917]$
|
||||
|
||||
---
|
||||
|
||||
## 6. Decoding Process
|
||||
|
||||
### 6.1 Decoding Algorithm
|
||||
|
||||
**Input:** Token IDs $\mathbf{t} = [t_1, t_2, \ldots, t_n]$
|
||||
|
||||
**Step 1: Handle Special Tokens**
|
||||
|
||||
```math
|
||||
\text{Stop at EOS: } \mathbf{t}' = \mathbf{t}[:i] \text{ where } t_i = \text{<eos>}
|
||||
```
|
||||
|
||||
**Step 2: Lookup Bytes**
|
||||
|
||||
For each token $t_i$:
|
||||
|
||||
```math
|
||||
\text{bytes}_i = \text{vocab}(t_i)
|
||||
```
|
||||
|
||||
**Step 3: Concatenate Bytes**
|
||||
|
||||
```math
|
||||
\text{all\_bytes} = \bigoplus_{i=1}^n \text{bytes}_i
|
||||
```
|
||||
|
||||
**Step 4: UTF-8 Decode**
|
||||
|
||||
```math
|
||||
\text{text} = \text{UTF-8}^{-1}(\text{all\_bytes})
|
||||
```
|
||||
|
||||
### 6.2 Decoding Function
|
||||
|
||||
**Mathematical Definition:**
|
||||
|
||||
```math
|
||||
\text{decode}(\mathbf{t}) = \text{UTF-8}^{-1}\left(\bigoplus_{i=1}^n \text{vocab}(t_i)\right)
|
||||
```
|
||||
|
||||
### 6.3 Decoding Flowchart
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
A["Token IDs<br/>[15496, 1917]"] --> B{Special Tokens?}
|
||||
B -->|EOS Found| C[Stop at EOS]
|
||||
B -->|No EOS| D[Process All Tokens]
|
||||
C --> D
|
||||
|
||||
D --> E[For Each Token ID]
|
||||
E --> F["Lookup Bytes<br/>vocab[t_i]"]
|
||||
F --> G["Concatenate<br/>all_bytes = bytes₁ ⊕ bytes₂ ⊕ ..."]
|
||||
|
||||
G --> H["UTF-8 Decode<br/>text = decode(all_bytes)"]
|
||||
H --> I["Output Text<br/>'Hello world'"]
|
||||
|
||||
style A fill:#e1f5ff
|
||||
style I fill:#e1ffe1
|
||||
style F fill:#fff4e1
|
||||
```
|
||||
|
||||
### 6.4 Decoding Example
|
||||
|
||||
**Input:** $[259, 1917]$
|
||||
|
||||
**Step 1: Lookup**
|
||||
|
||||
```math
|
||||
\text{vocab}(259) = \text{vocab}(258) \oplus \text{vocab}(111)
|
||||
```
|
||||
|
||||
```math
|
||||
= (\text{vocab}(257) \oplus \text{vocab}(108)) \oplus [111]
|
||||
```
|
||||
|
||||
```math
|
||||
= ((\text{vocab}(256) \oplus \text{vocab}(108)) \oplus [108]) \oplus [111]
|
||||
```
|
||||
|
||||
```math
|
||||
= (((\text{vocab}(72) \oplus \text{vocab}(101)) \oplus [108]) \oplus [108]) \oplus [111]
|
||||
```
|
||||
|
||||
```math
|
||||
= [[72] \oplus [101]] \oplus [108] \oplus [108] \oplus [111]
|
||||
```
|
||||
|
||||
```math
|
||||
= [72, 101, 108, 108, 111]
|
||||
```
|
||||
|
||||
```math
|
||||
\text{vocab}(1917) = [32, 119, 111, 114, 108, 100]
|
||||
```
|
||||
|
||||
**Step 2: Concatenate**
|
||||
|
||||
```math
|
||||
\text{all\_bytes} = [72, 101, 108, 108, 111] \oplus [32, 119, 111, 114, 108, 100]
|
||||
```
|
||||
|
||||
```math
|
||||
= [72, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100]
|
||||
```
|
||||
|
||||
**Step 3: Decode**
|
||||
|
||||
```math
|
||||
\text{decode}([72, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100]) = "Hello world"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. Regex Pattern Splitting
|
||||
|
||||
### 7.1 GPT-4 Style Pattern
|
||||
|
||||
**Pattern Components:**
|
||||
|
||||
```math
|
||||
P = P_{\text{contractions}} \cup P_{\text{letters}} \cup P_{\text{numbers}} \cup P_{\text{punctuation}} \cup P_{\text{whitespace}}
|
||||
```
|
||||
|
||||
**Pattern Definition:**
|
||||
|
||||
\[
|
||||
P = \text{'(?i:[sdmt]|ll|ve|re)} \cup \text{[^\r\n\p{L}\p{N}]?+\p{L}+} \cup \text{\p{N}{1,3}} \cup \text{?[^\s\p{L}\p{N}]++} \cup \text{\r?\n} \cup \text{\s+}
|
||||
\]
|
||||
|
||||
**Components:**
|
||||
|
||||
1. **Contractions:** \( P\_{\text{contractions}} = \text{'(?i:[sdmt]|ll|ve|re)} \)
|
||||
- Matches: `'s`, `'t`, `'ll`, `'ve`, `'re` (case-insensitive)
|
||||
|
||||
2. **Letters:** \( P\_{\text{letters}} = \text{[^\r\n\p{L}\p{N}]?+\p{L}+} \)
|
||||
- Optional space + one or more letters
|
||||
|
||||
3. **Numbers:** \( P\_{\text{numbers}} = \text{\p{N}{1,3}} \)
|
||||
- **Limit:** 1-3 digits only (prevents long number tokens)
|
||||
|
||||
4. **Punctuation:** \( P\_{\text{punctuation}} = \text{?[^\s\p{L}\p{N}]++} \)
|
||||
- Optional space + punctuation
|
||||
|
||||
5. **Whitespace:** \( P\_{\text{whitespace}} = \text{\r?\n} \cup \text{\s+} \)
|
||||
- Newlines and multiple spaces
|
||||
|
||||
### 7.2 Regex Splitting Function
|
||||
|
||||
```math
|
||||
\text{RegexSplit}(s, P) = \{m_1, m_2, \ldots, m_k : m_i \in \text{Match}(s, P)\}
|
||||
```
|
||||
|
||||
where matches are found left-to-right, non-overlapping.
|
||||
|
||||
### 7.3 Regex Splitting Diagram
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
A["Input Text<br/>'Hello world 123'"] --> B[Apply Regex Pattern]
|
||||
|
||||
B --> C1["Chunk 1: 'Hello'<br/>P_letters"]
|
||||
B --> C2["Chunk 2: ' world'<br/>P_whitespace + P_letters"]
|
||||
B --> C3["Chunk 3: ' 123'<br/>P_whitespace + P_numbers"]
|
||||
|
||||
C1 --> D["Chunks List<br/>['Hello', ' world', ' 123']"]
|
||||
C2 --> D
|
||||
C3 --> D
|
||||
|
||||
style A fill:#e1f5ff
|
||||
style D fill:#e1ffe1
|
||||
style B fill:#fff4e1
|
||||
```
|
||||
|
||||
**Example:**
|
||||
|
||||
```math
|
||||
\text{RegexSplit}("Hello world 123", P) = ["Hello", " world", " 123"]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8. Special Tokens
|
||||
|
||||
### 8.1 Special Token Set
|
||||
|
||||
**Special Tokens:**
|
||||
|
||||
```math
|
||||
V_{\text{special}} = \{\text{<pad>}, \text{<unk>}, \text{<bos>}, \text{<eos>}\}
|
||||
```
|
||||
|
||||
**Token IDs:**
|
||||
|
||||
```math
|
||||
\text{id}(\text{<pad>}) = 0, \quad \text{id}(\text{<unk>}) = 1, \quad \text{id}(\text{<bos>}) = 2, \quad \text{id}(\text{<eos>}) = 3
|
||||
```
|
||||
|
||||
### 8.2 Special Token Functions
|
||||
|
||||
**Padding:**
|
||||
|
||||
```math
|
||||
\text{pad}(\mathbf{t}, \text{max\_length}) = \mathbf{t} \oplus [\text{<pad>}]^{\max(\text{max\_length} - |\mathbf{t}|, 0)}
|
||||
```
|
||||
|
||||
**Unknown Token:**
|
||||
|
||||
```math
|
||||
\text{encode}(s) = \begin{cases}
|
||||
[\text{id}(c)] & \text{if } c \in V \\
|
||||
[\text{<unk>}] & \text{if } c \notin V
|
||||
\end{cases}
|
||||
```
|
||||
|
||||
**EOS Handling:**
|
||||
|
||||
```math
|
||||
\text{decode}(\mathbf{t}) = \text{decode}(\mathbf{t}[:i]) \text{ where } t_i = \text{<eos>}
|
||||
```
|
||||
|
||||
### 8.3 Special Token Flowchart
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Special Tokens"
|
||||
A1["<pad> → 0<br/>Padding"]
|
||||
A2["<unk> → 1<br/>Unknown"]
|
||||
A3["<bos> → 2<br/>Beginning"]
|
||||
A4["<eos> → 3<br/>End"]
|
||||
end
|
||||
|
||||
subgraph "Usage"
|
||||
B1["Padding:<br/>[1,2,3] → [1,2,3,0,0]"]
|
||||
B2["Unknown:<br/>unknown_char → 1"]
|
||||
B3["EOS Stop:<br/>[1,2,3] → stop at 3"]
|
||||
end
|
||||
|
||||
A1 --> B1
|
||||
A2 --> B2
|
||||
A4 --> B3
|
||||
|
||||
style A1 fill:#e1f5ff
|
||||
style A2 fill:#ffe1f5
|
||||
style A3 fill:#e1ffe1
|
||||
style A4 fill:#fff4e1
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 9. Complete Tokenization Pipeline
|
||||
|
||||
### 9.1 Full Pipeline
|
||||
|
||||
**Complete Tokenization Process:**
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
Start[Input Text] --> Split[Regex Split]
|
||||
Split --> UTF8[UTF-8 Encode]
|
||||
UTF8 --> BPE[BPE Merge]
|
||||
BPE --> Special{Special Tokens?}
|
||||
Special -->|Yes| Handle[Handle Special]
|
||||
Special -->|No| Combine[Combine Tokens]
|
||||
Handle --> Combine
|
||||
Combine --> Output[Token IDs]
|
||||
|
||||
Output --> Embed[Embedding Layer]
|
||||
|
||||
style Start fill:#e1f5ff
|
||||
style Output fill:#e1ffe1
|
||||
style Embed fill:#ffe1f5
|
||||
```
|
||||
|
||||
### 9.2 Mathematical Formulation
|
||||
|
||||
**Complete Encoding:**
|
||||
|
||||
```math
|
||||
\text{Tokenize}(s) = \text{HandleSpecial}\left(\bigoplus_{c \in \text{RegexSplit}(s)} \text{BPE}(\text{UTF-8}(c))\right)
|
||||
```
|
||||
|
||||
**Complete Decoding:**
|
||||
|
||||
```math
|
||||
\text{Detokenize}(\mathbf{t}) = \text{UTF-8}^{-1}\left(\bigoplus_{i=1}^n \text{vocab}(\text{RemoveSpecial}(t_i))\right)
|
||||
```
|
||||
|
||||
### 9.3 Pipeline Example
|
||||
|
||||
**Input:** "Hello world!"
|
||||
|
||||
**Step 1: Regex Split**
|
||||
|
||||
```math
|
||||
\text{chunks} = ["Hello", " world", "!"]
|
||||
```
|
||||
|
||||
**Step 2: UTF-8 Encode**
|
||||
|
||||
```math
|
||||
\text{bytes} = [[72,101,108,108,111], [32,119,111,114,108,100], [33]]
|
||||
```
|
||||
|
||||
**Step 3: BPE Merge**
|
||||
|
||||
```math
|
||||
\text{tokens} = [[15496], [1917], [0]]
|
||||
```
|
||||
|
||||
**Step 4: Combine**
|
||||
|
||||
```math
|
||||
\text{final} = [15496, 1917, 0]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 10. Tokenization Challenges and Solutions
|
||||
|
||||
### 10.1 Challenge: Long Tokens
|
||||
|
||||
**Problem:** Some tokens become very long (e.g., "defaultstyle" as single token).
|
||||
|
||||
**Mathematical Impact:**
|
||||
|
||||
```math
|
||||
\text{LongToken}(s) = \begin{cases}
|
||||
1 \text{ token} & \text{if } |s| > 10 \text{ chars} \\
|
||||
\text{multiple tokens} & \text{otherwise}
|
||||
\end{cases}
|
||||
```
|
||||
|
||||
**Solution:** Better regex splitting prevents over-merging.
|
||||
|
||||
### 10.2 Challenge: Number Tokenization
|
||||
|
||||
**Problem:** Numbers tokenized arbitrarily (sometimes 1 token, sometimes 2-3).
|
||||
|
||||
**Solution:** Limit number merging to 1-3 digits:
|
||||
|
||||
```math
|
||||
P_{\text{numbers}} = \text{\p{N}{1,3}}
|
||||
```
|
||||
|
||||
**Impact:**
|
||||
|
||||
```math
|
||||
\text{TokenCount}(n) = \begin{cases}
|
||||
1 & \text{if } n < 1000 \\
|
||||
2-3 & \text{if } n \geq 1000
|
||||
\end{cases}
|
||||
```
|
||||
|
||||
### 10.3 Challenge: Python Code Efficiency
|
||||
|
||||
**Problem:** Each space is separate token (GPT-2 issue).
|
||||
|
||||
**Solution:** Merge multiple spaces:
|
||||
|
||||
```math
|
||||
P_{\text{whitespace}} = \text{\s+} \quad \text{(matches multiple spaces)}
|
||||
```
|
||||
|
||||
**Efficiency Gain:**
|
||||
|
||||
```math
|
||||
\text{TokensBefore} = |\text{spaces}| \quad \text{(one token per space)}
|
||||
```
|
||||
|
||||
```math
|
||||
\text{TokensAfter} = \lceil |\text{spaces}| / 4 \rceil \quad \text{(grouped spaces)}
|
||||
```
|
||||
|
||||
### 10.4 Challenge: Trailing Whitespace
|
||||
|
||||
**Problem:** Trailing spaces cause poor tokenization.
|
||||
|
||||
**Detection:**
|
||||
|
||||
```math
|
||||
\text{HasTrailingSpace}(s) = \begin{cases}
|
||||
\text{True} & \text{if } s[-1] = ' ' \\
|
||||
\text{False} & \text{otherwise}
|
||||
\end{cases}
|
||||
```
|
||||
|
||||
**Warning:**
|
||||
|
||||
```math
|
||||
\text{Tokenize}(s) = \begin{cases}
|
||||
\text{encode}(s) + \text{warning} & \text{if } \text{HasTrailingSpace}(s) \\
|
||||
\text{encode}(s) & \text{otherwise}
|
||||
\end{cases}
|
||||
```
|
||||
|
||||
### 10.5 Challenge: Multilingual Support
|
||||
|
||||
**Problem:** Non-English languages tokenize inefficiently.
|
||||
|
||||
**Solution:** UTF-8 byte-level encoding handles all languages:
|
||||
|
||||
```math
|
||||
\text{Tokenize}(s) = \text{BPE}(\text{UTF-8}(s)) \quad \forall s \in \Sigma^*
|
||||
```
|
||||
|
||||
**Efficiency:**
|
||||
|
||||
```math
|
||||
\text{TokenRatio} = \frac{|\text{Tokenize}(s_{\text{non-english}})|}{|\text{Tokenize}(s_{\text{english}})|} \approx 1.5-2.0
|
||||
```
|
||||
|
||||
### 10.6 Challenge: Untrained Tokens
|
||||
|
||||
**Problem:** Some tokens never appear in training data.
|
||||
|
||||
**Solution:** Fallback handling:
|
||||
|
||||
```math
|
||||
\text{Decode}(t) = \begin{cases}
|
||||
\text{vocab}(t) & \text{if } t \in V \\
|
||||
\text{<unk>} & \text{if } t \notin V \\
|
||||
\text{fallback\_bytes} & \text{if } t < 256
|
||||
\end{cases}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
### Key Formulas
|
||||
|
||||
**Encoding:**
|
||||
|
||||
```math
|
||||
\text{encode}(s) = \text{HandleSpecial}\left(\bigoplus_{c \in \text{RegexSplit}(s)} \text{BPE}(\text{UTF-8}(c))\right)
|
||||
```
|
||||
|
||||
**Decoding:**
|
||||
|
||||
```math
|
||||
\text{decode}(\mathbf{t}) = \text{UTF-8}^{-1}\left(\bigoplus_{i=1}^n \text{vocab}(t_i)\right)
|
||||
```
|
||||
|
||||
**BPE Merge:**
|
||||
|
||||
```math
|
||||
\text{tokens}^{(k)} = \text{Merge}(\text{tokens}^{(k-1)}, \arg\max_{(i,j)} \text{stats}^{(k)}(i,j), 256+k-1)
|
||||
```
|
||||
|
||||
**Vocabulary Size:**
|
||||
|
||||
```math
|
||||
|V| = 256 + K + |V_{\text{special}}|
|
||||
```
|
||||
|
||||
### Key Takeaways
|
||||
|
||||
1. **UTF-8 Encoding**: Handles all Unicode characters consistently
|
||||
2. **BPE Algorithm**: Creates efficient vocabulary through iterative merging
|
||||
3. **Regex Splitting**: Prevents over-merging with pattern-based chunking
|
||||
4. **Special Tokens**: Control flow and handle edge cases
|
||||
5. **Byte-Level**: Works at byte level for universal language support
|
||||
165
docs/TOKENIZER_IMPROVEMENTS.md
Normal file
165
docs/TOKENIZER_IMPROVEMENTS.md
Normal file
@@ -0,0 +1,165 @@
|
||||
# Tokenizer Improvements
|
||||
|
||||
## Overview
|
||||
|
||||
Based on the tokenization challenges discussed in the transcript, we've implemented an improved BPE (Byte Pair Encoding) tokenizer that addresses common issues with tokenization in large language models.
|
||||
|
||||
## Key Improvements
|
||||
|
||||
### 1. UTF-8 Byte-Level Encoding
|
||||
- **What**: Tokenizer works at the byte level (0-255) instead of character level
|
||||
- **Why**: Handles all Unicode characters consistently, regardless of language
|
||||
- **Benefit**: Better support for non-English languages and special characters
|
||||
|
||||
### 2. GPT-4 Style Regex Pattern
|
||||
- **What**: Improved regex pattern for splitting text into chunks
|
||||
- **Improvements**:
|
||||
- Case-insensitive matching for contractions (`'s`, `'t`, `'ll`, etc.)
|
||||
- Better whitespace handling (groups multiple spaces for Python code)
|
||||
- Limits number merging to 1-3 digits (prevents overly long number tokens)
|
||||
- **Benefit**: More consistent tokenization, especially for code and numbers
|
||||
|
||||
### 3. BPE Training Algorithm
|
||||
- **What**: Full Byte Pair Encoding implementation
|
||||
- **Features**:
|
||||
- `get_stats()`: Counts consecutive token pairs
|
||||
- `merge()`: Replaces frequent pairs with single tokens
|
||||
- Iterative training process
|
||||
- **Benefit**: Creates efficient vocabulary that compresses common sequences
|
||||
|
||||
### 4. Special Token Handling
|
||||
- **What**: Proper handling of special tokens (`<pad>`, `<unk>`, `<bos>`, `<eos>`)
|
||||
- **Features**:
|
||||
- Special tokens are excluded from BPE merging
|
||||
- EOS token stops decoding
|
||||
- Configurable special token set
|
||||
- **Benefit**: Clean separation between data tokens and control tokens
|
||||
|
||||
### 5. Trailing Whitespace Detection
|
||||
- **What**: Warns when text ends with trailing whitespace
|
||||
- **Why**: Trailing spaces can cause poor tokenization (as seen in GPT-3.5 playground warnings)
|
||||
- **Benefit**: Helps developers avoid tokenization issues
|
||||
|
||||
### 6. Better Python Code Handling
|
||||
- **What**: Improved whitespace merging for indentation
|
||||
- **Why**: GPT-2 had issues with Python code because each space was a separate token
|
||||
- **Benefit**: More efficient tokenization of Python code (fewer tokens per file)
|
||||
|
||||
### 7. Number Tokenization Limits
|
||||
- **What**: Limits number merging to 1-3 digits
|
||||
- **Why**: Prevents creating tokens for very long number sequences
|
||||
- **Benefit**: Better arithmetic performance (numbers are more consistently tokenized)
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from data.example import SimpleTokenizer
|
||||
|
||||
# Create tokenizer (uses BPE by default)
|
||||
tokenizer = SimpleTokenizer(use_bpe=True, vocab_size=50257)
|
||||
|
||||
# Encode text
|
||||
tokens = tokenizer.encode("Hello world!")
|
||||
print(tokens) # [15496, 1917, 0]
|
||||
|
||||
# Decode tokens
|
||||
text = tokenizer.decode(tokens)
|
||||
print(text) # "Hello world!"
|
||||
```
|
||||
|
||||
### Training a Custom Tokenizer
|
||||
|
||||
```python
|
||||
from data.example import BPETokenizer
|
||||
|
||||
# Create tokenizer
|
||||
tokenizer = BPETokenizer(vocab_size=50257)
|
||||
|
||||
# Train on your corpus
|
||||
texts = [
|
||||
"Your training text here...",
|
||||
"More training text...",
|
||||
]
|
||||
|
||||
tokenizer.train(texts, num_merges=50000, verbose=True)
|
||||
|
||||
# Save trained tokenizer
|
||||
tokenizer.save("merges.json", "vocab.json")
|
||||
```
|
||||
|
||||
### Loading a Pre-trained Tokenizer
|
||||
|
||||
```python
|
||||
from data.example import BPETokenizer
|
||||
|
||||
# Load saved tokenizer
|
||||
tokenizer = BPETokenizer()
|
||||
tokenizer.load("merges.json", "vocab.json")
|
||||
|
||||
# Use it
|
||||
tokens = tokenizer.encode("Hello world!")
|
||||
```
|
||||
|
||||
## Addressing Common Issues
|
||||
|
||||
### Issue: "Can't spell words well"
|
||||
- **Cause**: Long tokens (like "defaultstyle" as single token)
|
||||
- **Fix**: Better regex splitting prevents over-merging
|
||||
|
||||
### Issue: "Bad at arithmetic"
|
||||
- **Cause**: Arbitrary number tokenization (sometimes 1 token, sometimes 2-3)
|
||||
- **Fix**: Limits number merging to 1-3 digits for consistency
|
||||
|
||||
### Issue: "Python code inefficient"
|
||||
- **Cause**: Each space is separate token (GPT-2 issue)
|
||||
- **Fix**: Multiple spaces merge into single tokens
|
||||
|
||||
### Issue: "Non-English languages worse"
|
||||
- **Cause**: Tokenizer trained primarily on English
|
||||
- **Fix**: UTF-8 byte-level encoding handles all languages consistently
|
||||
|
||||
### Issue: "Trailing whitespace warning"
|
||||
- **Cause**: Models see very few examples of trailing spaces
|
||||
- **Fix**: Warning helps developers detect and fix the issue
|
||||
|
||||
### Issue: "Solid gold Magikarp" (untrained tokens)
|
||||
- **Cause**: Tokenizer creates tokens for strings not in training data
|
||||
- **Fix**: Proper validation and fallback handling for unknown tokens
|
||||
|
||||
## Backward Compatibility
|
||||
|
||||
The `SimpleTokenizer` class maintains backward compatibility:
|
||||
- If `use_bpe=False`, uses character-level tokenization (old behavior)
|
||||
- If `use_bpe=True` (default), uses new BPE tokenizer
|
||||
- All existing code continues to work without changes
|
||||
|
||||
## Technical Details
|
||||
|
||||
### BPE Algorithm
|
||||
1. Start with byte-level vocabulary (256 tokens)
|
||||
2. Count consecutive token pairs in training data
|
||||
3. Find most frequent pair
|
||||
4. Merge pair into new token
|
||||
5. Repeat until target vocabulary size reached
|
||||
|
||||
### Encoding Process
|
||||
1. Split text using regex pattern
|
||||
2. Convert each chunk to UTF-8 bytes
|
||||
3. Apply BPE merges (greedy, left-to-right)
|
||||
4. Return token IDs
|
||||
|
||||
### Decoding Process
|
||||
1. Look up token IDs in vocabulary
|
||||
2. Convert bytes back to UTF-8
|
||||
3. Handle special tokens (EOS stops decoding)
|
||||
4. Return decoded text
|
||||
|
||||
## References
|
||||
|
||||
Based on improvements discussed in:
|
||||
- GPT-2 paper tokenization section
|
||||
- GPT-4 tokenizer improvements
|
||||
- Common tokenization challenges and solutions
|
||||
|
||||
1329
docs/TRAINING_EXPLAINED.md
Normal file
1329
docs/TRAINING_EXPLAINED.md
Normal file
File diff suppressed because it is too large
Load Diff
BIN
docs/images/loss_by_epoch.png
Normal file
BIN
docs/images/loss_by_epoch.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 26 KiB |
BIN
docs/images/training_curve.png
Normal file
BIN
docs/images/training_curve.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 73 KiB |
Reference in New Issue
Block a user