commit 3d2da94ce23c17fc73261b824d00a089cf9de185 Author: Carlos Gutierrez Date: Thu Nov 6 22:07:41 2025 -0500 Initial commit: SheepOp LLM - Transformer-based language model implementation - Complete transformer implementation from scratch - Training pipeline with gradient accumulation and mixed precision - Optimized inference with KV caching - Multi-format data processing (PDFs, images, code, text) - Comprehensive documentation - Apache 2.0 license - Example training plots included in docs/images/ diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c01b932 --- /dev/null +++ b/.gitignore @@ -0,0 +1,209 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +Pipfile.lock + +# poetry +poetry.lock + +# pdm +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582 +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +.idea/ + +# VS Code +.vscode/ + +# Custom entries +.cursor +papers/ +# Ignore data files but keep the data/ directory structure +data/*.txt +data/*.json +data/*.csv +data/*.db +data/*.sqlite +data/*.sqlite3 +# But NOT data/__init__.py (needed for Python module) +!data/__init__.py +# Ignore storage symlinks +data_storage +checkpoints_storage + +# OS-specific +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# Checkpoints (if you don't want to track them) +checkpoints/ +checkpoints_test/ + +# Training artifacts (discovered knowledge) +# Ignore all images except those in docs/images +*.png +*.jpg +*.jpeg +*.svg +!docs/images/ +!docs/images/**/*.png +!docs/images/**/*.jpg +!docs/images/**/*.jpeg + +# Training outputs (exclude from root, but allow in docs/images) +/training_curve.png +/loss_by_epoch.png +training_logs/ +*.log + +# Model checkpoints and weights +*.pt +*.pth +*.ckpt +*.safetensors + +# Training metrics +metrics.json +training_metrics.json +wandb/ +tensorboard_logs/ + +# Data (already covered but ensure) +data/ +data_storage/ +*.db +*.sqlite +*.sqlite3 diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..12e537f --- /dev/null +++ b/LICENSE @@ -0,0 +1,17 @@ +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +Copyright (c) 2024 Carlos Gutierrez + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..22ef520 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,202 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (which shall not include Communications that are clearly marked or + otherwise designated in writing by the copyright owner as "Not a Work"). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied. See the License for the specific language governing + permissions and limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..99217bd --- /dev/null +++ b/README.md @@ -0,0 +1,368 @@ +# SheepOp LLM šŸ‘āž”ļøšŸ¤– + +**Author:** Carlos Gutierrez +**Email:** carlos.gutierrez@carg.dev +**License:** Apache 2.0 + +A modern language model implementation from scratch, incorporating insights from recent research papers. + +--- + +## Purpose of the Project + +SheepOp LLM is a comprehensive transformer-based language model implementation designed for: + +- **Research & Education**: Understanding how large language models work from the ground up +- **Custom Training**: Training models on domain-specific data (PDFs, code, text files) +- **Production Deployment**: Optimized inference with KV caching and efficient attention mechanisms +- **Multi-Format Data Processing**: Support for various data types including PDFs, images (OCR), code files, and text + +The project provides a complete toolkit for building, training, and deploying transformer language models with modern best practices. + +--- + +## Documentation Index + +All detailed documentation is available in the [`docs/`](docs/) folder: + +### Core Concepts + +- **[Complete Guide](docs/COMPLETE_GUIDE.md)** - Full project documentation with mathematical foundations, architecture, and usage +- **[Architecture](docs/ARCHITECTURE.md)** - System architecture and design patterns +- **[Mathematics](docs/MATHEMATICS.md)** - Complete mathematical derivations for all components + +### Component Explanations + +- **[Embeddings](docs/EMBEDDINGS_EXPLAINED.md)** - What are embeddings and how they work +- **[Attention](docs/ATTENTION_EXPLAINED.md)** - Attention mechanisms explained step-by-step +- **[Feed-Forward](docs/FEED_FORWARD_EXPLAINED.md)** - Feed-forward networks explained +- **[Normalization](docs/NORMALIZATION_EXPLAINED.md)** - Layer normalization explained +- **[Neural Networks](docs/NEURAL_NETWORK_EXPLAINED.md)** - Neural networks, neurons, and weights explained + +### Training & Optimization + +- **[Training](docs/TRAINING_EXPLAINED.md)** - What is training, why we need data, why more data is better, and how to interpret training metrics +- **[Optimization](docs/OPTIMIZATION_EXPLAINED.md)** - Optimizers (AdamW, gradient descent) explained +- **[Scheduling](docs/SCHEDULING_EXPLAINED.md)** - Learning rate scheduling explained +- **[Generation](docs/GENERATION_EXPLAINED.md)** - Text generation and sampling strategies + +### Data & Processing + +- **[Data Processing](docs/DATA_PROCESSING_EXPLAINED.md)** - How data processing works step-by-step +- **[Multi-Format Data Guide](docs/MULTI_FORMAT_DATA_GUIDE.md)** - Working with PDFs, images, code files +- **[Data Guide](docs/DATA_GUIDE.md)** - General data handling guide +- **[Database Extraction Guide](docs/DATABASE_EXTRACTION_GUIDE.md)** - Extracting data from databases +- **[Repository Download Guide](docs/REPOSITORY_DOWNLOAD_GUIDE.md)** - Automatically downloading GitHub repositories for code training + +### Advanced Topics + +- **[Control System Model](docs/CONTROL_SYSTEM_MODEL.md)** - Mathematical control system formulation +- **[Optimizations](docs/OPTIMIZATIONS.md)** - Performance optimizations +- **[Retraining Guide](docs/RETRAINING_GUIDE.md)** - How to retrain models + +--- + +## Common Questions + +### Getting Started + +**Q: How do I get started with this project?** +**A:** See [Complete Guide](docs/COMPLETE_GUIDE.md) - Quick Start section + +**Q: What do I need to install?** +**A:** See [Complete Guide](docs/COMPLETE_GUIDE.md) - Installation section + +**Q: How do I train my first model?** +**A:** See [Complete Guide](docs/COMPLETE_GUIDE.md) - Usage section + +### Understanding Concepts + +**Q: What are embeddings?** +**A:** See [Embeddings Explained](docs/EMBEDDINGS_EXPLAINED.md) + +**Q: How does attention work?** +**A:** See [Attention Explained](docs/ATTENTION_EXPLAINED.md) + +**Q: What is a feed-forward network?** +**A:** See [Feed-Forward Explained](docs/FEED_FORWARD_EXPLAINED.md) + +**Q: Why do we need normalization?** +**A:** See [Normalization Explained](docs/NORMALIZATION_EXPLAINED.md) + +**Q: How do neural networks work?** +**A:** See [Neural Network Explained](docs/NEURAL_NETWORK_EXPLAINED.md) + +**Q: What is a neuron and what are weights?** +**A:** See [Neural Network Explained](docs/NEURAL_NETWORK_EXPLAINED.md) + +### Training Questions + +**Q: What is training and why do we need it?** +**A:** See [Training Explained](docs/TRAINING_EXPLAINED.md) + +**Q: Why do we need data for training?** +**A:** See [Training Explained](docs/TRAINING_EXPLAINED.md) - Why Do We Need Data section + +**Q: Why is more data better?** +**A:** See [Training Explained](docs/TRAINING_EXPLAINED.md) - Why More Data is Better section + +**Q: How does the optimizer work?** +**A:** See [Optimization Explained](docs/OPTIMIZATION_EXPLAINED.md) + +**Q: What is learning rate scheduling?** +**A:** See [Scheduling Explained](docs/SCHEDULING_EXPLAINED.md) + +### Data Questions + +**Q: How does data processing work?** +**A:** See [Data Processing Explained](docs/DATA_PROCESSING_EXPLAINED.md) + +**Q: Can I train on PDFs?** +**A:** See [Multi-Format Data Guide](docs/MULTI_FORMAT_DATA_GUIDE.md) + +**Q: Can I train on images?** +**A:** See [Multi-Format Data Guide](docs/MULTI_FORMAT_DATA_GUIDE.md) + +**Q: How do I process different file types?** +**A:** See [Data Processing Explained](docs/DATA_PROCESSING_EXPLAINED.md) + +**Q: How do I download code repositories automatically?** +**A:** See [Repository Download Guide](docs/REPOSITORY_DOWNLOAD_GUIDE.md) + +### Generation Questions + +**Q: How does text generation work?** +**A:** See [Generation Explained](docs/GENERATION_EXPLAINED.md) + +**Q: What is temperature in generation?** +**A:** See [Generation Explained](docs/GENERATION_EXPLAINED.md) - Temperature section + +**Q: What is top-k and top-p sampling?** +**A:** See [Generation Explained](docs/GENERATION_EXPLAINED.md) - Top-k and Top-p sections + +### Mathematical Questions + +**Q: What are the mathematical foundations?** +**A:** See [Mathematics](docs/MATHEMATICS.md) or [Complete Guide](docs/COMPLETE_GUIDE.md) - Mathematical Foundations section + +**Q: How do I understand the complete mathematical model?** +**A:** See [Mathematics](docs/MATHEMATICS.md) for step-by-step derivations + +**Q: Is there a control system perspective?** +**A:** See [Control System Model](docs/CONTROL_SYSTEM_MODEL.md) + +### Architecture Questions + +**Q: How is the architecture designed?** +**A:** See [Architecture](docs/ARCHITECTURE.md) + +**Q: What is the complete system flow?** +**A:** See [Complete Guide](docs/COMPLETE_GUIDE.md) - Architecture Explained section + +### Advanced Questions + +**Q: How do I optimize inference?** +**A:** See [Optimizations](docs/OPTIMIZATIONS.md) + +**Q: How do I retrain a model?** +**A:** See [Retraining Guide](docs/RETRAINING_GUIDE.md) + +**Q: How do I extract data from databases?** +**A:** See [Database Extraction Guide](docs/DATABASE_EXTRACTION_GUIDE.md) + +**Q: How do I download GitHub repositories for code training?** +**A:** See [Repository Download Guide](docs/REPOSITORY_DOWNLOAD_GUIDE.md) + +--- + +## Glossary + +### A + +**AdamW** - Advanced optimizer combining adaptive learning rates with weight decay. See [Optimization Explained](docs/OPTIMIZATION_EXPLAINED.md) + +**Attention** - Mechanism that determines how much each word should consider other words. See [Attention Explained](docs/ATTENTION_EXPLAINED.md) + +**Autoregressive** - Generation method where the model uses its own previous outputs as inputs. See [Generation Explained](docs/GENERATION_EXPLAINED.md) + +### B + +**Batch** - Small group of examples processed together during training. See [Training Explained](docs/TRAINING_EXPLAINED.md) + +**Bias** - Constant added to weighted sum in neural networks. See [Neural Network Explained](docs/NEURAL_NETWORK_EXPLAINED.md) + +**Backpropagation** - Algorithm for computing gradients through the network. See [Training Explained](docs/TRAINING_EXPLAINED.md) + +### C + +**Causal Masking** - Prevents tokens from attending to future tokens. See [Complete Guide](docs/COMPLETE_GUIDE.md) + +**Cosine Annealing** - Learning rate schedule that follows a cosine curve. See [Scheduling Explained](docs/SCHEDULING_EXPLAINED.md) + +**Cross-Entropy Loss** - Loss function for classification tasks. See [Mathematics](docs/MATHEMATICS.md) + +### D + +**Data Processing** - Transformation of raw files into training-ready text. See [Data Processing Explained](docs/DATA_PROCESSING_EXPLAINED.md) + +**Dropout** - Regularization technique that randomly sets activations to zero. See [Complete Guide](docs/COMPLETE_GUIDE.md) + +**Decoder** - Part of transformer that generates output. See [Architecture](docs/ARCHITECTURE.md) + +### E + +**Embedding** - Numerical representation of words/tokens. See [Embeddings Explained](docs/EMBEDDINGS_EXPLAINED.md) + +**Epoch** - One complete pass through the training data. See [Training Explained](docs/TRAINING_EXPLAINED.md) + +**Evaluation** - Process of measuring model performance. See [Training Explained](docs/TRAINING_EXPLAINED.md) + +### F + +**Feed-Forward Network (FFN)** - Two-layer neural network that transforms features. See [Feed-Forward Explained](docs/FEED_FORWARD_EXPLAINED.md) + +**Forward Pass** - Computing predictions from inputs through the model. See [Neural Network Explained](docs/NEURAL_NETWORK_EXPLAINED.md) + +### G + +**GELU** - Gaussian Error Linear Unit activation function. See [Feed-Forward Explained](docs/FEED_FORWARD_EXPLAINED.md) + +**Generation** - Process of creating new text from a trained model. See [Generation Explained](docs/GENERATION_EXPLAINED.md) + +**Gradient** - Derivative of loss with respect to parameters. See [Optimization Explained](docs/OPTIMIZATION_EXPLAINED.md) + +**Gradient Clipping** - Technique to prevent exploding gradients. See [Complete Guide](docs/COMPLETE_GUIDE.md) + +**Gradient Descent** - Basic optimization algorithm. See [Optimization Explained](docs/OPTIMIZATION_EXPLAINED.md) + +### H + +**Hidden State** - Intermediate representation in the model. See [Architecture](docs/ARCHITECTURE.md) + +### L + +**Layer Normalization** - Normalization technique applied per layer. See [Normalization Explained](docs/NORMALIZATION_EXPLAINED.md) + +**Learning Rate** - Step size for weight updates. See [Optimization Explained](docs/OPTIMIZATION_EXPLAINED.md) + +**Logits** - Raw scores before applying softmax. See [Generation Explained](docs/GENERATION_EXPLAINED.md) + +**Loss** - Measure of prediction error. See [Training Explained](docs/TRAINING_EXPLAINED.md) + +### M + +**Multi-Head Attention** - Attention mechanism with multiple parallel heads. See [Attention Explained](docs/ATTENTION_EXPLAINED.md) + +**Momentum** - Technique to accelerate gradient descent. See [Optimization Explained](docs/OPTIMIZATION_EXPLAINED.md) + +### N + +**Neural Network** - Computational model inspired by biological neurons. See [Neural Network Explained](docs/NEURAL_NETWORK_EXPLAINED.md) + +**Neuron** - Basic processing unit in neural networks. See [Neural Network Explained](docs/NEURAL_NETWORK_EXPLAINED.md) + +**Normalization** - Technique to standardize activations. See [Normalization Explained](docs/NORMALIZATION_EXPLAINED.md) + +**Nucleus Sampling (Top-p)** - Sampling strategy keeping tokens with cumulative probability ≄ p. See [Generation Explained](docs/GENERATION_EXPLAINED.md) + +### O + +**Optimization** - Process of finding optimal weights. See [Optimization Explained](docs/OPTIMIZATION_EXPLAINED.md) + +**Optimizer** - Algorithm that updates model weights. See [Optimization Explained](docs/OPTIMIZATION_EXPLAINED.md) + +**Overfitting** - Model memorizes training data but doesn't generalize. See [Training Explained](docs/TRAINING_EXPLAINED.md) + +### P + +**Perplexity** - Measure of model uncertainty (exp(loss)). See [Mathematics](docs/MATHEMATICS.md) + +**Positional Encoding** - Adds position information to embeddings. See [Complete Guide](docs/COMPLETE_GUIDE.md) + +**Pre-norm** - Architecture where normalization comes before sublayers. See [Architecture](docs/ARCHITECTURE.md) + +**Probability Distribution** - Distribution over possible next tokens. See [Generation Explained](docs/GENERATION_EXPLAINED.md) + +### Q + +**Query (Q)** - One of three representations in attention (what am I looking for?). See [Attention Explained](docs/ATTENTION_EXPLAINED.md) + +### R + +**Residual Connection** - Skip connection that adds input to output. See [Architecture](docs/ARCHITECTURE.md) + +### S + +**Sampling** - Process of selecting a token from probability distribution. See [Generation Explained](docs/GENERATION_EXPLAINED.md) + +**Scheduling** - Adjusting learning rate during training. See [Scheduling Explained](docs/SCHEDULING_EXPLAINED.md) + +**Self-Attention** - Attention mechanism where queries, keys, and values come from same input. See [Attention Explained](docs/ATTENTION_EXPLAINED.md) + +**Softmax** - Function that converts logits to probabilities. See [Generation Explained](docs/GENERATION_EXPLAINED.md) + +### T + +**Temperature** - Parameter controlling randomness in sampling. See [Generation Explained](docs/GENERATION_EXPLAINED.md) + +**Token** - Basic unit of text (word or character). See [Neural Network Explained](docs/NEURAL_NETWORK_EXPLAINED.md) + +**Tokenization** - Process of converting text to tokens. See [Data Processing Explained](docs/DATA_PROCESSING_EXPLAINED.md) + +**Top-k Sampling** - Sampling strategy keeping only top k tokens. See [Generation Explained](docs/GENERATION_EXPLAINED.md) + +**Top-p Sampling** - Another name for nucleus sampling. See [Generation Explained](docs/GENERATION_EXPLAINED.md) + +**Transformer** - Neural network architecture based on attention. See [Architecture](docs/ARCHITECTURE.md) + +**Training** - Process of teaching model to make predictions. See [Training Explained](docs/TRAINING_EXPLAINED.md) + +### V + +**Value (V)** - One of three representations in attention (what information do I contain?). See [Attention Explained](docs/ATTENTION_EXPLAINED.md) + +**Vocabulary** - Set of all possible tokens. See [Embeddings Explained](docs/EMBEDDINGS_EXPLAINED.md) + +### W + +**Weight** - Parameter in neural network that controls connection strength. See [Neural Network Explained](docs/NEURAL_NETWORK_EXPLAINED.md) + +**Weight Decay** - Regularization technique that penalizes large weights. See [Optimization Explained](docs/OPTIMIZATION_EXPLAINED.md) + +**Weight Matrix** - Matrix containing all weights for a layer. See [Neural Network Explained](docs/NEURAL_NETWORK_EXPLAINED.md) + +--- + +## Quick Links + +- **Complete Documentation**: [docs/COMPLETE_GUIDE.md](docs/COMPLETE_GUIDE.md) +- **Mathematical Foundations**: [docs/MATHEMATICS.md](docs/MATHEMATICS.md) +- **System Architecture**: [docs/ARCHITECTURE.md](docs/ARCHITECTURE.md) +- **Control System Model**: [docs/CONTROL_SYSTEM_MODEL.md](docs/CONTROL_SYSTEM_MODEL.md) + +--- + +## License + +This project is licensed under the **Apache License 2.0**. + +See [LICENSE](LICENSE) or [LICENSE.txt](LICENSE.txt) for the full license text. + +**Summary:** +- āœ… Free to use, modify, and distribute +- āœ… Commercial use allowed +- āœ… Patent grant included +- āœ… Private use allowed +- āš ļø Must include license and copyright notice +- āš ļø Must state changes if modifying + +--- + +## Contact + +**Carlos Gutierrez** +Email: carlos.gutierrez@carg.dev + +--- + +*This README serves as an index to the comprehensive documentation available in the `docs/` folder.* diff --git a/benchmark_batch.py b/benchmark_batch.py new file mode 100755 index 0000000..ad8848d --- /dev/null +++ b/benchmark_batch.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 +""" +Batch benchmark script for running multiple prompts and creating trends. +Collects data across multiple prompts for research analysis. +""" +import subprocess +import argparse +import json +from pathlib import Path +import time +from typing import List + + +def run_benchmark( + checkpoint: str, + prompt: str, + max_length: int = 100, + temperature: float = 1.0, + device: str = 'cuda', + benchmark_dir: str = './inference_benchmarks', + extra_args: List[str] = None, +): + """Run a single benchmark.""" + cmd = [ + 'python', 'inference.py', + '--checkpoint', checkpoint, + '--prompt', prompt, + '--max-length', str(max_length), + '--temperature', str(temperature), + '--device', device, + '--benchmark', + '--benchmark-dir', benchmark_dir, + ] + + if extra_args: + cmd.extend(extra_args) + + print(f"\n{'='*70}") + print(f"Running benchmark: {prompt[:50]}...") + print(f"{'='*70}") + + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + print(f"āŒ Error running benchmark:") + print(result.stderr) + return False + + print(result.stdout) + return True + + +def load_prompts_from_file(prompt_file: str) -> List[str]: + """Load prompts from a text file (one prompt per line).""" + with open(prompt_file, 'r', encoding='utf-8') as f: + prompts = [line.strip() for line in f if line.strip()] + return prompts + + +def main(): + parser = argparse.ArgumentParser( + description='Run multiple benchmarks with different prompts to create trends' + ) + parser.add_argument( + '--checkpoint', type=str, required=True, + help='Path to model checkpoint' + ) + parser.add_argument( + '--prompts', type=str, nargs='+', + help='List of prompts to benchmark' + ) + parser.add_argument( + '--prompt-file', type=str, + help='File containing prompts (one per line)' + ) + parser.add_argument( + '--max-length', type=int, default=100, + help='Maximum generation length' + ) + parser.add_argument( + '--temperature', type=float, default=1.0, + help='Sampling temperature' + ) + parser.add_argument( + '--device', type=str, default='cuda', + help='Device to use' + ) + parser.add_argument( + '--benchmark-dir', type=str, default='./inference_benchmarks', + help='Directory to save benchmark results' + ) + parser.add_argument( + '--delay', type=float, default=1.0, + help='Delay between benchmarks (seconds)' + ) + + args = parser.parse_args() + + # Collect prompts + prompts = [] + + if args.prompt_file: + prompts.extend(load_prompts_from_file(args.prompt_file)) + print(f"šŸ“ Loaded {len(prompts)} prompts from {args.prompt_file}") + + if args.prompts: + prompts.extend(args.prompts) + print(f"šŸ“ Added {len(args.prompts)} prompts from command line") + + if not prompts: + print("āŒ No prompts provided! Use --prompts or --prompt-file") + return + + print(f"\nāœ… Total prompts to benchmark: {len(prompts)}") + print(f"šŸ“Š Results will be saved to: {args.benchmark_dir}") + print(f"ā±ļø Delay between runs: {args.delay}s\n") + + # Run benchmarks + successful = 0 + failed = 0 + + for i, prompt in enumerate(prompts, 1): + print(f"\n[{i}/{len(prompts)}] Processing prompt...") + + success = run_benchmark( + checkpoint=args.checkpoint, + prompt=prompt, + max_length=args.max_length, + temperature=args.temperature, + device=args.device, + benchmark_dir=args.benchmark_dir, + ) + + if success: + successful += 1 + else: + failed += 1 + + # Delay between runs + if i < len(prompts): + time.sleep(args.delay) + + # Summary + print(f"\n{'='*70}") + print("BATCH BENCHMARK SUMMARY") + print(f"{'='*70}") + print(f"āœ… Successful: {successful}/{len(prompts)}") + print(f"āŒ Failed: {failed}/{len(prompts)}") + print(f"\nšŸ“Š All data saved to: {args.benchmark_dir}") + print(f" - JSON metrics: {args.benchmark_dir}/inference_metrics.json") + print(f" - CSV export: {args.benchmark_dir}/inference_metrics.csv") + print(f" - Comparison plots: {args.benchmark_dir}/optimization_comparison.png") + print(f" - Trend plot: {args.benchmark_dir}/performance_over_time.png") + + # Load and show summary + metrics_file = Path(args.benchmark_dir) / 'inference_metrics.json' + if metrics_file.exists(): + with open(metrics_file, 'r') as f: + metrics = json.load(f) + + runs = metrics.get('runs', []) + optimized_runs = [r for r in runs if r['optimized']] + non_optimized_runs = [r for r in runs if not r['optimized']] + + if optimized_runs and non_optimized_runs: + avg_optimized = sum(r['tokens_per_second'] for r in optimized_runs) / len(optimized_runs) + avg_non_optimized = sum(r['tokens_per_second'] for r in non_optimized_runs) / len(non_optimized_runs) + speedup = avg_optimized / avg_non_optimized if avg_non_optimized > 0 else 0 + + print(f"\nšŸ“ˆ OVERALL PERFORMANCE:") + print(f" Average Optimized: {avg_optimized:.2f} tokens/sec") + print(f" Average Non-Optimized: {avg_non_optimized:.2f} tokens/sec") + print(f" Overall Speedup: {speedup:.2f}x") + + +if __name__ == '__main__': + main() + diff --git a/checkpoints b/checkpoints new file mode 120000 index 0000000..c6c244f --- /dev/null +++ b/checkpoints @@ -0,0 +1 @@ +/mnt/storage/sheepOp/checkpoints \ No newline at end of file diff --git a/config.json b/config.json new file mode 100644 index 0000000..e4fa2b0 --- /dev/null +++ b/config.json @@ -0,0 +1,36 @@ +{ + "model": { + "vocab_size": 50257, + "d_model": 512, + "num_layers": 6, + "num_heads": 8, + "d_ff": 2048, + "max_seq_len": 512, + "dropout": 0.1, + "activation": "gelu", + "layer_norm_eps": 1e-5, + "bias": false, + "tie_weights": true + }, + "training": { + "batch_size": 8, + "max_epochs": 50, + "learning_rate": 1e-4, + "weight_decay": 0.01, + "warmup_steps": 1000, + "max_grad_norm": 1.0, + "gradient_accumulation_steps": 16, + "use_amp": true, + "save_dir": "./checkpoints", + "log_interval": 50, + "eval_interval": 500 + }, + "data": { + "data_dir": "./data", + "max_length": 512, + "stride": null, + "num_workers": 12 + }, + "device": "cuda", + "seed": 42 +} diff --git a/config.py b/config.py new file mode 100644 index 0000000..0f4b8a7 --- /dev/null +++ b/config.py @@ -0,0 +1,96 @@ +""" +Configuration management +""" +import json +from pathlib import Path +from dataclasses import dataclass, asdict +from typing import Optional + + +@dataclass +class ModelConfig: + """Model configuration.""" + vocab_size: int = 50257 + d_model: int = 512 + num_layers: int = 6 + num_heads: int = 8 + d_ff: int = 2048 + max_seq_len: int = 512 + dropout: float = 0.1 + activation: str = 'gelu' + layer_norm_eps: float = 1e-5 + bias: bool = False + tie_weights: bool = True + + +@dataclass +class TrainingConfig: + """Training configuration.""" + batch_size: int = 32 + max_epochs: int = 10 + learning_rate: float = 1e-4 + weight_decay: float = 0.01 + warmup_steps: int = 1000 + max_grad_norm: float = 1.0 + gradient_accumulation_steps: int = 1 + use_amp: bool = True + save_dir: str = './checkpoints' + log_interval: int = 100 + eval_interval: int = 1000 + + +@dataclass +class DataConfig: + """Data configuration.""" + data_dir: str = './data' + max_length: int = 512 + stride: Optional[int] = None + num_workers: int = 0 + + +@dataclass +class Config: + """Complete configuration.""" + model: ModelConfig + training: TrainingConfig + data: DataConfig + device: str = 'cuda' + seed: int = 42 + + @classmethod + def from_json(cls, config_path: str) -> 'Config': + """Load configuration from JSON file.""" + with open(config_path, 'r') as f: + config_dict = json.load(f) + + return cls( + model=ModelConfig(**config_dict.get('model', {})), + training=TrainingConfig(**config_dict.get('training', {})), + data=DataConfig(**config_dict.get('data', {})), + device=config_dict.get('device', 'cuda'), + seed=config_dict.get('seed', 42), + ) + + def to_json(self, config_path: str): + """Save configuration to JSON file.""" + config_dict = { + 'model': asdict(self.model), + 'training': asdict(self.training), + 'data': asdict(self.data), + 'device': self.device, + 'seed': self.seed, + } + + with open(config_path, 'w') as f: + json.dump(config_dict, f, indent=2) + + +def get_default_config() -> Config: + """Get default configuration.""" + return Config( + model=ModelConfig(), + training=TrainingConfig(), + data=DataConfig(), + ) + + diff --git a/config_cuda_8gb.json b/config_cuda_8gb.json new file mode 100644 index 0000000..166b19e --- /dev/null +++ b/config_cuda_8gb.json @@ -0,0 +1,37 @@ +{ + "model": { + "vocab_size": 50257, + "d_model": 512, + "num_layers": 6, + "num_heads": 8, + "d_ff": 2048, + "max_seq_len": 512, + "dropout": 0.1, + "activation": "gelu", + "layer_norm_eps": 1e-5, + "bias": false, + "tie_weights": true + }, + "training": { + "batch_size": 8, + "max_epochs": 50, + "learning_rate": 1e-4, + "weight_decay": 0.01, + "warmup_steps": 1000, + "max_grad_norm": 1.0, + "gradient_accumulation_steps": 16, + "use_amp": true, + "save_dir": "./checkpoints", + "log_interval": 50, + "eval_interval": 500 + }, + "data": { + "data_dir": "./data", + "max_length": 384, + "stride": null, + "num_workers": 0 + }, + "device": "cuda", + "seed": 42 +} + diff --git a/data b/data new file mode 120000 index 0000000..0394143 --- /dev/null +++ b/data @@ -0,0 +1 @@ +/mnt/storage/sheepOp/data \ No newline at end of file diff --git a/data.example/__init__.py b/data.example/__init__.py new file mode 100644 index 0000000..9d2df87 --- /dev/null +++ b/data.example/__init__.py @@ -0,0 +1,924 @@ +""" +Data loading and preprocessing utilities +Includes comprehensive data processor for multiple file types (PDFs, images, code, text, etc.) +""" +import torch +from torch.utils.data import Dataset, DataLoader +from typing import List, Dict, Optional, Iterator +import json +from pathlib import Path +import logging +from tqdm import tqdm +import hashlib +import pickle +import os +import sys + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Ensure logging output is unbuffered +for handler in logger.handlers: + handler.flush() +# Also ensure root logger handlers are unbuffered +for handler in logging.root.handlers: + handler.flush() + + +class TextDataset(Dataset): + """ + Dataset for text data. + """ + + def __init__( + self, + texts: List[str], + tokenizer, + max_length: int = 512, + stride: Optional[int] = None, + ): + """ + Args: + texts: List of text strings + tokenizer: Tokenizer instance + max_length: Maximum sequence length + stride: Stride for sliding window (if None, no overlap) + """ + self.texts = texts + self.tokenizer = tokenizer + self.max_length = max_length + self.stride = stride if stride is not None else max_length + + # Tokenize all texts + self.sequences = self._prepare_sequences() + + def _prepare_sequences(self) -> List[torch.Tensor]: + """Tokenize and chunk sequences.""" + sequences = [] + + for text in self.texts: + # Tokenize text + tokens = self.tokenizer.encode(text) + + # Chunk into sequences of max_length + for i in range(0, len(tokens), self.stride): + chunk = tokens[i:i + self.max_length] + + # Pad if necessary + if len(chunk) < self.max_length: + chunk = chunk + [self.tokenizer.pad_token_id] * (self.max_length - len(chunk)) + + sequences.append(torch.tensor(chunk, dtype=torch.long)) + + # Stop if we've covered the entire sequence + if i + self.max_length >= len(tokens): + break + + return sequences + + def __len__(self) -> int: + return len(self.sequences) + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + sequence = self.sequences[idx] + + # Input is all tokens except the last one + input_ids = sequence[:-1] + # Labels are all tokens except the first one (shifted by 1) + labels = sequence[1:] + + return { + 'input_ids': input_ids, + 'labels': labels, + } + + +class SimpleTokenizer: + """ + Simple character-level tokenizer (for backward compatibility). + Uses BPE tokenizer by default if available, falls back to character-level. + """ + + def __init__( + self, + vocab_file: Optional[str] = None, + use_bpe: bool = True, + vocab_size: int = 50257, + ): + """ + Args: + vocab_file: Optional path to vocabulary file + use_bpe: Whether to use BPE tokenizer (default: True) + vocab_size: Vocabulary size for BPE tokenizer (default: 50257) + """ + self.use_bpe = use_bpe + + # Try to use BPE tokenizer if available + if use_bpe: + try: + from .bpe_tokenizer import BPETokenizer + self.bpe_tokenizer = BPETokenizer(vocab_size=vocab_size) + self._use_bpe = True + + # Map BPE tokenizer attributes + self.pad_token_id = self.bpe_tokenizer.pad_token_id + self.unk_token_id = self.bpe_tokenizer.unk_token_id + self.bos_token_id = self.bpe_tokenizer.bos_token_id + self.eos_token_id = self.bpe_tokenizer.eos_token_id + self.vocab_size = self.bpe_tokenizer.vocab_size + self.vocab = {i: self.bpe_tokenizer.vocab.get(i, bytes([i])).decode('utf-8', errors='replace') + for i in range(256)} # Limited vocab view + self.inv_vocab = {v: k for k, v in self.vocab.items()} + return + except ImportError: + logger.warning("BPE tokenizer not available, falling back to character-level") + self._use_bpe = False + + # Fallback to character-level tokenizer + self._use_bpe = False + + if vocab_file and Path(vocab_file).exists(): + with open(vocab_file, 'r') as f: + vocab = json.load(f) + self.vocab = vocab + self.inv_vocab = {v: k for k, v in vocab.items()} + else: + # Default: character-level vocabulary + self.vocab = { + '': 0, + '': 1, + '': 2, + '': 3, + } + # Add printable ASCII characters + for i in range(32, 127): + self.vocab[chr(i)] = len(self.vocab) + + self.inv_vocab = {v: k for k, v in self.vocab.items()} + + self.pad_token_id = self.vocab.get('', 0) + self.unk_token_id = self.vocab.get('', 1) + self.bos_token_id = self.vocab.get('', 2) + self.eos_token_id = self.vocab.get('', 3) + + self.vocab_size = len(self.vocab) + + def encode(self, text: str) -> List[int]: + """Encode text to token IDs.""" + if self._use_bpe: + return self.bpe_tokenizer.encode(text) + + # Character-level encoding + tokens = [] + for char in text: + tokens.append(self.vocab.get(char, self.unk_token_id)) + return tokens + + def decode(self, token_ids: List[int]) -> str: + """Decode token IDs to text.""" + if self._use_bpe: + return self.bpe_tokenizer.decode(token_ids) + + # Character-level decoding + chars = [] + for tid in token_ids: + if tid in self.inv_vocab: + char = self.inv_vocab[tid] + if char not in ['', '', '']: + chars.append(char) + return ''.join(chars) + + def save_vocab(self, vocab_file: str): + """Save vocabulary to file.""" + if self._use_bpe: + # Save BPE tokenizer + merges_file = str(vocab_file).replace('.json', '_merges.json') + self.bpe_tokenizer.save(merges_file, vocab_file) + else: + # Save character-level vocab + with open(vocab_file, 'w') as f: + json.dump(self.vocab, f, indent=2) + + def train(self, texts: List[str], num_merges: Optional[int] = None, verbose: bool = False): + """Train the tokenizer on texts (BPE only).""" + if self._use_bpe: + self.bpe_tokenizer.train(texts, num_merges=num_merges, verbose=verbose) + # Update vocab size + self.vocab_size = self.bpe_tokenizer.vocab_size + else: + logger.warning("Training not supported for character-level tokenizer") + + +def create_dataloader( + texts: List[str], + tokenizer, + batch_size: int = 32, + max_length: int = 512, + shuffle: bool = True, + num_workers: int = 0, +) -> DataLoader: + """ + Create a DataLoader for text data. + + Args: + texts: List of text strings + tokenizer: Tokenizer instance + batch_size: Batch size + max_length: Maximum sequence length + shuffle: Whether to shuffle data + num_workers: Number of data loading workers + + Returns: + DataLoader instance + """ + dataset = TextDataset( + texts=texts, + tokenizer=tokenizer, + max_length=max_length, + ) + + def collate_fn(batch): + """Collate function for batching.""" + input_ids = torch.stack([item['input_ids'] for item in batch]) + labels = torch.stack([item['labels'] for item in batch]) + + return { + 'input_ids': input_ids, + 'labels': labels, + } + + return DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + collate_fn=collate_fn, + pin_memory=torch.cuda.is_available(), + ) + + +# ============================================================================ +# Data Processor for Multiple File Types +# ============================================================================ + +class DataProcessor: + """ + Process various file types and extract text for training. + Supports: PDFs, images (OCR), code files, text files, and more. + """ + + # Supported file extensions + TEXT_EXTENSIONS = {'.txt', '.md', '.rst', '.log', '.csv', '.json', '.jsonl', '.xml', '.html', '.htm'} + CODE_EXTENSIONS = { + '.py', '.js', '.ts', '.jsx', '.tsx', '.java', '.cpp', '.c', '.h', '.hpp', + '.cs', '.go', '.rs', '.rb', '.php', '.swift', '.kt', '.scala', '.r', + '.sql', '.sh', '.bash', '.zsh', '.fish', '.yaml', '.yml', '.toml', + '.ini', '.cfg', '.conf', '.vue', '.svelte', '.dart', '.lua', '.pl', + '.hs', '.ml', '.mli', '.elm', '.ex', '.exs', '.jl', '.clj', '.cljs' + } + PDF_EXTENSIONS = {'.pdf'} + IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.tif', '.webp'} + + def __init__(self, use_ocr: bool = True, use_pdf_extraction: bool = True, cache_dir: Optional[Path] = None): + """ + Initialize data processor. + + Args: + use_ocr: Whether to use OCR for images (requires pytesseract) + use_pdf_extraction: Whether to extract text from PDFs (requires PyPDF2 or pdfplumber) + cache_dir: Directory to store cache files (default: .cache in data directory) + """ + self.use_ocr = use_ocr + self.use_pdf_extraction = use_pdf_extraction + self.cache_dir = Path(cache_dir) if cache_dir else None + self._check_dependencies() + + def _get_cache_dir(self, directory: Path) -> Path: + """Get cache directory for a given data directory.""" + if self.cache_dir: + return self.cache_dir + # Default: .cache in the data directory + cache_dir = directory / '.cache' + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir + + def _compute_directory_hash(self, directory: Path, recursive: bool = True) -> str: + """ + Compute a hash of directory contents to detect changes. + Uses file paths and modification times. + """ + directory = Path(directory) + file_info = [] + + pattern = '**/*' if recursive else '*' + scanned_count = 0 + + try: + for file_path in directory.glob(pattern): + scanned_count += 1 + + # Progress feedback every 5000 files (hash computation can be slow) + if scanned_count % 5000 == 0: + logger.info(f"Computing directory hash: scanned {scanned_count:,} paths...") + sys.stderr.flush() + + try: + if file_path.is_file(): + stat = file_path.stat() + file_info.append(f"{file_path.relative_to(directory)}:{stat.st_mtime}:{stat.st_size}") + except (OSError, PermissionError): + continue + except KeyboardInterrupt: + logger.warning(f"Directory hash computation interrupted after scanning {scanned_count:,} paths") + raise + except KeyboardInterrupt: + # Re-raise to allow graceful handling upstream + logger.warning("Directory hash computation interrupted. Will skip cache and do fresh scan.") + raise + + # Sort for consistent hashing + file_info.sort() + content = '\n'.join(file_info) + return hashlib.md5(content.encode()).hexdigest() + + def _get_cache_path(self, directory: Path, cache_type: str = 'files') -> Path: + """Get cache file path for a directory.""" + cache_dir = self._get_cache_dir(directory) + # Create a safe filename from directory path + dir_hash = hashlib.md5(str(directory.absolute()).encode()).hexdigest()[:8] + return cache_dir / f"{cache_type}_{dir_hash}.pkl" + + def _load_cache(self, cache_path: Path) -> Optional[Dict]: + """Load cache from file.""" + if not cache_path.exists(): + return None + try: + with open(cache_path, 'rb') as f: + return pickle.load(f) + except Exception as e: + logger.warning(f"Failed to load cache from {cache_path}: {e}") + return None + + def _save_cache(self, cache_path: Path, data: Dict): + """Save cache to file.""" + try: + cache_path.parent.mkdir(parents=True, exist_ok=True) + with open(cache_path, 'wb') as f: + pickle.dump(data, f) + except Exception as e: + logger.warning(f"Failed to save cache to {cache_path}: {e}") + + def clear_cache(self, directory: Path): + """ + Clear cache for a directory. + + Args: + directory: Directory path to clear cache for + """ + cache_path = self._get_cache_path(directory, 'files') + if cache_path.exists(): + try: + cache_path.unlink() + logger.info(f"āœ… Cleared cache for {directory}") + except Exception as e: + logger.warning(f"Failed to clear cache: {e}") + else: + logger.info(f"No cache found for {directory}") + + def _check_dependencies(self): + """Check if required dependencies are available.""" + if self.use_ocr: + try: + import pytesseract + from PIL import Image + self._ocr_available = True + except ImportError: + logger.warning("pytesseract or PIL not available. OCR disabled.") + self._ocr_available = False + self.use_ocr = False + + if self.use_pdf_extraction: + try: + import PyPDF2 + self._pypdf2_available = True + except ImportError: + try: + import pdfplumber + self._pdfplumber_available = True + self._pypdf2_available = False + except ImportError: + logger.warning("PyPDF2 or pdfplumber not available. PDF extraction disabled.") + self._pdfplumber_available = False + self.use_pdf_extraction = False + + def process_file(self, file_path: Path) -> Iterator[str]: + """ + Process a single file and yield text lines. + + Args: + file_path: Path to the file + + Yields: + Text lines extracted from the file + """ + file_path = Path(file_path) + + if not file_path.exists(): + logger.warning(f"File not found: {file_path}") + return + + suffix = file_path.suffix.lower() + + try: + if suffix in self.TEXT_EXTENSIONS: + yield from self._process_text_file(file_path) + elif suffix in self.CODE_EXTENSIONS: + yield from self._process_code_file(file_path) + elif suffix in self.PDF_EXTENSIONS: + yield from self._process_pdf(file_path) + elif suffix in self.IMAGE_EXTENSIONS: + yield from self._process_image(file_path) + else: + # Try to process as text file as fallback (many file types can be read as text) + # Only log at debug level to avoid spam + logger.debug(f"Unsupported file type: {file_path} (extension: {suffix}), attempting as text...") + try: + yield from self._process_text_file(file_path) + except KeyboardInterrupt: + raise + except Exception as e: + logger.debug(f"Failed to process {file_path} as text: {e}") + except KeyboardInterrupt: + # Re-raise KeyboardInterrupt to allow graceful shutdown + logger.warning(f"Interrupted while processing {file_path}") + raise + except Exception as e: + logger.error(f"Error processing {file_path}: {e}") + + def _process_text_file(self, file_path: Path) -> Iterator[str]: + """Process a text file.""" + try: + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + for line in f: + line = line.strip() + if line: + yield line + except KeyboardInterrupt: + # Re-raise KeyboardInterrupt to allow graceful shutdown + logger.warning(f"Interrupted while reading {file_path}") + raise + except UnicodeDecodeError: + # Try with different encoding + try: + with open(file_path, 'r', encoding='latin-1', errors='ignore') as f: + for line in f: + line = line.strip() + if line: + yield line + except KeyboardInterrupt: + logger.warning(f"Interrupted while reading {file_path}") + raise + except Exception as e: + logger.error(f"Failed to read {file_path}: {e}") + + def _process_code_file(self, file_path: Path) -> Iterator[str]: + """Process a code file.""" + # Code files are processed as text, but we can add syntax-aware processing + # For now, just extract text lines + yield from self._process_text_file(file_path) + + def _process_pdf(self, file_path: Path) -> Iterator[str]: + """Extract text from PDF file.""" + if not self.use_pdf_extraction: + logger.warning(f"PDF extraction disabled. Skipping {file_path}") + return + + try: + if self._pypdf2_available: + import PyPDF2 + with open(file_path, 'rb') as f: + pdf_reader = PyPDF2.PdfReader(f) + for page_num, page in enumerate(pdf_reader.pages): + try: + text = page.extract_text() + if text: + # Split into sentences/lines + for line in text.split('\n'): + line = line.strip() + if line and len(line) > 5: # Filter very short lines + yield line + except KeyboardInterrupt: + logger.warning(f"Interrupted while processing PDF page {page_num} from {file_path}") + raise + except Exception as e: + logger.warning(f"Error extracting page {page_num} from {file_path}: {e}") + + elif self._pdfplumber_available: + import pdfplumber + with pdfplumber.open(file_path) as pdf: + for page_num, page in enumerate(pdf.pages): + try: + text = page.extract_text() + if text: + for line in text.split('\n'): + line = line.strip() + if line and len(line) > 5: + yield line + except KeyboardInterrupt: + logger.warning(f"Interrupted while processing PDF page {page_num} from {file_path}") + raise + except Exception as e: + logger.warning(f"Error extracting page {page_num} from {file_path}: {e}") + + except KeyboardInterrupt: + logger.warning(f"Interrupted while processing PDF {file_path}") + raise + except Exception as e: + logger.error(f"Failed to extract text from PDF {file_path}: {e}") + + def _process_image(self, file_path: Path) -> Iterator[str]: + """Extract text from image using OCR.""" + if not self.use_ocr or not self._ocr_available: + logger.warning(f"OCR disabled or unavailable. Skipping {file_path}") + return + + try: + import pytesseract + from PIL import Image + + # Open and process image + img = Image.open(file_path) + + # Perform OCR + text = pytesseract.image_to_string(img) + + if text: + # Split into lines + for line in text.split('\n'): + line = line.strip() + if line and len(line) > 3: # Filter very short lines + yield line + + except KeyboardInterrupt: + logger.warning(f"Interrupted while processing image {file_path}") + raise + except Exception as e: + logger.error(f"Failed to extract text from image {file_path}: {e}") + + def process_directory( + self, + directory: Path, + recursive: bool = True, + include_patterns: Optional[List[str]] = None, + exclude_patterns: Optional[List[str]] = None, + min_length: int = 10, + ) -> Iterator[str]: + """ + Process all files in a directory. + + Args: + directory: Directory path + recursive: Whether to process subdirectories + include_patterns: Optional list of glob patterns to include + exclude_patterns: Optional list of glob patterns to exclude + min_length: Minimum length for extracted text lines + + Yields: + Text lines from all processed files + """ + directory = Path(directory) + + if not directory.exists(): + logger.error(f"Directory not found: {directory}") + return + + # Try to load cached file list + cache_path = self._get_cache_path(directory, 'files') + + # Compute directory hash (may be interrupted) + try: + logger.info("Computing directory hash for cache validation...") + logger.info("(This may take a while for large directories. Press Ctrl+C to skip cache and do fresh scan)") + sys.stderr.flush() + current_hash = self._compute_directory_hash(directory, recursive) + cached_data = self._load_cache(cache_path) + except KeyboardInterrupt: + logger.warning("\nāš ļø Directory hash computation interrupted.") + logger.warning(" Skipping cache validation and doing fresh directory scan...") + logger.warning(" (Press Ctrl+C again to stop completely)") + sys.stderr.flush() + current_hash = None # Force cache miss + cached_data = None + # Don't re-raise - allow user to continue with fresh scan + # If they want to stop completely, they can press Ctrl+C again during scanning + + files_to_process = [] + scanned_count = 0 + skipped_count = 0 + + # Check if cache is valid + if current_hash and cached_data and cached_data.get('hash') == current_hash: + files_to_process = [Path(f) for f in cached_data.get('files', [])] + logger.info(f"āœ… Loaded {len(files_to_process):,} files from cache (skipping directory scan)") + else: + # Cache miss or invalid - scan directory + logger.info("Scanning directory (cache miss or invalid)...") + + # Collect all supported file extensions + all_supported_extensions = ( + self.TEXT_EXTENSIONS | + self.CODE_EXTENSIONS | + self.PDF_EXTENSIONS | + self.IMAGE_EXTENSIONS + ) + + if recursive: + pattern = '**/*' + else: + pattern = '*' + + # Default exclude patterns for common directories that don't contain training data + default_exclude_patterns = [ + '**/.git/**', + '**/__pycache__/**', + '**/node_modules/**', + '**/.venv/**', + '**/venv/**', + '**/.env/**', + '**/.pytest_cache/**', + '**/.mypy_cache/**', + '**/.tox/**', + '**/.coverage/**', + '**/dist/**', + '**/build/**', + '**/*.pyc', + '**/.DS_Store', + ] + + # Merge user exclude patterns with defaults + all_exclude_patterns = default_exclude_patterns.copy() + if exclude_patterns: + # Convert any Path objects to strings + all_exclude_patterns.extend(str(p) if isinstance(p, Path) else p for p in exclude_patterns) + + # Ensure all patterns are strings (not Path objects) + all_exclude_patterns = [str(p) for p in all_exclude_patterns] + + # Convert include_patterns to strings as well + if include_patterns: + include_patterns = [str(p) if isinstance(p, Path) else p for p in include_patterns] + + logger.info(f"Scanning directory: {directory} (recursive={recursive})...") + logger.info("This may take several minutes for large directories. Please wait...") + sys.stderr.flush() # Force flush to show message immediately + + try: + for file_path in directory.glob(pattern): + scanned_count += 1 + + # Progress reporting every 1000 files scanned + if scanned_count % 1000 == 0: + logger.info(f"Scanned {scanned_count:,} paths, found {len(files_to_process):,} files to process...") + sys.stderr.flush() # Force flush to show progress immediately + + # Skip if not a file (handles symlinks, directories, etc. gracefully) + try: + if not file_path.is_file(): + continue + except (OSError, PermissionError) as e: + # Skip inaccessible files (broken symlinks, permission denied, etc.) + skipped_count += 1 + if skipped_count <= 10: # Only log first 10 to avoid spam + logger.debug(f"Skipping inaccessible path: {file_path} ({e})") + continue + + # Early filtering by extension to avoid checking unsupported files + suffix = file_path.suffix.lower() + if suffix not in all_supported_extensions: + continue + + # Check include/exclude patterns + if include_patterns: + if not any(file_path.match(pattern) for pattern in include_patterns): + continue + + if all_exclude_patterns: + if any(file_path.match(pattern) for pattern in all_exclude_patterns): + continue + + files_to_process.append(file_path) + + except KeyboardInterrupt: + logger.warning(f"Directory scanning interrupted. Found {len(files_to_process)} files so far.") + raise + except Exception as e: + logger.error(f"Error during directory scanning: {e}") + logger.info(f"Continuing with {len(files_to_process)} files found so far...") + + if skipped_count > 10: + logger.info(f"Skipped {skipped_count} inaccessible paths") + + logger.info(f"Found {len(files_to_process):,} files to process (scanned {scanned_count:,} paths)") + sys.stderr.flush() # Force flush + + # Save file list to cache + cache_data = { + 'hash': current_hash, + 'files': [str(f.absolute()) for f in files_to_process], + 'recursive': recursive, + } + self._save_cache(cache_path, cache_data) + logger.info(f"šŸ’¾ Cached file list ({len(files_to_process):,} files) for future use") + sys.stderr.flush() # Force flush + + # Process each file + processed_count = 0 + skipped_count = 0 + error_count = 0 + total_lines = 0 + total_files = len(files_to_process) + + if total_files == 0: + logger.warning("No files found to process!") + return + + logger.info(f"Starting to process {total_files} files with progress bar...") + + # Create progress bar + pbar = tqdm( + total=total_files, + desc="Processing files", + unit="file", + ncols=120, + mininterval=0.1, # Update at least every 0.1 seconds + maxinterval=1.0, # Force update at least once per second + file=sys.stderr, # Write to stderr to avoid buffering issues + dynamic_ncols=True, # Auto-adjust to terminal width + disable=False, # Explicitly enable + ) + + try: + for idx, file_path in enumerate(files_to_process, 1): + try: + file_lines = list(self.process_file(file_path)) + if file_lines: + processed_count += 1 + for line in file_lines: + if len(line) >= min_length: + yield line + total_lines += 1 + else: + skipped_count += 1 + + # Update progress bar with statistics + pbar.set_postfix({ + 'Processed': processed_count, + 'Skipped': skipped_count, + 'Errors': error_count, + 'Lines': f"{total_lines:,}" + }) + pbar.update(1) # Advance progress bar + pbar.refresh() # Force immediate refresh + sys.stderr.flush() # Force flush stderr to ensure progress bar displays + + except KeyboardInterrupt: + pbar.close() + logger.warning( + f"Processing interrupted. " + f"Files: {idx}/{total_files}, Processed: {processed_count}, " + f"Skipped: {skipped_count}, Errors: {error_count}, " + f"Lines extracted: {total_lines:,}" + ) + raise + except Exception as e: + error_count += 1 + logger.error(f"Error processing {file_path}: {e}") + # Update progress bar even on errors + pbar.set_postfix({ + 'Processed': processed_count, + 'Skipped': skipped_count, + 'Errors': error_count, + 'Lines': f"{total_lines:,}" + }) + pbar.update(1) # Advance progress bar even on error + pbar.refresh() # Force immediate refresh + sys.stderr.flush() # Force flush stderr to ensure progress bar displays + finally: + pbar.close() + + logger.info( + f"Processing complete: {processed_count}/{total_files} files processed successfully, " + f"{skipped_count} skipped, {error_count} errors, {total_lines:,} lines extracted" + ) + + def process_to_list( + self, + directory: Path, + recursive: bool = True, + include_patterns: Optional[List[str]] = None, + exclude_patterns: Optional[List[str]] = None, + min_length: int = 10, + max_samples: Optional[int] = None, + ) -> List[str]: + """ + Process directory and return list of text lines. + + Args: + directory: Directory path + recursive: Whether to process subdirectories + include_patterns: Optional list of glob patterns to include + exclude_patterns: Optional list of glob patterns to exclude + min_length: Minimum length for extracted text lines + max_samples: Maximum number of samples to return (None = all) + + Returns: + List of text lines + """ + logger.info(f"Starting data extraction from {directory}...") + logger.info("This may take a while for large directories. Progress will be shown below.") + sys.stderr.flush() # Force flush to show message immediately + + texts = [] + try: + for text in self.process_directory( + directory=directory, + recursive=recursive, + include_patterns=include_patterns, + exclude_patterns=exclude_patterns, + min_length=min_length, + ): + texts.append(text) + if max_samples and len(texts) >= max_samples: + logger.info(f"Reached max_samples limit ({max_samples}). Stopping extraction.") + break + except KeyboardInterrupt: + # Return partial results if interrupted + logger.warning( + f"Data processing interrupted. Returning {len(texts):,} text samples collected so far." + ) + # Re-raise to allow caller to handle if needed + raise + + logger.info(f"āœ… Extracted {len(texts):,} text samples from {directory}") + return texts + + +def extract_text_from_directory( + directory: Path, + recursive: bool = True, + use_ocr: bool = True, + use_pdf_extraction: bool = True, + min_length: int = 10, + max_samples: Optional[int] = None, +) -> List[str]: + """ + Convenience function to extract text from a directory. + + Args: + directory: Directory path + recursive: Whether to process subdirectories + use_ocr: Whether to use OCR for images + use_pdf_extraction: Whether to extract text from PDFs + min_length: Minimum length for extracted text lines + max_samples: Maximum number of samples to return (None = all) + + Returns: + List of text lines + """ + processor = DataProcessor(use_ocr=use_ocr, use_pdf_extraction=use_pdf_extraction) + try: + return processor.process_to_list( + directory=directory, + recursive=recursive, + min_length=min_length, + max_samples=max_samples, + ) + except KeyboardInterrupt: + logger.error( + "\nāš ļø Data processing interrupted by user (Ctrl+C).\n" + " No data was loaded. Please run the training command again to retry." + ) + # Re-raise to stop training + raise + + +# Try to import BPE tokenizer for direct access +try: + from .bpe_tokenizer import BPETokenizer + __all__ = [ + 'TextDataset', + 'SimpleTokenizer', + 'BPETokenizer', + 'create_dataloader', + 'DataProcessor', + 'extract_text_from_directory', + ] +except ImportError: + __all__ = [ + 'TextDataset', + 'SimpleTokenizer', + 'create_dataloader', + 'DataProcessor', + 'extract_text_from_directory', + ] + diff --git a/data.example/bpe_tokenizer.py b/data.example/bpe_tokenizer.py new file mode 100644 index 0000000..7b8f36f --- /dev/null +++ b/data.example/bpe_tokenizer.py @@ -0,0 +1,423 @@ +""" +Improved BPE Tokenizer based on GPT-4 tokenization approach +Addresses common tokenization challenges: +- UTF-8 byte-level encoding +- Better Python code handling +- Case-insensitive contraction matching +- Limited number merging (1-3 digits) +- Proper special token handling +- Trailing whitespace warnings +""" +import re +import json +from typing import List, Dict, Tuple, Optional, Set +from collections import defaultdict +from pathlib import Path + + +class BPETokenizer: + """ + Byte Pair Encoding Tokenizer with GPT-4-inspired improvements. + + Key features: + - UTF-8 byte-level encoding + - BPE merging algorithm + - GPT-4 style regex pattern for text splitting + - Better whitespace handling for Python code + - Case-insensitive matching for contractions + - Limited number merging (1-3 digits) + """ + + def __init__( + self, + vocab_size: int = 50257, + special_tokens: Optional[Dict[str, int]] = None, + merges_file: Optional[str] = None, + vocab_file: Optional[str] = None, + ): + """ + Initialize BPE tokenizer. + + Args: + vocab_size: Target vocabulary size (default 50257 for GPT-2 style) + special_tokens: Dictionary of special token names to IDs + merges_file: Path to saved merges file + vocab_file: Path to saved vocab file + """ + # Special tokens + self.special_tokens = special_tokens or { + '': 0, + '': 1, + '': 2, + '': 3, + } + + # Initialize byte vocabulary (0-255) + self.byte_to_token = {i: i for i in range(256)} + self.token_to_byte = {i: bytes([i]) for i in range(256)} + self.next_token_id = 256 + + # BPE merges: (left, right) -> merged_token_id + self.merges: Dict[Tuple[int, int], int] = {} + + # Vocabulary: token_id -> bytes + self.vocab: Dict[int, bytes] = {} + self.inv_vocab: Dict[bytes, int] = {} + + # Initialize vocab with bytes + for i in range(256): + self.vocab[i] = bytes([i]) + self.inv_vocab[bytes([i])] = i + + # GPT-4 style regex pattern for splitting text + # Improvements over GPT-2: + # - Case-insensitive matching (flag) + # - Better whitespace handling + # - Limit number merging to 1-3 digits + self.pattern = self._create_gpt4_pattern() + + # Load pre-trained tokenizer if files provided + if merges_file and vocab_file: + self.load(merges_file, vocab_file) + else: + self.target_vocab_size = vocab_size + + # Token IDs for special tokens + self.pad_token_id = self.special_tokens.get('', 0) + self.unk_token_id = self.special_tokens.get('', 1) + self.bos_token_id = self.special_tokens.get('', 2) + self.eos_token_id = self.special_tokens.get('', 3) + + def _create_gpt4_pattern(self) -> re.Pattern: + """ + Create GPT-4 style regex pattern for splitting text. + + Improvements over GPT-2: + - Case-insensitive matching for contractions + - Better whitespace handling (groups multiple spaces) + - Limit number merging (1-3 digits) + """ + # GPT-4 style pattern with improvements + # Pattern breakdown: + # 1. Contractions: '(?i:[sdmt]|ll|ve|re) - case-insensitive + # 2. Letters: [^\r\n\p{L}\p{N}]?+\p{L}+ - optional space + letters + # 3. Numbers: \p{N}{1,3} - 1-3 digits only + # 4. Punctuation: ?[^\s\p{L}\p{N}]++ - optional space + non-letter/number + # 5. Whitespace: \r?\n - newlines + # 6. Trailing whitespace: \s+ - multiple spaces + pattern = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++|\r?\n|\s+""" + + # Compile with case-insensitive flag for contractions + return re.compile(pattern, re.IGNORECASE | re.UNICODE) + + def _get_stats(self, tokens: List[int]) -> Dict[Tuple[int, int], int]: + """ + Get statistics of consecutive token pairs. + + Args: + tokens: List of token IDs + + Returns: + Dictionary mapping pair tuples to counts + """ + stats = defaultdict(int) + for i in range(len(tokens) - 1): + pair = (tokens[i], tokens[i + 1]) + stats[pair] += 1 + return dict(stats) + + def _merge(self, tokens: List[int], pair: Tuple[int, int], new_id: int) -> List[int]: + """ + Merge consecutive occurrences of a pair into a new token. + + Args: + tokens: List of token IDs + pair: Tuple of (left, right) tokens to merge + new_id: New token ID to replace the pair + + Returns: + New list with merged tokens + """ + if len(tokens) < 2: + return tokens + + new_tokens = [] + i = 0 + while i < len(tokens): + # Check if we can merge at position i + if i < len(tokens) - 1 and tokens[i] == pair[0] and tokens[i + 1] == pair[1]: + new_tokens.append(new_id) + i += 2 + else: + new_tokens.append(tokens[i]) + i += 1 + + return new_tokens + + def train( + self, + texts: List[str], + num_merges: Optional[int] = None, + verbose: bool = False, + ): + """ + Train BPE tokenizer on a corpus of texts. + + Args: + texts: List of training texts + num_merges: Number of merges to perform (default: vocab_size - 256) + verbose: Whether to print progress + """ + if num_merges is None: + num_merges = self.target_vocab_size - 256 + + # Convert all texts to byte sequences + all_tokens = [] + for text in texts: + # Split text using regex pattern + chunks = self.pattern.findall(text) + + # Convert each chunk to bytes and tokenize + for chunk in chunks: + bytes_seq = chunk.encode('utf-8') + tokens = list(bytes_seq) + all_tokens.extend(tokens) + # Add separator between chunks (optional) + # all_tokens.append(256) # separator token + + # Perform BPE merges + for merge_num in range(num_merges): + # Get statistics + stats = self._get_stats(all_tokens) + + if not stats: + break + + # Find most frequent pair + pair = max(stats, key=stats.get) + + # Create new token + new_id = self.next_token_id + self.next_token_id += 1 + + # Merge + all_tokens = self._merge(all_tokens, pair, new_id) + + # Store merge + self.merges[pair] = new_id + + # Update vocabulary + left_bytes = self.vocab.get(pair[0], bytes([pair[0]])) + right_bytes = self.vocab.get(pair[1], bytes([pair[1]])) + merged_bytes = left_bytes + right_bytes + self.vocab[new_id] = merged_bytes + self.inv_vocab[merged_bytes] = new_id + + if verbose and (merge_num + 1) % 1000 == 0: + print(f"Merged {merge_num + 1}/{num_merges} pairs") + + if verbose: + print(f"Training complete. Vocabulary size: {len(self.vocab)}") + + def _encode_chunk(self, text: str) -> List[int]: + """ + Encode a single text chunk using BPE. + + Args: + text: Text chunk to encode + + Returns: + List of token IDs + """ + # Convert to bytes + bytes_seq = text.encode('utf-8') + tokens = list(bytes_seq) + + # If no merges trained yet, return byte tokens + if not self.merges: + return tokens + + # Apply merges in order + # Sort merges by their token ID (merge order) + sorted_merges = sorted(self.merges.items(), key=lambda x: x[1]) + + # Keep merging until no more merges are possible + changed = True + while changed: + changed = False + best_pair = None + best_idx = float('inf') + + # Find the earliest merge we can apply + for i in range(len(tokens) - 1): + pair = (tokens[i], tokens[i + 1]) + if pair in self.merges: + merge_idx = self.merges[pair] + if merge_idx < best_idx: + best_idx = merge_idx + best_pair = pair + + # Apply the best merge + if best_pair is not None: + merged_id = self.merges[best_pair] + tokens = self._merge(tokens, best_pair, merged_id) + changed = True + + return tokens + + def encode(self, text: str, allowed_special: Optional[Set[str]] = None) -> List[int]: + """ + Encode text into token IDs. + + Args: + text: Input text + allowed_special: Set of special tokens to allow in text + + Returns: + List of token IDs + """ + # Check for trailing whitespace (warn if present) + if text and text[-1] == ' ': + import warnings + warnings.warn( + "Text ends with trailing whitespace. This may cause worse performance " + "due to how the tokenizer splits text into tokens.", + UserWarning + ) + + # Handle special tokens + if allowed_special: + for special_name, special_id in self.special_tokens.items(): + if special_name in allowed_special and special_name in text: + # Simple special token replacement (can be improved) + if text == special_name: + return [special_id] + + # Split text using regex pattern + chunks = self.pattern.findall(text) + + # Encode each chunk + tokens = [] + for chunk in chunks: + chunk_tokens = self._encode_chunk(chunk) + tokens.extend(chunk_tokens) + + return tokens + + def decode(self, token_ids: List[int], errors: str = 'replace') -> str: + """ + Decode token IDs back to text. + + Args: + token_ids: List of token IDs + errors: Error handling for invalid UTF-8 ('strict', 'replace', 'ignore') + + Returns: + Decoded text string + """ + # Handle special tokens + if self.eos_token_id in token_ids: + # Stop at EOS token + eos_idx = token_ids.index(self.eos_token_id) + token_ids = token_ids[:eos_idx] + + # Convert tokens to bytes + bytes_parts = [] + for token_id in token_ids: + if token_id in self.special_tokens.values(): + # Skip special tokens (except maybe keep them for debugging) + continue + + if token_id in self.vocab: + bytes_parts.append(self.vocab[token_id]) + else: + # Unknown token - try to use byte representation + if token_id < 256: + bytes_parts.append(bytes([token_id])) + else: + # Unknown token - use replacement character + bytes_parts.append(b'\ufffd') + + # Concatenate bytes + if not bytes_parts: + return '' + + try: + combined_bytes = b''.join(bytes_parts) + return combined_bytes.decode('utf-8', errors=errors) + except UnicodeDecodeError: + # Fallback with replacement + return combined_bytes.decode('utf-8', errors='replace') + + def save(self, merges_file: str, vocab_file: str): + """ + Save tokenizer to files. + + Args: + merges_file: Path to save merges + vocab_file: Path to save vocabulary + """ + # Save merges + merges_list = [ + (left, right, merged_id) + for (left, right), merged_id in sorted(self.merges.items(), key=lambda x: x[1]) + ] + + with open(merges_file, 'w') as f: + json.dump(merges_list, f, indent=2) + + # Save vocabulary (convert bytes to base64 or hex) + vocab_dict = { + str(token_id): token_bytes.hex() + for token_id, token_bytes in self.vocab.items() + } + + with open(vocab_file, 'w') as f: + json.dump({ + 'vocab': vocab_dict, + 'special_tokens': self.special_tokens, + 'next_token_id': self.next_token_id, + }, f, indent=2) + + def load(self, merges_file: str, vocab_file: str): + """ + Load tokenizer from files. + + Args: + merges_file: Path to merges file + vocab_file: Path to vocabulary file + """ + # Load merges + with open(merges_file, 'r') as f: + merges_list = json.load(f) + + for left, right, merged_id in merges_list: + self.merges[(left, right)] = merged_id + self.next_token_id = max(self.next_token_id, merged_id + 1) + + # Load vocabulary + with open(vocab_file, 'r') as f: + vocab_data = json.load(f) + + vocab_dict = vocab_data['vocab'] + for token_id_str, token_bytes_hex in vocab_dict.items(): + token_id = int(token_id_str) + token_bytes = bytes.fromhex(token_bytes_hex) + self.vocab[token_id] = token_bytes + self.inv_vocab[token_bytes] = token_id + + if 'special_tokens' in vocab_data: + self.special_tokens.update(vocab_data['special_tokens']) + + if 'next_token_id' in vocab_data: + self.next_token_id = vocab_data['next_token_id'] + + @property + def vocab_size(self) -> int: + """Get vocabulary size.""" + return len(self.vocab) + len(self.special_tokens) + + +# Backward compatibility alias +SimpleTokenizer = BPETokenizer + diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md new file mode 100644 index 0000000..618cad2 --- /dev/null +++ b/docs/ARCHITECTURE.md @@ -0,0 +1,729 @@ +# SheepOp LLM - Complete Architecture Documentation + +Complete documentation of the SheepOp Language Model project architecture, data flow, training pipeline, and inference system. + +## Table of Contents + +1. [System Overview](#system-overview) +2. [Data Ingestion Pipeline](#data-ingestion-pipeline) +3. [Training Pipeline](#training-pipeline) +4. [Model Architecture](#model-architecture) +5. [Inference Pipeline](#inference-pipeline) +6. [Complete Workflow](#complete-workflow) + +--- + +## System Overview + +```mermaid +graph TB + subgraph "Data Sources" + A[PDF Files] --> DataProcessor + B[Images - PNG/JPG/etc] --> DataProcessor + C[Code Files - .py/.js/etc] --> DataProcessor + D[Text Files - .txt/.md/etc] --> DataProcessor + end + + DataProcessor[DataProcessor
Multi-Format Extractor] --> TextList[Text Lines] + + TextList --> Tokenizer[SimpleTokenizer
Character-Level] + Tokenizer --> DataLoader[PyTorch DataLoader
Batched Sequences] + + DataLoader --> Trainer[Trainer
Training Loop] + + subgraph "Training Components" + Trainer --> Model[TransformerModel] + Trainer --> Optimizer[AdamW Optimizer] + Trainer --> Scheduler[CosineAnnealingLR] + Trainer --> Loss[CrossEntropyLoss] + end + + Model --> Checkpoint[Model Checkpoints
checkpoints/*.pt] + + Checkpoint --> Inference[Inference Script] + Inference --> GeneratedText[Generated Text] + + style DataProcessor fill:#e1f5ff + style Model fill:#fff4e1 + style Trainer fill:#ffe1f5 + style Checkpoint fill:#e1ffe1 +``` + +--- + +## Data Ingestion Pipeline + +### Multi-Format Data Processing Flow + +```mermaid +flowchart TD + Start([Start:
train.py --data path]) --> CheckPath{Path Type?} + + CheckPath -->|File| SingleFile[Process Single File] + CheckPath -->|Directory| Directory[Process Directory] + + SingleFile --> DataProcessor[DataProcessor.process_file] + Directory --> RecursiveScan[Recursive Directory Scan
Find all files] + + RecursiveScan --> FileType{File Extension?} + + FileType -->|.txt/.md/.json/etc| TextExtract[Read as Text File
Line by line] + FileType -->|.py/.js/.java/etc| CodeExtract[Read as Code File
Line by line] + FileType -->|.pdf| PDFExtract[PDF Extraction
PyPDF2/pdfplumber] + FileType -->|.png/.jpg/.tiff/etc| ImageExtract[OCR Extraction
pytesseract] + FileType -->|Unknown| Fallback[Try Text Fallback] + + PDFExtract --> PDFPages[Extract Each Page] + PDFPages --> PDFLines[Split into Lines] + + ImageExtract --> OCR[Perform OCR
pytesseract] + OCR --> OCRLines[Split OCR Text into Lines] + + TextExtract --> FilterLines[Filter Lines
min_length=10] + CodeExtract --> FilterLines + PDFLines --> FilterLines + OCRLines --> FilterLines + Fallback --> FilterLines + + FilterLines --> Combine[List of Text Lines] + Combine --> Validate{Texts Empty?} + + Validate -->|Yes| Error[Raise Error:
No data extracted] + Validate -->|No| Success[āœ… Success
N text samples loaded] + + Success --> TokenizerStep[Next: Tokenization] + + style DataProcessor fill:#e1f5ff + style PDFExtract fill:#ffe1f5 + style ImageExtract fill:#fff4e1 + style Success fill:#e1ffe1 + style Error fill:#ffe1e1 +``` + +### Data Processing Components + +```mermaid +classDiagram + class DataProcessor { + +process_file(file_path) Iterator[str] + +process_directory(directory) Iterator[str] + +process_to_list(...) List[str] + -_process_text_file() Iterator[str] + -_process_code_file() Iterator[str] + -_process_pdf() Iterator[str] + -_process_image() Iterator[str] + -_check_dependencies() + } + + class SimpleTokenizer { + +vocab: Dict[str, int] + +inv_vocab: Dict[int, str] + +vocab_size: int + +encode(text: str) List[int] + +decode(token_ids: List[int]) str + +save_vocab(path: str) + } + + class TextDataset { + +texts: List[str] + +tokenizer: SimpleTokenizer + +max_length: int + +sequences: List[torch.Tensor] + +__getitem__(idx) Dict + +_prepare_sequences() List[Tensor] + } + + class DataLoader { + +batch_size: int + +shuffle: bool + +num_workers: int + +collate_fn: Callable + } + + DataProcessor --> TextDataset : extracts text + SimpleTokenizer --> TextDataset : tokenizes + TextDataset --> DataLoader : creates dataset +``` + +--- + +## Training Pipeline + +### Complete Training Flow + +```mermaid +flowchart TD + Start([python train.py
--data path]) --> Args[Parse Arguments
--data, --config, --resume, --device] + + Args --> ConfigLoad{Config File
Provided?} + ConfigLoad -->|Yes| LoadConfig[Load config.json] + ConfigLoad -->|No| DefaultConfig[Use Default Config] + + LoadConfig --> Config[Config Object
ModelConfig
TrainingConfig
DataConfig
seed=42] + DefaultConfig --> Config + + Config --> SetSeed[Set Random Seed
torch.manual_seed
torch.cuda.manual_seed_all
CUDNN deterministic] + + SetSeed --> Device[Detect Device
CUDA/MPS/CPU] + + Device --> DataIngestion[Data Ingestion Pipeline
Extract text from all files] + + DataIngestion --> TextList[List of Text Lines
N samples] + + TextList --> CreateTokenizer[Create SimpleTokenizer
Character-level vocab] + + CreateTokenizer --> Tokenizer[Tokenizer Ready
vocab_size calculated] + + Tokenizer --> CreateDataLoader[Create DataLoader
Batch size
Max length
Shuffle] + + CreateDataLoader --> TrainLoader[PyTorch DataLoader
Batched sequences] + + TrainLoader --> CheckResume{Resume
Checkpoint?} + + CheckResume -->|Yes| LoadCheckpoint[Load Checkpoint
Model state
Optimizer state
Scheduler state
Epoch/Step] + CheckResume -->|No| CreateModel[Create New Model
TransformerModel] + + LoadCheckpoint --> CreateModel + + CreateModel --> Model[Model Ready
N parameters] + + Model --> CreateOptimizer[Create Optimizer
AdamW
lr, weight_decay] + + CreateOptimizer --> CreateScheduler[Create Scheduler
CosineAnnealingLR
T_max=total_steps] + + CreateScheduler --> CreateTrainer[Create Trainer
Model
DataLoader
Optimizer
Scheduler
Device] + + CreateTrainer --> Trainer[Trainer Ready] + + Trainer --> TrainingLoop[Training Loop
For each epoch] + + TrainingLoop --> EpochLoop[For each batch] + + EpochLoop --> Forward[Forward Pass
Model prediction] + + Forward --> Loss[Compute Loss
CrossEntropyLoss] + + Loss --> Backward[Backward Pass
Compute gradients] + + Backward --> GradientAccum{Gradient
Accumulation?} + + GradientAccum -->|Not yet| Accumulate[Accumulate gradients] + Accumulate --> EpochLoop + + GradientAccum -->|Ready| ClipGrad[Gradient Clipping
max_grad_norm] + + ClipGrad --> Update[Update Weights
Optimizer.step] + + Update --> UpdateLR[Update Learning Rate
Scheduler.step] + + UpdateLR --> ZeroGrad[Zero Gradients] + + ZeroGrad --> Log{Log Interval?} + + Log -->|Yes| LogMetrics[Log Metrics
Loss, LR
Save to metrics.json] + Log -->|No| EvalCheck{Evaluation
Interval?} + + LogMetrics --> EvalCheck + + EvalCheck -->|Yes| Evaluate[Evaluate on
Validation Set] + EvalCheck -->|No| SaveCheck{End of
Epoch?} + + Evaluate --> SaveCheck + + SaveCheck -->|No| EpochLoop + SaveCheck -->|Yes| SaveCheckpoint[Save Checkpoint
Model state
Optimizer state
Scheduler state
Epoch/Step] + + SaveCheckpoint --> MoreEpochs{More
Epochs?} + + MoreEpochs -->|Yes| TrainingLoop + MoreEpochs -->|No| GeneratePlots[Generate Training Plots
loss_by_epoch.png
training_curve.png] + + GeneratePlots --> End([Training Complete!
Checkpoints saved]) + + style SetSeed fill:#ffe1f5 + style DataIngestion fill:#e1f5ff + style Model fill:#fff4e1 + style TrainingLoop fill:#ffe1f5 + style End fill:#e1ffe1 +``` + +### Seed Initialization Details + +```mermaid +sequenceDiagram + participant TrainScript as train.py + participant Config as Config + participant PyTorch as PyTorch + participant CUDA as CUDA Backend + + TrainScript->>Config: Load config (seed=42) + TrainScript->>PyTorch: torch.manual_seed(42) + TrainScript->>CUDA: torch.cuda.manual_seed_all(42) + TrainScript->>PyTorch: torch.backends.cudnn.deterministic = True + TrainScript->>PyTorch: torch.backends.cudnn.benchmark = False + + Note over TrainScript,CUDA: Seed ensures reproducibility
across runs and devices +``` + +### Training Loop Details + +```mermaid +graph LR + subgraph "Single Training Step" + A[Batch Input
input_ids, labels] --> B[Forward Pass
Model forward] + B --> C[Logits
batch_size Ɨ seq_len Ɨ vocab_size] + C --> D[Compute Loss
CrossEntropyLoss] + D --> E[Backward Pass
Compute gradients] + E --> F{Gradient
Accumulation
Steps reached?} + F -->|No| G[Accumulate Gradients] + F -->|Yes| H[Gradient Clipping] + H --> I[Optimizer Step
Update weights] + I --> J[Scheduler Step
Update LR] + J --> K[Zero Gradients] + K --> L[Log Metrics] + end + + G --> A + + style B fill:#e1f5ff + style D fill:#ffe1f5 + style I fill:#fff4e1 +``` + +--- + +## Model Architecture + +### Transformer Model Structure + +```mermaid +graph TB + Input[Input Tokens
Token IDs] --> Embed[Token Embedding
vocab_size → d_model] + + Embed --> PosEnc[Positional Encoding
Sinusoidal/Cosine] + + PosEnc --> Dropout1[Dropout] + + Dropout1 --> Layer1[Transformer Block 1] + Layer1 --> Layer2[Transformer Block 2] + Layer2 --> Layer3[Transformer Block 3] + Layer3 --> LayerN[Transformer Block N
num_layers] + + LayerN --> LayerNorm[Final Layer Norm] + + LayerNorm --> OutputProj[Output Projection
d_model → vocab_size] + + OutputProj --> Logits[Logits
batch Ɨ seq_len Ɨ vocab_size] + + subgraph "Transformer Block Details" + TBInput[Input x] --> Attention[Multi-Head
Self-Attention] + Attention --> AddNorm1[Add & Norm
Residual + LayerNorm] + AddNorm1 --> FFN[Feed-Forward
Network] + FFN --> AddNorm2[Add & Norm
Residual + LayerNorm] + AddNorm2 --> TBOutput[Output] + end + + style Embed fill:#e1f5ff + style Attention fill:#ffe1f5 + style FFN fill:#fff4e1 + style Logits fill:#e1ffe1 +``` + +### Multi-Head Attention Mechanism + +```mermaid +graph LR + Input[Input
batch Ɨ seq_len Ɨ d_model] --> Q[Query
Linear Layer] + Input --> K[Key
Linear Layer] + Input --> V[Value
Linear Layer] + + Q --> SplitQ[Split into
num_heads heads] + K --> SplitK[Split into
num_heads heads] + V --> SplitV[Split into
num_heads heads] + + SplitQ --> ScaledDot[Scaled Dot-Product
Attention] + SplitK --> ScaledDot + SplitV --> ScaledDot + + ScaledDot --> Mask[Causal Mask
Lower triangular] + + Mask --> Softmax[Softmax] + + Softmax --> AttentionOutput[Attention Output
per head] + + AttentionOutput --> Concat[Concat Heads] + + Concat --> OutputProj[Output Projection
Linear Layer] + + OutputProj --> Output[Output
batch Ɨ seq_len Ɨ d_model] + + style ScaledDot fill:#ffe1f5 + style Mask fill:#fff4e1 + style Output fill:#e1ffe1 +``` + +### Complete Model Component Diagram + +```mermaid +classDiagram + class TransformerModel { + +vocab_size: int + +d_model: int + +num_layers: int + +num_heads: int + +token_embedding: Embedding + +pos_encoding: PositionalEncoding + +layers: ModuleList[TransformerBlock] + +final_norm: LayerNorm + +output_proj: Linear + +forward(input_ids) Tuple[Tensor, Tensor] + +generate(...) Tensor + +get_num_params() int + } + + class TransformerBlock { + +attention: MultiHeadAttention + +ffn: FeedForward + +norm1: LayerNorm + +norm2: LayerNorm + +dropout: Dropout + +forward(x, mask) Tensor + } + + class MultiHeadAttention { + +num_heads: int + +d_model: int + +d_k: int + +q_proj: Linear + +k_proj: Linear + +v_proj: Linear + +out_proj: Linear + +forward(q, k, v, mask) Tensor + } + + class FeedForward { + +linear1: Linear + +linear2: Linear + +activation: GELU/ReLU + +dropout: Dropout + +forward(x) Tensor + } + + class PositionalEncoding { + +d_model: int + +max_len: int + +pe: Tensor + +forward(x) Tensor + } + + TransformerModel --> TransformerBlock : contains N layers + TransformerModel --> PositionalEncoding : adds positional info + TransformerBlock --> MultiHeadAttention : self-attention + TransformerBlock --> FeedForward : feed-forward network +``` + +--- + +## Inference Pipeline + +### Text Generation Flow + +```mermaid +flowchart TD + Start([python inference.py
--checkpoint path
--prompt text]) --> LoadModel[Load Model from Checkpoint
Load state dict
Set to eval mode] + + LoadModel --> CreateTokenizer[Create Tokenizer
SimpleTokenizer] + + CreateTokenizer --> EncodePrompt[Encode Prompt
Text → Token IDs] + + EncodePrompt --> CheckOptimized{Use Optimized
Inference?} + + CheckOptimized -->|Yes| OptimizedGen[OptimizedInference
with KV Caching] + CheckOptimized -->|No| StandardGen[Standard Generation] + + StandardGen --> InitGen[Initialize Generation
generated = input_ids] + + InitGen --> LoopStart[Generation Loop
For max_length steps] + + LoopStart --> Forward[Forward Pass
Model prediction] + + Forward --> NextToken[Get Next Token Logits
Last position] + + NextToken --> Temperature[Apply Temperature
Scale logits] + + Temperature --> TopK{Top-K
Filtering?} + + TopK -->|Yes| FilterK[Filter Top-K Tokens] + TopK -->|No| TopP{Top-P
Nucleus Sampling?} + + FilterK --> TopP + + TopP -->|Yes| FilterP[Filter by Cumulative Prob] + TopP -->|No| Sample[Sample Token
Multinomial] + + FilterP --> Sample + + Sample --> Append[Append Token
to Generated] + + Append --> CheckStop{Stop
Condition?} + + CheckStop -->|No| LoopStart + CheckStop -->|Yes| Decode[Decode Tokens
Token IDs → Text] + + OptimizedGen --> KVCache[Use KV Cache
Cache previous KV] + KVCache --> LoopStart + + Decode --> Output[Generated Text
Output] + + Output --> End([End]) + + style OptimizedGen fill:#e1f5ff + style Forward fill:#ffe1f5 + style Sample fill:#fff4e1 + style Output fill:#e1ffe1 +``` + +### Optimized Inference with KV Caching + +```mermaid +graph TB + subgraph "Standard Generation" + A1[Input Token] --> B1[Forward Pass
Compute Q, K, V] + B1 --> C1[Attention
Full Sequence] + C1 --> D1[Next Token] + D1 --> E1[Append Token] + E1 --> A1 + end + + subgraph "Optimized Generation with KV Cache" + A2[Input Token] --> B2{First
Token?} + B2 -->|Yes| C2[Forward Pass
Compute Q, K, V] + B2 -->|No| C2Cache[Use Cached K, V
Only compute Q] + C2 --> D2[Cache K, V] + D2 --> E2[Attention
Only with New Token] + C2Cache --> E2 + E2 --> F2[Next Token] + F2 --> G2[Append Token] + G2 --> A2 + end + + style C2 fill:#e1f5ff + style C2Cache fill:#ffe1f5 + style E2 fill:#e1ffe1 +``` + +--- + +## Complete Workflow + +### End-to-End System Flow + +```mermaid +flowchart TB + subgraph "Phase 1: Data Preparation" + A1[Raw Data Files
PDFs, Images, Code, Text] --> A2[DataProcessor
Extract Text] + A2 --> A3[Text Lines
List of Strings] + A3 --> A4[SimpleTokenizer
Build Vocabulary] + A4 --> A5[Tokenize & Chunk
Create Sequences] + A5 --> A6[DataLoader
Batched Data] + end + + subgraph "Phase 2: Model Initialization" + B1[Load Config
ModelConfig] --> B2[Set Random Seed
seed=42] + B2 --> B3[Create Model
TransformerModel] + B3 --> B4[Initialize Weights
Normal Distribution] + B4 --> B5[Create Optimizer
AdamW] + B5 --> B6[Create Scheduler
CosineAnnealingLR] + end + + subgraph "Phase 3: Training" + C1[Trainer Setup] --> C2[Training Loop
Epochs] + C2 --> C3[Batch Loop] + C3 --> C4[Forward Pass] + C4 --> C5[Compute Loss] + C5 --> C6[Backward Pass] + C6 --> C7[Gradient Clipping] + C7 --> C8[Update Weights] + C8 --> C9[Save Checkpoint] + C9 --> C10{More Epochs?} + C10 -->|Yes| C2 + C10 -->|No| C11[Generate Plots
Training Metrics] + end + + subgraph "Phase 4: Inference" + D1[Load Checkpoint] --> D2[Load Model State] + D2 --> D3[Encode Prompt] + D3 --> D4[Generate Text
Autoregressive] + D4 --> D5[Decode Tokens] + D5 --> D6[Output Text] + end + + A6 --> B1 + B6 --> C1 + C11 --> D1 + + style A2 fill:#e1f5ff + style B3 fill:#fff4e1 + style C4 fill:#ffe1f5 + style D4 fill:#e1ffe1 +``` + +### Checkpoint Structure + +```mermaid +graph TB + Checkpoint[Checkpoint File
checkpoint_epoch_N.pt] --> ModelState[model_state_dict
Model weights] + Checkpoint --> OptimizerState[optimizer_state_dict
AdamW state] + Checkpoint --> SchedulerState[scheduler_state_dict
LR scheduler state] + Checkpoint --> ModelConfig[model_config
Model hyperparameters] + Checkpoint --> Epoch[epoch
Current epoch number] + Checkpoint --> GlobalStep[global_step
Training step count] + Checkpoint --> BestValLoss[best_val_loss
Best validation loss] + + ModelState --> Resume[Resume Training
Restore model state] + OptimizerState --> Resume + SchedulerState --> Resume + ModelConfig --> Resume + Epoch --> Resume + GlobalStep --> Resume + + style Checkpoint fill:#e1f5ff + style Resume fill:#e1ffe1 +``` + +### Configuration Hierarchy + +```mermaid +graph TB + Config[Config
Root Configuration] --> ModelConfig[ModelConfig
vocab_size
d_model
num_layers
num_heads
d_ff
max_seq_len
dropout
activation] + + Config --> TrainingConfig[TrainingConfig
batch_size
max_epochs
learning_rate
weight_decay
warmup_steps
max_grad_norm
gradient_accumulation_steps
use_amp] + + Config --> DataConfig[DataConfig
data_dir
max_length
stride
num_workers] + + Config --> Global[Global Settings
device
seed] + + ModelConfig --> Model[TransformerModel
Model Architecture] + TrainingConfig --> Trainer[Trainer
Training Parameters] + DataConfig --> DataLoader[DataLoader
Data Parameters] + + style Config fill:#e1f5ff + style Model fill:#fff4e1 + style Trainer fill:#ffe1f5 + style DataLoader fill:#e1ffe1 +``` + +--- + +## Key Components Summary + +### 1. **Data Processing** +- **DataProcessor**: Multi-format text extraction (PDFs, images, code, text) +- **SimpleTokenizer**: Character-level tokenization +- **TextDataset**: PyTorch dataset for training +- **DataLoader**: Batched data loading + +### 2. **Model Architecture** +- **TransformerModel**: Complete transformer language model +- **TransformerBlock**: Multi-head attention + feed-forward +- **MultiHeadAttention**: Scaled dot-product attention with causal masking +- **FeedForward**: Position-wise feed-forward network +- **PositionalEncoding**: Sinusoidal position embeddings + +### 3. **Training** +- **Trainer**: Complete training loop with: + - Gradient accumulation + - Mixed precision training (AMP) + - Gradient clipping + - Learning rate scheduling + - Checkpointing + - Metrics tracking + +### 4. **Inference** +- **Standard Generation**: Autoregressive text generation +- **OptimizedInference**: KV caching for faster generation +- **RetrievalCache**: Caching for RAG systems + +### 5. **Configuration** +- **Config System**: Hierarchical configuration (Model, Training, Data) +- **JSON Support**: Save/load configurations +- **Default Values**: Sensible defaults for all parameters + +--- + +## Usage Examples + +### Training +```bash +# Basic training +python train.py --data /path/to/data + +# With custom config +python train.py --data /path/to/data --config config.json + +# Resume from checkpoint +python train.py --data /path/to/data --resume checkpoints/checkpoint_epoch_5.pt + +# Specify device +python train.py --data /path/to/data --device cuda +``` + +### Inference +```bash +# Basic inference +python inference.py --checkpoint checkpoints/best_checkpoint.pt --prompt "Hello world" + +# With sampling parameters +python inference.py \ + --checkpoint checkpoints/best_checkpoint.pt \ + --prompt "The future of AI" \ + --max-length 200 \ + --temperature 0.8 \ + --top-k 50 \ + --top-p 0.95 + +# Optimized inference +python inference.py \ + --checkpoint checkpoints/best_checkpoint.pt \ + --prompt "Hello" \ + --optimized +``` + +--- + +## File Structure + +``` +sheepOp/ +ā”œā”€ā”€ train.py # Main training script +ā”œā”€ā”€ inference.py # Inference script +ā”œā”€ā”€ config.py # Configuration management +ā”œā”€ā”€ config.json # Configuration file +ā”œā”€ā”€ data/ # Data module (symlink) +│ └── __init__.py # Tokenizer, DataLoader, DataProcessor +ā”œā”€ā”€ models/ # Model definitions +│ ā”œā”€ā”€ transformer.py # Main transformer model +│ ā”œā”€ā”€ blocks.py # Transformer blocks +│ ā”œā”€ā”€ attention.py # Attention mechanisms +│ └── optimized_attention.py # Optimized inference +ā”œā”€ā”€ training/ # Training utilities +│ ā”œā”€ā”€ __init__.py # Trainer class +│ └── metrics.py # Training metrics +ā”œā”€ā”€ checkpoints/ # Saved model checkpoints +└── requirements.txt # Dependencies +``` + +--- + +## Flow Summary + +1. **Data Ingestion**: Raw files → Text extraction → Text lines +2. **Tokenization**: Text lines → Token sequences → Batched data +3. **Model Setup**: Config → Model → Optimizer → Scheduler +4. **Training**: Batches → Forward → Loss → Backward → Update → Checkpoint +5. **Inference**: Checkpoint → Model → Prompt → Generate → Output + +--- + +*This documentation provides a complete view of the SheepOp LLM project architecture and workflow.* + diff --git a/docs/ATTENTION_EXPLAINED.md b/docs/ATTENTION_EXPLAINED.md new file mode 100644 index 0000000..d384204 --- /dev/null +++ b/docs/ATTENTION_EXPLAINED.md @@ -0,0 +1,410 @@ +# What is Attention? Step-by-Step Explanation + +Complete step-by-step explanation of attention mechanisms in transformer models: how models understand relationships between words. + +## Table of Contents + +1. [The Problem Attention Solves](#21-the-problem-attention-solves) +2. [What is Attention?](#22-what-is-attention) +3. [How Attention Works: Step-by-Step](#23-how-attention-works-step-by-step) +4. [Complete Example: Attention in "Hello World"](#24-complete-example-attention-in-hello-world) +5. [Why Attention Matters](#25-why-attention-matters) +6. [Multi-Head Attention](#26-multi-head-attention) +7. [Visual Representation of Attention](#27-visual-representation-of-attention) +8. [Key Takeaways](#28-key-takeaways) + +--- + +## 2.1 The Problem Attention Solves + +### The Challenge + +**In a sentence, words depend on each other:** + +``` +"He saw the cat with binoculars" +``` + +Two possible meanings: +1. He used binoculars to see the cat +2. The cat has binoculars + +**Context matters!** The model needs to understand which words relate to each other. + +### The Solution: Attention + +**Attention allows the model to "look" at other words when processing each word.** + +--- + +## 2.2 What is Attention? + +### Simple Definition + +**Attention** is a mechanism that determines **how much each word should consider other words** when processing information. + +### Intuitive Analogy + +**Think of reading a sentence:** + +When you read "cat" in: +``` +"The cat sat on the mat" +``` + +You might: +- Pay attention to "sat" (what the cat did) +- Pay attention to "mat" (where the cat is) +- Pay less attention to "the" (just a word) + +**Attention does the same thing mathematically!** + +--- + +## 2.3 How Attention Works: Step-by-Step + +### High-Level Overview + +``` +Step 1: Create Query, Key, Value for each word +Step 2: Compare queries and keys (find similarities) +Step 3: Calculate attention weights (how much to attend) +Step 4: Combine values weighted by attention +``` + +### Detailed Step-by-Step + +#### Step 1: Create Query, Key, Value (Q, K, V) + +**For each word, create three representations:** + +**Query (Q):** "What am I looking for?" +**Key (K):** "What am I offering?" +**Value (V):** "What information do I contain?" + +**Example with "Hello World":** + +``` +Word: "Hello" + Query: [0.2, -0.1, 0.3, ...] ← What should I look for? + Key: [0.1, 0.2, -0.1, ...] ← What do I represent? + Value: [0.15, 0.1, 0.2, ...] ← What information do I have? + +Word: "World" + Query: [0.18, 0.15, 0.25, ...] + Key: [0.12, 0.19, -0.08, ...] + Value: [0.14, 0.12, 0.18, ...] +``` + +**How Q, K, V are created:** +``` +Q = Word Ɨ W_Q (learned matrix) +K = Word Ɨ W_K (learned matrix) +V = Word Ɨ W_V (learned matrix) +``` + +#### Step 2: Compute Similarity Scores + +**Compare each query with all keys:** + +``` +Score[i, j] = How much should word i attend to word j? +``` + +**Mathematical Formula:** +``` +Score[i, j] = (Query[i] Ā· Key[j]) / √d_k +``` + +**Example:** + +**Query for "Hello":** `[0.2, -0.1, 0.3]` +**Key for "Hello":** `[0.1, 0.2, -0.1]` +**Key for "World":** `[0.12, 0.19, -0.08]` + +**Calculate similarity:** + +``` +Score["Hello", "Hello"] = (0.2Ɨ0.1 + (-0.1)Ɨ0.2 + 0.3Ɨ(-0.1)) / √3 + = (0.02 - 0.02 - 0.03) / 1.732 + = -0.03 / 1.732 + ā‰ˆ -0.017 + +Score["Hello", "World"] = (0.2Ɨ0.12 + (-0.1)Ɨ0.19 + 0.3Ɨ(-0.08)) / √3 + = (0.024 - 0.019 - 0.024) / 1.732 + = -0.019 / 1.732 + ā‰ˆ -0.011 +``` + +**Result:** Similarity scores tell us how related words are + +#### Step 3: Convert Scores to Attention Weights + +**Use softmax to convert scores to probabilities:** + +``` +Attention[i, j] = exp(Score[i, j]) / Ī£ exp(Score[i, k]) +``` + +**Example:** + +**Raw Scores:** +``` +Score["Hello", "Hello"] = -0.017 +Score["Hello", "World"] = -0.011 +``` + +**Compute exponentials:** +``` +exp(-0.017) ā‰ˆ 0.983 +exp(-0.011) ā‰ˆ 0.989 +Sum = 0.983 + 0.989 = 1.972 +``` + +**Compute attention weights:** +``` +Attention["Hello", "Hello"] = 0.983 / 1.972 ā‰ˆ 0.499 (49.9%) +Attention["Hello", "World"] = 0.989 / 1.972 ā‰ˆ 0.501 (50.1%) +``` + +**Meaning:** "Hello" attends 49.9% to itself and 50.1% to "World" + +#### Step 4: Weighted Combination + +**Combine values using attention weights:** + +``` +Output["Hello"] = Attention["Hello", "Hello"] Ɨ Value["Hello"] + + Attention["Hello", "World"] Ɨ Value["World"] +``` + +**Example:** + +``` +Value["Hello"] = [0.15, 0.1, 0.2] +Value["World"] = [0.14, 0.12, 0.18] + +Output["Hello"] = 0.499 Ɨ [0.15, 0.1, 0.2] + 0.501 Ɨ [0.14, 0.12, 0.18] + = [0.075, 0.050, 0.100] + [0.070, 0.060, 0.090] + = [0.145, 0.110, 0.190] +``` + +**Result:** New representation that combines information from both words! + +--- + +## 2.4 Complete Example: Attention in "Hello World" + +### Input + +``` +Words: ["Hello", "World"] +Position 0: "Hello" +Position 1: "World" +``` + +### Step-by-Step Processing + +#### Step 1: Embeddings + +``` +E["Hello"] = [0.10, -0.20, 0.30, ..., 0.05] +E["World"] = [0.15, -0.18, 0.28, ..., 0.10] +``` + +#### Step 2: Create Q, K, V + +``` +Q["Hello"] = E["Hello"] Ɨ W_Q = [0.2, -0.1, 0.3, ...] +K["Hello"] = E["Hello"] Ɨ W_K = [0.1, 0.2, -0.1, ...] +V["Hello"] = E["Hello"] Ɨ W_V = [0.15, 0.1, 0.2, ...] + +Q["World"] = E["World"] Ɨ W_Q = [0.18, 0.15, 0.25, ...] +K["World"] = E["World"] Ɨ W_K = [0.12, 0.19, -0.08, ...] +V["World"] = E["World"] Ɨ W_V = [0.14, 0.12, 0.18, ...] +``` + +#### Step 3: Compute Attention Scores + +``` +Score Matrix (2Ɨ2): + + "Hello" "World" +"Hello" 0.5 0.3 +"World" 0.4 0.6 +``` + +**Interpretation:** +- "Hello" attends to itself (0.5) more than "World" (0.3) +- "World" attends to itself (0.6) more than "Hello" (0.4) + +#### Step 4: Apply Softmax + +``` +Attention Matrix: + + "Hello" "World" +"Hello" 0.62 0.38 +"World" 0.40 0.60 +``` + +**Interpretation:** +- "Hello" gives 62% attention to itself, 38% to "World" +- "World" gives 40% attention to "Hello", 60% to itself + +#### Step 5: Weighted Combination + +``` +Output["Hello"] = 0.62 Ɨ V["Hello"] + 0.38 Ɨ V["World"] + = 0.62 Ɨ [0.15, 0.1, 0.2] + 0.38 Ɨ [0.14, 0.12, 0.18] + = [0.093, 0.062, 0.124] + [0.053, 0.046, 0.068] + = [0.146, 0.108, 0.192] + +Output["World"] = 0.40 Ɨ V["Hello"] + 0.60 Ɨ V["World"] + = 0.40 Ɨ [0.15, 0.1, 0.2] + 0.60 Ɨ [0.14, 0.12, 0.18] + = [0.060, 0.040, 0.080] + [0.084, 0.072, 0.108] + = [0.144, 0.112, 0.188] +``` + +**Result:** Each word now contains information from both words! + +--- + +## 2.5 Why Attention Matters + +### Benefit 1: Context Understanding + +**Without Attention:** +``` +"Hello" is processed in isolation +"World" is processed in isolation +Result: No understanding of relationship +``` + +**With Attention:** +``` +"Hello" considers "World" (38% attention) +"World" considers "Hello" (40% attention) +Result: Understands they're related +``` + +### Benefit 2: Long-Range Dependencies + +**Attention can connect distant words:** + +``` +"The cat that I saw yesterday sat on the mat" +``` + +- "cat" can attend to "yesterday" (even though far apart) +- Model understands the cat from yesterday + +### Benefit 3: Selective Focus + +**Attention focuses on relevant information:** + +``` +"He saw the cat with binoculars" +``` + +- "saw" attends strongly to "binoculars" (how he saw) +- "cat" attends strongly to "sat" (what it did) +- Each word focuses on what's relevant to it + +--- + +## 2.6 Multi-Head Attention + +### What is Multi-Head Attention? + +**Multiple attention "heads" look at different aspects:** + +``` +Head 1: Focuses on syntax (grammar relationships) +Head 2: Focuses on semantics (meaning relationships) +Head 3: Focuses on position (spatial relationships) +... +Head 8: Focuses on another aspect +``` + +### Visual Representation + +``` +Input: "Hello World" + +Head 1 (Syntax): + "Hello" → attends to "World" (subject-object relationship) + +Head 2 (Semantics): + "Hello" → attends to "World" (greeting relationship) + +Head 3 (Position): + "Hello" → attends more to itself (being first) + +... (other heads) + +Final: Combine all heads → Richer representation +``` + +### Why Multiple Heads? + +**Different heads capture different relationships:** + +- **Head 1:** Grammatical relationships +- **Head 2:** Semantic relationships +- **Head 3:** Positional relationships +- **Head 4:** Other patterns... + +**Together:** Comprehensive understanding! + +--- + +## 2.7 Visual Representation of Attention + +### Attention Heatmap + +``` +Attention Weights for "Hello World" + + Position 0 Position 1 + ("Hello") ("World") + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +Position 0 │ 0.62 │ │ 0.38 │ +("Hello") ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +Position 1 │ 0.40 │ │ 0.60 │ +("World") ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ +``` + +**Reading:** +- Row 0: "Hello" attends 62% to itself, 38% to "World" +- Row 1: "World" attends 40% to "Hello", 60% to itself + +### Attention Flow Diagram + +``` +"Hello" ──── 0.38 ────→ "World" + ↑ ↑ + │ │ + 0.62 0.60 + │ │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + (self-attention) +``` + +**Meaning:** Information flows between words based on attention weights. + +--- + +## 2.8 Key Takeaways: Attention + +āœ… **Attention determines which words to focus on** +āœ… **Calculates similarity between words** +āœ… **Creates weighted combinations of information** +āœ… **Enables understanding of relationships** +āœ… **Multiple heads capture different aspects** + +--- + +*This document provides a step-by-step explanation of attention mechanisms, the core component that enables transformers to understand relationships between words.* + diff --git a/docs/BENCHMARKING_GUIDE.md b/docs/BENCHMARKING_GUIDE.md new file mode 100644 index 0000000..7c2729a --- /dev/null +++ b/docs/BENCHMARKING_GUIDE.md @@ -0,0 +1,757 @@ +# Inference Benchmarking Guide + +This guide explains how to use the benchmarking feature to compare optimized vs non-optimized inference performance for research purposes. + +## Overview + +The benchmarking feature runs inference both with and without optimizations (KV caching, optimized attention) and generates: + +- **Performance metrics** (tokens/sec, latency, memory usage) +- **Comparison plots** (visual charts showing improvements) +- **CSV export** (data for further analysis) + +## Data Storage Location + +**All benchmark data is saved to:** `./inference_benchmarks/` (default) + +**You can customize the location:** + +```bash +python inference.py --benchmark --benchmark-dir ./research/results +``` + +**Data files created:** + +- `inference_metrics.json` - All raw metrics (JSON format) +- `inference_metrics.csv` - Spreadsheet-friendly data (CSV format) +- `optimization_comparison.png` - Visual comparison charts +- `performance_over_time.png` - Trend analysis over multiple runs + +**Note:** All runs accumulate in the same files, so you can run multiple benchmarks and build trends over time. + +## Quick Start + +### Basic Benchmark + +```bash +python inference.py \ + --checkpoint checkpoints/best_checkpoint.pt \ + --prompt "The future of artificial intelligence" \ + --max-length 100 \ + --benchmark +``` + +This will: + +1. Run inference **without** optimizations +2. Run inference **with** optimizations (KV cache) +3. Collect metrics for both runs +4. Generate comparison plots +5. Save all data to `./inference_benchmarks/` + +### Custom Benchmark Directory + +```bash +python inference.py \ + --checkpoint checkpoints/best_checkpoint.pt \ + --prompt "Your prompt here" \ + --max-length 100 \ + --benchmark \ + --benchmark-dir ./research/results +``` + +### Running Multiple Prompts for Trends + +**Use the batch benchmark script** to run multiple prompts and create trends: + +```bash +# Create a prompts file +cat > prompts.txt << EOF +The future of artificial intelligence +Machine learning is transforming +Deep neural networks enable +Natural language processing requires +EOF + +# Run batch benchmarks +python benchmark_batch.py \ + --checkpoint checkpoints/best_checkpoint.pt \ + --prompt-file prompts.txt \ + --max-length 100 \ + --benchmark-dir ./research/results +``` + +**Or use command-line prompts:** + +```bash +python benchmark_batch.py \ + --checkpoint checkpoints/best_checkpoint.pt \ + --prompts "Prompt 1" "Prompt 2" "Prompt 3" \ + --max-length 100 +``` + +**Results accumulate** in the same files, allowing you to: + +- Build trends across multiple prompts +- Analyze performance consistency +- Create comprehensive research reports + +## Output Files + +After running a benchmark, you'll get: + +### 1. JSON Metrics File + +**Location:** `inference_benchmarks/inference_metrics.json` + +Contains all raw metrics data: + +```json +{ + "runs": [ + { + "run_name": "run_1234567890_optimized", + "optimized": true, + "tokens_per_second": 150.5, + "time_per_token": 6.64, + "memory_used_mb": 245.3, + ... + }, + ... + ] +} +``` + +### 2. CSV Export + +**Location:** `inference_benchmarks/inference_metrics.csv` + +For spreadsheet analysis: + +```csv +run_name,timestamp,optimized,prompt_length,generated_length,total_time,tokens_per_second,time_per_token,memory_used_mb,device +run_1234567890_optimized,1234567890.5,true,20,100,0.663,150.8,6.63,245.3,cuda +... +``` + +### 3. Comparison Plot + +**Location:** `inference_benchmarks/optimization_comparison.png` + +Shows 4 charts: + +- **Tokens per Second** (speed comparison) +- **Time per Token** (latency comparison) +- **Total Generation Time** (overall speed) +- **Memory Usage** (memory efficiency) + +### 4. Performance Over Time Plot + +**Location:** `inference_benchmarks/performance_over_time.png` + +Shows how performance varies across multiple benchmark runs. + +## Metrics Collected + +### Performance Metrics + +- **Tokens per Second**: Generation speed +- **Time per Token**: Latency per token (milliseconds) +- **Total Time**: Complete generation time + +### Resource Metrics + +- **Memory Usage**: GPU memory consumption (MB) +- **Device**: Device used (cuda/cpu/mps) + +### Derived Metrics + +- **Speedup**: Ratio of optimized vs non-optimized speed +- **Memory Reduction**: Percentage reduction in memory usage + +## Example Output + +``` +šŸ”¬ BENCHMARK MODE: Comparing optimized vs non-optimized inference +====================================================================== + +BENCHMARK RUN: run_1234567890 +====================================================================== + +šŸ”“ Running NON-OPTIMIZED inference... + ā±ļø Total Time: 1.234 s + šŸ“Š Tokens/Second: 81.0 + ⚔ Time/Token: 12.35 ms + šŸ’¾ Memory Used: 512.3 MB + šŸ“ Generated: The future of artificial intelligence is bright... + +🟢 Running OPTIMIZED inference... + ā±ļø Total Time: 0.663 s + šŸ“Š Tokens/Second: 150.8 + ⚔ Time/Token: 6.63 ms + šŸ’¾ Memory Used: 245.3 MB + šŸ“ Generated: The future of artificial intelligence is bright... + +šŸš€ SPEEDUP: 1.86x faster with optimizations +šŸ’¾ MEMORY REDUCTION: 52.1% + +šŸ“Š Generating comparison plots and data... +šŸ“Š Comparison plot saved to: ./inference_benchmarks/optimization_comparison.png +šŸ“Š Performance over time plot saved to: ./inference_benchmarks/performance_over_time.png +šŸ“Š Metrics exported to CSV: ./inference_benchmarks/inference_metrics.csv + +āœ… Benchmark complete! Results saved to: ./inference_benchmarks +``` + +## Running Multiple Benchmarks for Trends + +### Method 1: Individual Runs (Manual) + +```bash +# Run 1 +python inference.py --checkpoint checkpoints/best.pt --prompt "Prompt 1" --benchmark + +# Run 2 +python inference.py --checkpoint checkpoints/best.pt --prompt "Prompt 2" --benchmark + +# Run 3 +python inference.py --checkpoint checkpoints/best.pt --prompt "Prompt 3" --max-length 200 --benchmark +``` + +All runs accumulate in the same files: + +- `inference_metrics.json` - All runs appended +- `inference_metrics.csv` - All runs in CSV format +- Plots update automatically with new data + +### Method 2: Batch Script (Recommended) + +**Create a prompts file:** + +```bash +cat > research_prompts.txt << EOF +The future of artificial intelligence is bright. +Machine learning models are becoming more efficient. +Deep neural networks can process complex patterns. +Natural language processing enables human-computer interaction. +Transformer architectures revolutionized NLP. +EOF +``` + +**Run batch benchmarks:** + +```bash +python benchmark_batch.py \ + --checkpoint checkpoints/best_checkpoint.pt \ + --prompt-file research_prompts.txt \ + --max-length 100 \ + --benchmark-dir ./research/results \ + --delay 2.0 +``` + +**Benefits:** + +- āœ… Runs all prompts automatically +- āœ… Accumulates data for trend analysis +- āœ… Creates comprehensive performance reports +- āœ… Handles errors gracefully + +**After running multiple benchmarks:** + +- Check `performance_over_time.png` for trends +- Analyze `inference_metrics.csv` in Excel/Python +- Review aggregated statistics in console output + +## Research Use Cases + +### 1. Performance Analysis + +Compare how optimizations affect inference speed: + +```bash +python inference.py \ + --checkpoint checkpoints/best.pt \ + --prompt "Your research prompt" \ + --benchmark +``` + +### 2. Memory Efficiency Study + +Analyze memory usage improvements: + +```bash +# Check memory reduction +python inference.py --checkpoint checkpoints/best.pt --prompt "Long prompt" --max-length 500 --benchmark +``` + +### 3. Scalability Testing + +Test with different generation lengths: + +```bash +# Short sequences +python inference.py --checkpoint checkpoints/best.pt --prompt "Test" --max-length 50 --benchmark + +# Medium sequences +python inference.py --checkpoint checkpoints/best.pt --prompt "Test" --max-length 200 --benchmark + +# Long sequences +python inference.py --checkpoint checkpoints/best.pt --prompt "Test" --max-length 1000 --benchmark +``` + +## Plot Interpretation + +### Comparison Plot (`optimization_comparison.png`) + +**Top Left - Tokens per Second:** + +- Higher is better +- Shows generation speed +- Speedup annotation shows improvement factor + +**Top Right - Time per Token:** + +- Lower is better +- Shows latency per token +- Important for real-time applications + +**Bottom Left - Total Generation Time:** + +- Lower is better +- Overall generation time +- Most user-visible metric + +**Bottom Right - Memory Usage:** + +- Lower is better +- GPU memory consumption +- Memory reduction annotation shows savings + +### Performance Over Time Plot (`performance_over_time.png`) + +Shows performance trends across multiple benchmark runs: + +- **Green line**: Optimized performance +- **Red line**: Non-optimized performance +- Useful for finding performance regressions or improvements + +## Reporting Results + +### Speedup Calculation + +``` +Speedup = Optimized Tokens/Second / Non-Optimized Tokens/Second +``` + +**Example:** + +- Optimized: 150 tokens/sec +- Non-Optimized: 81 tokens/sec +- Speedup: 150/81 = 1.85x faster + +### Memory Reduction Calculation + +``` +Memory Reduction % = (1 - Optimized Memory / Non-Optimized Memory) Ɨ 100 +``` + +**Example:** + +- Optimized: 245 MB +- Non-Optimized: 512 MB +- Reduction: (1 - 245/512) Ɨ 100 = 52.1% + +## Tips for Best Results + +1. **Warm Up GPU**: Run a few inference calls before benchmarking to warm up the GPU +2. **Clear Cache**: The benchmark automatically clears CUDA cache between runs +3. **Multiple Runs**: Run multiple benchmarks for statistical significance +4. **Consistent Prompts**: Use the same prompt for fair comparison +5. **Device Consistency**: Use the same device for all runs + +## Command Line Options + +```bash +python inference.py \ + --checkpoint PATH # Path to model checkpoint (required) + --prompt TEXT # Prompt text (required) + --max-length INT # Maximum generation length (default: 100) + --temperature FLOAT # Sampling temperature (default: 1.0) + --top-k INT # Top-k sampling (default: 50) + --top-p FLOAT # Top-p sampling (default: 0.95) + --device DEVICE # Device: cuda/cpu/mps (default: cuda) + --benchmark # Enable benchmarking mode + --benchmark-dir DIR # Benchmark output directory (default: ./inference_benchmarks) +``` + +## Troubleshooting + +### No GPU Memory Stats + +If memory stats show as `None`: + +- CUDA: Memory tracking should work automatically +- MPS (Apple Silicon): Memory tracking not available +- CPU: Memory tracking not available + +### Plots Not Generated + +If plots fail to generate: + +- Ensure `matplotlib` is installed: `pip install matplotlib` +- Check file permissions for output directory + +### Inconsistent Results + +For consistent results: + +- Use same device for all runs +- Use same prompt length +- Allow GPU to warm up +- Close other GPU applications + +## Example Research Workflow + +```bash +# 1. Run initial benchmark +python inference.py --checkpoint checkpoints/best.pt --prompt "Test prompt" --benchmark + +# 2. Review results +ls inference_benchmarks/ +cat inference_benchmarks/inference_metrics.json + +# 3. Generate plots (already done automatically) +# View: inference_benchmarks/optimization_comparison.png + +# 4. Analyze CSV data +# Open: inference_benchmarks/inference_metrics.csv in Excel/Python + +# 5. Run additional benchmarks +python inference.py --checkpoint checkpoints/best.pt --prompt "Different prompt" --max-length 200 --benchmark + +# 6. Compare results +python inference.py --checkpoint checkpoints/best.pt --prompt "Same prompt" --benchmark +``` + +## Optimization Architecture & Code Injection + +### Overview: Optimization Layers + +The optimizations are implemented as layers that wrap the standard inference pipeline: + +```mermaid +flowchart TB + subgraph subGraph0["Standard Inference (Non-Optimized)"] + B["Tokenize"] + A["Input Prompt"] + C["Embedding Layer"] + D["Transformer Blocks"] + E["Attention: Recompute All"] + F["Forward Pass: O(n²)"] + G["Output Tokens"] + H["Detokenize"] + I["Generated Text"] + end + subgraph subGraph1["Optimized Inference (With KV Cache)"] + B2["Tokenize"] + A2["Input Prompt"] + C2["Embedding Layer"] + D2["Transformer Blocks"] + E2["Optimized Attention"] + F2["KV Cache Layer"] + G2["Forward Pass: O(n)"] + H2["Output Tokens"] + I2["Detokenize"] + J2["Generated Text"] + end + A --> B + B --> C + C --> D + D --> E + E --> F + F --> G + G --> H + H --> I + A2 --> B2 + B2 --> C2 + C2 --> D2 + D2 --> E2 + E2 --> F2 + F2 --> G2 + G2 --> H2 + H2 --> I2 + I2 --> J2 + style E fill:#ffcccc + style F fill:#ffcccc + style E2 fill:#ccffcc + style F2 fill:#ccffcc +``` + +### Detailed Optimization Flow + +```mermaid + +flowchart LR + subgraph subGraph0["Request Flow"] + Mode{"Optimized?"} + Start["Benchmark Request"] + Standard["Standard Path"] + Optimized["Optimized Path"] + end + subgraph subGraph1["Standard Path"] + S1["Model.generate"] + S2["Transformer Forward"] + S3["MultiHeadAttention"] + S4["Compute Q, K, V"] + S5["Recompute All KVs"] + S6["Attention Scores: O(n²)"] + S7["Generate Token"] + end + subgraph subGraph2["Optimized Path"] + O1["OptimizedInference"] + O2["Init KV Cache"] + O3["Transformer Forward"] + O4["OptimizedMultiHeadAttention"] + O5["Compute Q, K, V"] + O6["KV Cache Layer"] + O7["Append to Cache"] + O8["Reuse Cached KVs"] + O9["Attention Scores: O(n)"] + O10["Generate Token"] + end + Start --> Mode + Mode -- No --> Standard + Mode -- Yes --> Optimized + Standard --> S1 + S1 --> S2 + S2 --> S3 + S3 --> S4 + S4 --> S5 + S5 --> S6 + S6 --> S7 + Optimized --> O1 + O1 --> O2 + O2 --> O3 + O3 --> O4 + O4 --> O5 + O5 --> O6 + O6 --> O7 + O7 --> O8 + O8 --> O9 + O9 --> O10 + S7 --> Metrics["Collect Metrics"] + O10 --> Metrics + style Standard fill:#ffcccc + style Optimized fill:#ccffcc + style S5 fill:#ffcccc + style O8 fill:#ccffcc + +``` + +### Code Injection Points + +```mermaid +graph TB + subgraph "Standard Model Architecture" + A[TransformerModel] --> B[TransformerBlock] + B --> C[MultiHeadAttention] + C --> D[Q, K, V Projections] + D --> E[Attention Computation] + E --> F[Output Projection] + F --> G[Feed Forward] + end + + subgraph "Optimization Injection Points" + H[OptimizedInference Wrapper] --> A + A --> B2[TransformerBlock] + B2 --> C2[OptimizedMultiHeadAttention] + C2 --> D2[Q, K, V Projections] + D2 --> I[KV Cache Injection] + I --> E2[Optimized Attention] + E2 --> F2[Output Projection] + F2 --> G2[Feed Forward] + end + + subgraph "KV Cache Layer Details" + I --> J[Cache Check] + J --> K{Cache Exists?} + K -->|No| L[Compute K, V] + K -->|Yes| M[Retrieve from Cache] + L --> N[Store in Cache] + M --> O[Append New K, V] + N --> O + O --> P[Use Cached KVs] + end + + style H fill:#90EE90 + style I fill:#90EE90 + style K fill:#FFD700 + style P fill:#90EE90 +``` + +### Benchmark Execution Flow + +```mermaid +sequenceDiagram + participant User + participant InferenceScript + participant BenchmarkModule + participant OptimizedInference + participant StandardModel + participant MetricsCollector + + User->>InferenceScript: python inference.py --benchmark + InferenceScript->>BenchmarkModule: Initialize Metrics + BenchmarkModule->>MetricsCollector: Create InferenceMetrics + + Note over InferenceScript: Run 1: Non-Optimized + InferenceScript->>StandardModel: model.generate() + StandardModel->>StandardModel: Forward Pass (O(n²)) + StandardModel-->>InferenceScript: Generated Tokens + InferenceScript->>MetricsCollector: Log Run (optimized=false) + + Note over InferenceScript: Run 2: Optimized + InferenceScript->>OptimizedInference: get_optimized_inference() + OptimizedInference->>OptimizedInference: Init KV Cache + OptimizedInference->>OptimizedInference: generate_with_cache() + + loop For each token + OptimizedInference->>OptimizedInference: Forward Pass (O(n)) + OptimizedInference->>OptimizedInference: Update KV Cache + end + + OptimizedInference-->>InferenceScript: Generated Tokens + InferenceScript->>MetricsCollector: Log Run (optimized=true) + + MetricsCollector->>MetricsCollector: Calculate Speedup + MetricsCollector->>MetricsCollector: Generate Plots + MetricsCollector->>MetricsCollector: Export CSV + MetricsCollector-->>User: Results & Plots +``` + +### Optimization Components Stack + +```mermaid +graph TD + subgraph "Application Layer" + A[inference.py] --> B[benchmark_inference] + B --> C[Generate Text] + end + + subgraph "Optimization Layer" + C --> D{Optimized?} + D -->|Yes| E[OptimizedInference] + D -->|No| F[Standard Model] + E --> G[KV Cache Manager] + E --> H[Optimized Attention] + end + + subgraph "Core Model Layer" + F --> I[TransformerModel] + E --> I + I --> J[TransformerBlock] + J --> K[MultiHeadAttention] + H --> K + K --> L[Attention Computation] + end + + subgraph "Cache Layer" + G --> M[KVCache Data Structure] + M --> N[Keys Cache] + M --> O[Values Cache] + N --> P[Retrieve Previous K] + O --> Q[Retrieve Previous V] + end + + subgraph "Compute Layer" + L --> R[Q Ɨ K^T] + P --> R + Q --> R + R --> S[Softmax] + S --> T[Attention Weights] + T --> U[Output] + end + + style E fill:#90EE90 + style G fill:#90EE90 + style H fill:#90EE90 + style M fill:#FFD700 +``` + +### Performance Comparison Schema + +```mermaid + +flowchart LR + subgraph subGraph0["Metrics Collection"] + B["Non-Optimized Metrics"] + A["Benchmark Run"] + C["Optimized Metrics"] + D["Time: T1
Memory: M1
Speed: S1"] + E["Time: T2
Memory: M2
Speed: S2"] + end + subgraph Analysis["Analysis"] + F["Calculate Speedup"] + G["Speedup = S2/S1"] + H["Calculate Memory Reduction"] + I["Reduction = (M1-M2)/M1 Ɨ 100%"] + end + subgraph Visualization["Visualization"] + J["Comparison Plot"] + K["Trend Analysis"] + L["Performance Over Time"] + end + subgraph subGraph3["Data Export"] + M["JSON Metrics"] + N["CSV Export"] + end + A --> B & C + B --> D + C --> E + D --> F & H + E --> F & H + F --> G & K + H --> I + G --> J + I --> J + K --> L + J --> M & N + L --> M & N + style F fill:#FFD700 + style G fill:#90EE90 + style I fill:#90EE90 + +``` + +## Data File Locations Summary + +**All benchmark data is saved to:** + +``` +./inference_benchmarks/ +ā”œā”€ā”€ inference_metrics.json # All raw metrics (JSON) +ā”œā”€ā”€ inference_metrics.csv # Spreadsheet data (CSV) +ā”œā”€ā”€ optimization_comparison.png # Comparison charts +└── performance_over_time.png # Trend analysis +``` + +**Custom location:** + +```bash +--benchmark-dir ./research/results +``` + +**Data accumulates:** Each benchmark run appends to the same files, building trends over time. + +## Next Steps + +1. āœ… Run your first benchmark +2. āœ… Review the comparison plots +3. āœ… Analyze CSV data for deeper insights +4. āœ… Run multiple benchmarks for statistical analysis +5. āœ… Use batch script for trend analysis +6. āœ… Include results in your research paper/presentation + +--- + +**Happy Benchmarking!** šŸ“ŠšŸ”¬ diff --git a/docs/COMPLETE_GUIDE.md b/docs/COMPLETE_GUIDE.md new file mode 100644 index 0000000..8072bf2 --- /dev/null +++ b/docs/COMPLETE_GUIDE.md @@ -0,0 +1,854 @@ +# SheepOp LLM šŸ‘āž”ļøšŸ¤– + +A modern language model implementation from scratch, incorporating insights from recent research papers. + +## License + +This project is licensed under the **Apache License 2.0**. + +See [LICENSE](../LICENSE) or [LICENSE.txt](../LICENSE.txt) for the full license text. + +**Summary:** +- āœ… Free to use, modify, and distribute +- āœ… Commercial use allowed +- āœ… Patent grant included +- āœ… Private use allowed +- āš ļø Must include license and copyright notice +- āš ļø Must state changes if modifying + +## Table of Contents + +- [What is This?](#what-is-this) +- [Features](#features) +- [Quick Start](#quick-start) +- [Mathematical Foundations](#mathematical-foundations) +- [Architecture Explained](#architecture-explained) +- [Project Structure](#project-structure) +- [Installation](#installation) +- [Usage](#usage) +- [Configuration](#configuration) +- [Diagrams](#diagrams) +- [References](#references) + +## What is This? + +A Transformer-based language model implementing autoregressive next-token prediction using multi-head self-attention, positional encoding, and modern training optimizations (mixed precision, gradient accumulation, KV caching). + +The model learns to write by reading large amounts of text, discovering patterns like "after 'the cat' usually comes 'sat' or 'ran'", enabling it to generate coherent sentences. It processes text sequentially, predicting each next word based on the context provided by previous words. + +## Features + +- **Transformer Architecture**: Multi-head self-attention mechanism from "Attention Is All You Need" +- **Long Context Support**: Efficient handling of long sequences with Rotary Positional Encoding (RoPE) +- **Training Optimizations**: Mixed precision training, gradient accumulation, gradient clipping +- **Modern Best Practices**: Pre-norm architecture, GELU activation, weight tying +- **Comprehensive Evaluation**: Perplexity, accuracy metrics, and generation utilities + +## Quick Start + +**Option 1: Automated Setup (Recommended)** + +```bash +# Run the setup script - it handles everything automatically! +./setup.sh + +# Then activate the virtual environment +source venv/bin/activate + +# Download data and train +python3 download_large_data.py wiki --version 103 +python3 train.py --data data/wikitext_103.txt --config config.json --device cuda +``` + +**Option 2: Manual Setup** + +```bash +# 1. Create and activate virtual environment (REQUIRED on modern Linux systems) +python3 -m venv venv +source venv/bin/activate + +# 2. Upgrade pip +pip install --upgrade pip + +# 3. Install dependencies +pip install -r requirements.txt + +# 4. Download data +python3 download_large_data.py wiki --version 103 + +# 5. Train +python3 train.py --data data/wikitext_103.txt --config config.json --device cuda + +# 6. Generate +python3 inference.py --checkpoint checkpoints/checkpoint_epoch_10.pt \ + --prompt "The future of artificial intelligence" --device cuda +``` + +**Note:** On modern Debian/Ubuntu systems (especially Python 3.12+), pip prevents system-wide installations to protect system packages. Always use a virtual environment (`python3 -m venv venv`) before installing dependencies. + +See [GETTING_STARTED.md](GETTING_STARTED.md) for detailed instructions. + +## Mathematical Foundations + +### 1. Token Embedding + +Words are converted into numerical representations that the model can process. Each word (token) is assigned a unique ID, which is then mapped to a dense vector representation - think of it as converting words into a format the computer can understand and manipulate mathematically. + +**Mathematical Formulation**: + +Given a vocabulary of size $V$ and token ID $t \in \{0, 1, \ldots, V-1\}$: + +$$ +\mathbf{E}_t = \text{EmbeddingTable}[t] \in \mathbb{R}^{d_{\text{model}}} +$$ + +where $\mathbf{E} \in \mathbb{R}^{V \times d_{\text{model}}}$ is the learnable embedding matrix. + +### 2. Positional Encoding + +Word order is crucial in language - "Cat bites dog" is very different from "Dog bites cat". Since transformers process all tokens simultaneously, we need to inject positional information so the model knows which word comes first, second, third, etc. + +**Mathematical Formulation**: + +**Sinusoidal Positional Encoding** (for position $pos$ and dimension $i$): + +$$ +PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) +$$ + +$$ +PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) +$$ + +The final embedding combines token and position: + +$$ +\mathbf{h}_i = \mathbf{E}_{t_i} + PE(i) +$$ + +where $t_i$ is the token at position $i$. + +### 3. Multi-Head Self-Attention + +Attention mechanisms allow the model to understand relationships between words in a sentence. When reading "The cat sat on the mat", the model learns to connect "cat" with "sat" because they're related. Multi-head attention enables the model to focus on different types of relationships simultaneously - syntax, semantics, and context. + +**Mathematical Formulation**: + +Given input $\mathbf{X} \in \mathbb{R}^{n \times d_{\text{model}}}$ with $n$ tokens: + +1. **Project to Query, Key, Value**: + +```math + \mathbf{Q} = \mathbf{X}\mathbf{W}_Q, \quad \mathbf{K} = \mathbf{X}\mathbf{W}_K, \quad \mathbf{V} = \mathbf{X}\mathbf{W}_V + + where $\mathbf{W}_Q, \mathbf{W}_K, \mathbf{W}_V \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}}$ +``` + +2. **Scaled Dot-Product Attention**: + +```math + \text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}}\right)\mathbf{V} +``` + +3. **Multi-Head Attention** splits into $h$ heads: + +```math + \text{head}_i = \text{Attention}(\mathbf{Q}_i, \mathbf{K}_i, \mathbf{V}_i) + + \text{MultiHead}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)\mathbf{W}_O + + where $d_k = d_{\text{model}} / h$ (each head has dimension $d_k$). +``` + +4. **Causal Masking** (for autoregressive generation): + +```math + M_{ij} = \begin{cases} + 0 & \text{if } i \geq j \\ + -\infty & \text{if } i < j + \end{cases} + \text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}, M) = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}} + M\right)\mathbf{V} +``` + +### 4. Feed-Forward Network + +After attention identifies which words relate to each other, the feed-forward network performs non-linear transformations on the information. This step allows the model to synthesize and combine the attended information, similar to mixing ingredients to create a final output. + +**Mathematical Formulation**: + +```math +\text{FFN}(\mathbf{x}) = \text{GELU}(\mathbf{x}\mathbf{W}_1 + \mathbf{b}_1)\mathbf{W}_2 + \mathbf{b}_2 + +where $\mathbf{W}_1 \in \mathbb{R}^{d_{\text{model}} \times d_{ff}}$, $\mathbf{W}_2 \in \mathbb{R}^{d_{ff} \times d_{\text{model}}}$, and typically $d_{ff} = 4 \times d_{\text{model}}$. +``` + +**GELU Activation**: + +```math +\text{GELU}(x) = x \cdot \Phi(x) = x \cdot \frac{1}{2}\left(1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right) + +where $\Phi(x)$ is the cumulative distribution function of the standard normal distribution. +``` + +### 5. Transformer Block (Pre-Norm Architecture) + +**For Kids**: Imagine building with LEGO blocks. Each block does something special (attention or thinking), and you can stack many blocks on top of each other to build something amazing! + +**Mathematical Formulation**: + +For a transformer block with input $\mathbf{x}$: + +1. **Self-Attention Sublayer**: + +```math + \mathbf{x}' = \mathbf{x} + \text{Dropout}\left(\text{MultiHead}(\text{LayerNorm}(\mathbf{x}))\right) +``` + +2. **Feed-Forward Sublayer**: + +```math + \mathbf{x}'' = \mathbf{x}' + \text{Dropout}\left(\text{FFN}(\text{LayerNorm}(\mathbf{x}'))\right) +``` + +**Layer Normalization**: + +```math +\text{LayerNorm}(\mathbf{x}) = \gamma \odot \frac{\mathbf{x} - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta + +where $\mu = \frac{1}{d}\sum_{i=1}^d x_i$, $\sigma^2 = \frac{1}{d}\sum_{i=1}^d (x_i - \mu)^2$, and $\gamma, \beta$ are learnable parameters. +``` + +### 6. Complete Forward Pass + +**For Kids**: Think of it like an assembly line! The words go through many stations (layers), each doing something special, until finally we get a prediction of what word comes next. + +**Mathematical Formulation**: + +Given input token sequence $\mathbf{t} = [t_1, t_2, \ldots, t_n]$: + +$$ +\mathbf{h}_0 = \mathbf{E}[\mathbf{t}] + PE(\mathbf{t}) \quad \text{(token embeddings + positional encoding)} +$$ + +$$ +\mathbf{h}_1 = \text{TransformerBlock}_1(\mathbf{h}_0) +$$ + +$$ +\mathbf{h}_2 = \text{TransformerBlock}_2(\mathbf{h}_1) +$$ + +$$ +\vdots +$$ + +$$ +\mathbf{h}_L = \text{TransformerBlock}_L(\mathbf{h}_{L-1}) +$$ + +$$ +\mathbf{o} = \text{LayerNorm}(\mathbf{h}_L) \quad \text{(final normalization)} +$$ + +$$ +\mathbf{y} = \mathbf{o}\mathbf{W}_{\text{out}} \quad \text{(output projection)} +$$ + +$$ +p(t_{n+1} | t_1, \ldots, t_n) = \text{softmax}(\mathbf{y}_n) \quad \text{(next token probability)} +$$ + +### 7. Training Objective (Cross-Entropy Loss) + +**For Kids**: When we train, we show the computer a sentence and ask "what word comes next?" If it guesses wrong, we say "try again!" and it learns from its mistake. + +**Mathematical Formulation**: + +For a sequence of tokens $\mathbf{t} = [t_1, t_2, \ldots, t_n]$: + +$$ +\mathcal{L} = -\frac{1}{n-1}\sum_{i=1}^{n-1} \log p(t_{i+1} | t_1, \ldots, t_i) +$$ + +**Perplexity** (measure of model confidence): + +$$ +\text{Perplexity} = \exp(\mathcal{L}) = \exp\left(-\frac{1}{n-1}\sum_{i=1}^{n-1} \log p(t_{i+1} | t_1, \ldots, t_i)\right) +$$ + +Lower perplexity = better model! + +### 8. Text Generation (Autoregressive Sampling) + +**For Kids**: The computer writes one word at a time. After each word, it thinks "what word makes sense next?" and picks one. + +**Mathematical Formulation**: + +Given prompt $\mathbf{p} = [p_1, \ldots, p_k]$: + +1. Initialize: $\mathbf{s} = \mathbf{p}$ +2. For $i = k+1, \ldots, k+n$: + +$$ +\mathbf{h}_i = \text{Transformer}(\mathbf{s}) +$$ + +$$ +\mathbf{p}_i = \text{softmax}(\mathbf{h}_i / \tau) \quad \text{(with temperature $\tau \geq 1$)} +$$ + +$$ +t_i \sim \text{Sample}(\mathbf{p}_i) \quad \text{(sample next token)} +$$ + +$$ +\mathbf{s} = \mathbf{s} \cup \{t_i\} \quad \text{(append to sequence)} +$$ + +**Top-k Sampling**: + +$$ +\mathbf{p}_i' = \begin{cases} +\mathbf{p}_i[j] & \text{if } j \in \text{top}_k(\mathbf{p}_i) \\ +0 & \text{otherwise} +\end{cases} +$$ + +**Top-p (Nucleus) Sampling**: +Find smallest set $S$ such that $\sum_{j \in S} \mathbf{p}_i[j] \geq p$, then: + +$$ +\mathbf{p}_i'[j] = \begin{cases} +\mathbf{p}_i[j] & \text{if } j \in S \\ +0 & \text{otherwise} +\end{cases} +$$ + +### 9. Optimization (AdamW) + +**For Kids**: Learning is like climbing a mountain. You take small steps in the right direction. AdamW is like having a smart guide that knows which direction is best and adjusts your step size automatically! + +**Mathematical Formulation**: + +For parameter $\theta$ with gradient $\mathbf{g}_t$ at step $t$: + +**Momentum**: + +$$ +\mathbf{m}_t = \beta_1 \mathbf{m}_{t-1} + (1 - \beta_1) \mathbf{g}_t +$$ + +**RMSprop**: + +$$ +\mathbf{v}_t = \beta_2 \mathbf{v}_{t-1} + (1 - \beta_2) \mathbf{g}_t^2 +$$ + +**Bias Correction**: + +$$ +\hat{\mathbf{m}}_t = \frac{\mathbf{m}_t}{1 - \beta_1^t}, \quad \hat{\mathbf{v}}_t = \frac{\mathbf{v}_t}{1 - \beta_2^t} +$$ + +**Parameter Update**: + +$$ +\theta_t = \theta_{t-1} - \eta \left(\frac{\hat{\mathbf{m}}_t}{\sqrt{\hat{\mathbf{v}}_t} + \epsilon} + \lambda \theta_{t-1}\right) +$$ + +where $\eta$ is learning rate, $\lambda$ is weight decay, and typically $\beta_1 = 0.9$, $\beta_2 = 0.999$. + +### 10. Gradient Clipping + +**For Kids**: Sometimes the computer gets too excited and tries to learn too fast (like running too fast down a hill). We slow it down so it doesn't fall! + +**Mathematical Formulation**: + +$$ +\mathbf{g}_{\text{clipped}} = \begin{cases} +\mathbf{g} & \text{if } \|\mathbf{g}\| \leq \theta_{\max} \\ +\mathbf{g} \cdot \frac{\theta_{\max}}{\|\mathbf{g}\|} & \text{if } \|\mathbf{g}\| > \theta_{\max} +\end{cases} +$$ + +where $\theta_{\max}$ is the maximum gradient norm (typically 1.0). + +## Architecture Explained + +### High-Level Overview + +**For Kids**: + +``` +šŸ“š Text Input → šŸ”¤ Turn into Numbers → 🧠 Thinking Layers → ✨ Predict Next Word → šŸ“ Generate Text +``` + +**For Scientists**: + +``` +Token IDs → Embeddings → Positional Encoding → NƗ Transformer Blocks → Output Projection → Logits → Sampling → Text +``` + +### Detailed Architecture + +``` +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ INPUT TEXT │ +│ "The cat sat on the mat" │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + │ + ā–¼ +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ TOKENIZATION │ +│ [1, 45, 123, 67, 45, 234] (each word → number) │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + │ + ā–¼ +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ TOKEN EMBEDDING │ +│ E ∈ ā„^(VƗd_model) | Each token → vector of size d_model │ +│ [1, 45, ...] → [[0.1, 0.3, ...], [0.2, -0.1, ...], ...] │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + │ + ā–¼ +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ POSITIONAL ENCODING │ +│ PE(pos, 2i) = sin(pos / 10000^(2i/d_model)) │ +│ h_i = E[t_i] + PE(i) (add position info) │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + │ + ā–¼ +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ TRANSFORMER BLOCK 1 │ +│ ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” │ +│ │ Pre-Norm Multi-Head Attention │ │ +│ │ Attention(Q, K, V) = softmax(QK^T/√d_k)V │ │ +│ │ + Residual Connection │ │ +│ ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ │ +│ ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” │ +│ │ Pre-Norm Feed-Forward Network │ │ +│ │ FFN(x) = GELU(xW₁ + b₁)Wā‚‚ + bā‚‚ │ │ +│ │ + Residual Connection │ │ +│ ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + │ + ā–¼ +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ TRANSFORMER BLOCK 2 │ +│ (Same structure as Block 1) │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + │ + ā–¼ + ... (N-2 more blocks) ... + │ + ā–¼ +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ TRANSFORMER BLOCK N │ +│ (Same structure) │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + │ + ā–¼ +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ FINAL LAYER NORM │ +│ LayerNorm(h_L) │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + │ + ā–¼ +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ OUTPUT PROJECTION │ +│ y = h_L Ā· W_out (W_out ∈ ā„^(d_model Ɨ V)) │ +│ Output: logits for each token in vocabulary │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + │ + ā–¼ +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ SOFTMAX │ +│ p(t | context) = softmax(y) │ +│ Probability distribution over vocabulary │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + │ + ā–¼ +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ SAMPLING │ +│ t_{next} ~ Sample(p) (with temperature, top-k, top-p) │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + │ + ā–¼ +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ OUTPUT TEXT │ +│ "The cat sat on the mat and..." │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ +``` + +## Project Structure + +``` +sheepOp/ +ā”œā”€ā”€ models/ # Model architectures +│ ā”œā”€ā”€ __init__.py # Module exports +│ ā”œā”€ā”€ transformer.py # Main transformer model +│ ā”œā”€ā”€ attention.py # Attention mechanisms +│ ā”œā”€ā”€ blocks.py # Building blocks (FFN, TransformerBlock) +│ ā”œā”€ā”€ optimized_attention.py # KV caching, optimized inference +│ └── prefetching.py # Data prefetching utilities +ā”œā”€ā”€ data/ # Data loading utilities +│ └── __init__.py # Dataset and tokenizer +ā”œā”€ā”€ training/ # Training utilities +│ ā”œā”€ā”€ __init__.py # Trainer class +│ └── metrics.py # Training metrics and plotting +ā”œā”€ā”€ config.py # Configuration management +ā”œā”€ā”€ config.json # Example configuration file +ā”œā”€ā”€ train.py # Training script +ā”œā”€ā”€ inference.py # Inference script +ā”œā”€ā”€ example.py # Usage examples +ā”œā”€ā”€ utils.py # Evaluation utilities +└── requirements.txt # Dependencies +``` + +## Installation + +**Important:** On modern Debian/Ubuntu systems (Python 3.12+), you must use a virtual environment. The system prevents system-wide pip installations to protect system packages. + +```bash +# 1. Create virtual environment +python3 -m venv venv + +# 2. Activate virtual environment +source venv/bin/activate + +# 3. Upgrade pip +pip install --upgrade pip + +# 4. Install dependencies +pip install -r requirements.txt + +# 5. For large dataset downloads (optional) +pip install datasets +``` + +**Alternative:** Use the automated setup script which handles everything: +```bash +./setup.sh +source venv/bin/activate +``` + +## Usage + +### Training + +```bash +python3 train.py \ + --data data/amazon_reviews.txt \ + --config config.json \ + --device cuda +``` + +### Inference + +```bash +python3 inference.py \ + --checkpoint checkpoints/checkpoint_epoch_10.pt \ + --prompt "The future of artificial intelligence" \ + --optimized \ + --device cuda +``` + +## Configuration + +See `config.json` for all available settings. Key parameters: + +- **Model**: `vocab_size`, `d_model`, `num_layers`, `num_heads` +- **Training**: `batch_size`, `learning_rate`, `max_epochs` +- **Data**: `max_length`, `data_dir` + +## Diagrams + +### Complete Model Architecture Flow + +```mermaid +graph TB + subgraph "Input Processing" + A[Text Input] --> B[Tokenization] + B --> C[Token Embedding
E ∈ ā„^VƗd] + C --> D[Positional Encoding
PE pos,2i = sin pos/10000^2i/d] + D --> E[Embedding + Position
h = E + PE] + end + + subgraph "Transformer Stack" + E --> F[Transformer Block 1] + F --> G[Transformer Block 2] + G --> H[...] + H --> I[Transformer Block N] + end + + subgraph "Transformer Block Detail" + F --> F1[Layer Norm 1] + F1 --> F2[Multi-Head Attention
QK^T/√d_k → softmax → V] + F2 --> F3[Residual + Dropout] + F --> F3 + F3 --> F4[Layer Norm 2] + F4 --> F5[Feed-Forward
GELU xW₁ + b₁Wā‚‚ + bā‚‚] + F5 --> F6[Residual + Dropout] + F3 --> F6 + end + + subgraph "Output Processing" + I --> J[Final Layer Norm] + J --> K[Output Projection
y = hW_out] + K --> L[Softmax
p = softmax y] + L --> M[Sampling
t ~ p with temp, top-k, top-p] + M --> N[Generated Text] + end + + style A fill:#e1f5ff + style N fill:#fff4e1 + style F fill:#e8f5e9 + style F2 fill:#ffe1f5 + style F5 fill:#ffe1f5 +``` + +### Attention Mechanism Visualization + +```mermaid +graph LR + subgraph "Input" + X[Input Tokens
x₁, xā‚‚, ..., xā‚™] + end + + subgraph "Q, K, V Projections" + X --> Q[Query Q
Q = XW_Q] + X --> K[Key K
K = XW_K] + X --> V[Value V
V = XW_V] + end + + subgraph "Attention Computation" + Q --> M[Matrix Multiply
QK^T] + K --> M + M --> S[Scale
÷√d_k] + S --> Mask{Causal Mask?} + Mask -->|Yes| CM[Mask
M_ij = -āˆž if i < j] + Mask -->|No| SM[Skip Mask] + CM --> Soft[Softmax
exp scores / Σexp] + SM --> Soft + Soft --> WV[Weighted Sum
Attention = softmax scores Ɨ V] + V --> WV + end + + subgraph "Multi-Head" + WV --> Split[Split into h heads] + Split --> H1[Head 1] + Split --> H2[Head 2] + Split --> H3[...] + Split --> Hh[Head h] + H1 --> Concat[Concatenate] + H2 --> Concat + H3 --> Concat + Hh --> Concat + Concat --> Out[Output Projection
WO] + end + + style X fill:#e1f5ff + style Out fill:#fff4e1 + style Soft fill:#ffe1f5 + style WV fill:#e8f5e9 +``` + +### Training Loop Flow + +```mermaid +graph TD + A[Start Training] --> B[Load Data] + B --> C[Create Tokenizer] + C --> D[Build Vocabulary] + D --> E[Create DataLoader] + E --> F[Initialize Model] + F --> G[Setup Optimizer
AdamW] + G --> H[Setup LR Scheduler
Cosine Annealing] + + H --> I[For Each Epoch] + I --> J[For Each Batch] + + J --> K[Forward Pass
y = Model x] + K --> L[Compute Loss
L = -log p t_next] + L --> M[Backward Pass
āˆ‡L] + M --> N[Gradient Accumulation] + N --> O{Accumulation
Complete?} + O -->|No| J + O -->|Yes| P[Gradient Clipping
Clip gradient norm] + P --> Q[Optimizer Step
Īø = Īø - Ī·āˆ‡L] + Q --> R[LR Scheduler Step] + R --> S{End of
Epoch?} + S -->|No| J + S -->|Yes| T[Evaluate on
Validation Set] + T --> U[Compute Perplexity
exp L] + U --> V{Best
Model?} + V -->|Yes| W[Save Checkpoint] + V -->|No| X[Save Regular Checkpoint] + W --> Y{More
Epochs?} + X --> Y + Y -->|Yes| I + Y -->|No| Z[Training Complete] + + style A fill:#e1f5ff + style Z fill:#fff4e1 + style K fill:#e8f5e9 + style L fill:#ffe1f5 + style P fill:#fff4e1 +``` + +### Inference Flow with Sampling + +```mermaid +graph TD + A[Load Checkpoint] --> B[Initialize Model] + B --> C[Set Eval Mode] + C --> D[Input Prompt] + D --> E[Tokenize
t = t1, t2, ..., tk] + E --> F[Encode
h = E t + PE] + + F --> G[Forward Pass
y = Model h] + G --> H[Get Logits
y ∈ ā„^V] + H --> I[Apply Temperature
p = softmax y/Ļ„] + + I --> J{Top-k
Filter?} + J -->|Yes| K[Keep Top-k
p' = filter p, k] + J -->|No| L[p' = p] + K --> M{Top-p
Filter?} + L --> M + M -->|Yes| N[Nucleus Sampling
p' = filter p', p] + M -->|No| O[p'' = p'] + N --> O + + O --> P[Sample Token
t_i ~ p''] + P --> Q[Append to Sequence
s = s ∪ t_i] + Q --> R{Max Length
Reached?} + R -->|No| G + R -->|Yes| S[Decode Tokens
text = decode s] + S --> T[Output Text] + + style A fill:#e1f5ff + style T fill:#fff4e1 + style G fill:#e8f5e9 + style P fill:#ffe1f5 +``` + +### Component Interaction + +```mermaid +graph TB + subgraph "Configuration" + CFG[config.py
Config Classes] + CFGJSON[config.json
JSON Settings] + end + + subgraph "Models" + TRANS[transformer.py
TransformerModel] + ATT[attention.py
MultiHeadAttention] + BLOCKS[blocks.py
TransformerBlock, FFN] + OPT[optimized_attention.py
KV Cache, Optimized Inference] + end + + subgraph "Data" + DATA[data/__init__.py
TextDataset, SimpleTokenizer] + TOKEN[SimpleTokenizer
encode/decode] + DATASET[TextDataset
PyTorch Dataset] + end + + subgraph "Training" + TRAIN[training/__init__.py
Trainer] + METRICS[training/metrics.py
TrainingMetrics] + end + + subgraph "Scripts" + TRAIN_SCRIPT[train.py
Training Entry Point] + INFER[inference.py
Generation Script] + EX[example.py
Usage Examples] + end + + subgraph "Utils" + UTILS[utils.py
Evaluation Functions] + end + + CFG --> TRAIN_SCRIPT + CFGJSON --> TRAIN_SCRIPT + TRANS --> TRAIN_SCRIPT + TRANS --> INFER + ATT --> TRANS + BLOCKS --> TRANS + OPT --> TRANS + DATA --> TRAIN_SCRIPT + DATA --> INFER + TOKEN --> DATASET + DATASET --> TRAINER + TRAIN --> TRAIN_SCRIPT + METRICS --> TRAIN + UTILS --> TRAIN_SCRIPT + + style TRANS fill:#e1f5ff + style TRAINER fill:#fff4e1 + style DATA fill:#e8f5e9 + style ATT fill:#ffe1f5 +``` + +## Mathematical Summary Table + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ConceptFormulaDescription
Token Embedding$$\mathbf{E}_t = \text{EmbeddingTable}[t]$$Maps token ID to vector
Positional Encoding$$PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d}}\right)$$Adds position information
Attention$$\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$Computes attention weights
Feed-Forward$$\text{FFN}(x) = \text{GELU}(xW_1 + b_1)W_2 + b_2$$Non-linear transformation
Layer Norm$$\text{LayerNorm}(x) = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$$Normalizes activations
Loss$$\mathcal{L} = -\frac{1}{n}\sum_{i=1}^{n} \log p(t_{i+1} \mid t_1, \ldots, t_i)$$Cross-entropy loss
Perplexity$$\text{Perplexity} = \exp(\mathcal{L})$$Measure of model confidence
AdamW Update$$\theta_t = \theta_{t-1} - \eta\left(\frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda\theta_{t-1}\right)$$Optimizer step
+ +## References + +- **Attention Is All You Need** (Vaswani et al., 2017) - Original Transformer paper +- **Optimizing LLM Inference and Retrieval** - Production RAG systems optimizations +- **RoPE: Rotary Position Embedding** - Efficient positional encoding for long sequences +- Various papers on LLM training, hallucinations, and long context handling + +--- + +**For Kids**: You've learned how computers can read and write! Just like you practice writing stories, computers practice by reading millions of books. The more they practice, the better they get! šŸŽ‰ + +**For Scientists**: This implementation follows modern best practices including pre-norm architecture, weight tying, mixed precision training, and efficient inference optimizations. The codebase is modular, well-documented, and optimized for both training and production deployment. diff --git a/docs/CONTROL_SYSTEM_MODEL.md b/docs/CONTROL_SYSTEM_MODEL.md new file mode 100644 index 0000000..90a7719 --- /dev/null +++ b/docs/CONTROL_SYSTEM_MODEL.md @@ -0,0 +1,2574 @@ +# SheepOp LLM - Mathematical Control System Model + +Complete mathematical control system formulation of the SheepOp Language Model, treating the entire system as a unified mathematical control system with state-space representations, transfer functions, and step-by-step explanations. + +## Table of Contents + +1. [System Overview](#1-system-overview) +2. [State-Space Representation](#2-state-space-representation) +3. [Tokenizer as Input Encoder](#3-tokenizer-as-input-encoder) +4. [Seed Control System](#4-seed-control-system) +5. [Embedding Layer Control](#5-embedding-layer-control) +6. [Positional Encoding State](#6-positional-encoding-state) +7. [Self-Attention Control System](#7-self-attention-control-system) +8. [Feed-Forward Control](#8-feed-forward-control) +9. [Layer Normalization Feedback](#9-layer-normalization-feedback) +10. [Complete System Dynamics](#10-complete-system-dynamics) +11. [Training as Optimization Control](#11-training-as-optimization-control) +12. [Inference Control Loop](#12-inference-control-loop) + +--- + +## 1. System Overview + +### 1.1 Control System Architecture + +The SheepOp LLM can be modeled as a **nonlinear dynamical control system** with: + +- **Input**: Character sequence $\mathbf{c} = [c_1, c_2, ..., c_n]$ +- **State**: Hidden representations $\mathbf{h}\_t $at each layer and time step +- **Control**: Model parameters $\theta = \{W_Q, W_K, W_V, W_1, W_2, ...\} + $ +- **Output**: Probability distribution over vocabulary $\mathbf{p}\_t \in \mathbb{R}^V$ + +**System Block Diagram:** + +``` +Input Sequence → Tokenizer → Embeddings → Positional Encoding → + ↓ + [Transformer Layer 1] → [Transformer Layer 2] → ... → [Transformer Layer L] + ↓ + Output Projection → Logits → Softmax → Output Probabilities +``` + +### 1.2 Mathematical System Formulation + +The complete system can be expressed as: + +```math + +\mathbf{y}_t = \mathcal{F}(\mathbf{x}_t, \mathbf{h}_t, \theta, \mathbf{s}) + +``` + +where: + +- $\mathbf{x}\_t $= input at time$ t$ +- $\mathbf{h}\_t $= hidden state at time$ t$ +- $\theta $= system parameters (weights) +- $\mathbf{s} $= seed for randomness +- $\mathcal{F} $= complete forward function + +--- + +## 2. State-Space Representation + +### 2.1 Discrete-Time State-Space Model + +For a transformer with L layers and sequence length n : + +**State Vector:** + +```math +\mathbf{H}_t = \begin{bmatrix} +\mathbf{h}_t^{(1)} \\ +\mathbf{h}_t^{(2)} \\ +\vdots \\ +\mathbf{h}_t^{(L)} +\end{bmatrix} \in \mathbb{R}^{L \times n \times d} +``` + +where + +$\mathbf{h}_t^{(l)} \in \mathbb{R}^{n \times d} is the hidden state at layer l .$ + +**State Update Equation:** + +```math + +\mathbf{h}_t^{(l+1)} = f_l(\mathbf{h}_t^{(l)}, \theta_l), \quad l = 0, 1, ..., L-1 + + +where f_l is the transformation at layer l . +``` + +**Output Equation:** + +```math + +\mathbf{y}_t = g(\mathbf{h}_t^{(L)}, \theta_{out}) + +``` + +### 2.2 System Linearity Analysis + +The system is **nonlinear** due to: + +- Attention mechanism (softmax) +- Activation functions (GELU) +- Layer normalization + +However, individual components can be analyzed as **piecewise linear** systems. + +--- + +## 3. Tokenizer as Input Encoder + +### 3.1 Tokenizer Control Function + +The tokenizer maps a character sequence to a discrete token sequence: + +```math + +\mathcal{T}: \mathcal{C}^* \rightarrow \mathbb{N}^* + +``` + +**Mathematical Formulation:** + +For input sequence $\mathbf{c} = [c_1, c_2, ..., c_n] $: + +```math + +\mathbf{t} = \mathcal{T}(\mathbf{c}) = [V(c_1), V(c_2), ..., V(c_n)] + + +where V: \mathcal{C} \rightarrow \mathbb{N} is the vocabulary mapping function. +``` + +### 3.2 Vocabulary Mapping Function + +```math + +V(c) = \begin{cases} +0 & \text{if } c = \text{} \\ +1 & \text{if } c = \text{} \\ +2 & \text{if } c = \text{} \\ +3 & \text{if } c = \text{} \\ +v & \text{if } c \in \mathcal{C}_{vocab} +\end{cases} + +``` + +**Control Properties:** + +- **Deterministic**: Same input always produces same output +- **Invertible**: For most tokens, $V^{-1}$ exists +- **Bijective**: Each character maps to unique token ID + +### 3.3 Tokenizer State Space + +The tokenizer maintains internal state: + +```math + +\Sigma_{\mathcal{T}} = \{V, V^{-1}, \text{padding\_strategy}, \text{max\_length}\} + +``` + +**State Transition:** + +```math + +\Sigma_{\mathcal{T}}' = \Sigma_{\mathcal{T}} \quad \text{(static during operation)} + +``` + +### 3.4 Step-by-Step Explanation + +**Step 1: Character Extraction** + +- Input: Raw text string "Hello" +- Process: Extract each character $c \in \{'H', 'e', 'l', 'l', 'o'\}$ +- Meaning: Break down text into atomic units + +**Step 2: Vocabulary Lookup** + +- Process: Apply $V(c)$ to each character +- Example: $V('H') = 72, V('e') = 101, V('l') = 108, V('o') = 111$ +- Meaning: Convert characters to numerical indices + +**Step 3: Sequence Formation** + +- Output: $\mathbf{t} = [72, 101, 108, 108, 111]$ +- Meaning: Numerical representation ready for embedding + +**Control Impact**: Tokenizer creates the **foundation** for all subsequent processing. Any error here propagates through the entire system. + +--- + +## 4. Seed Control System + +### 4.1 Seed as System Initialization + +The seed $s \in \mathbb{N}$ controls **randomness** throughout the system: + +```math + +\mathcal{R}(\mathbf{x}, s) = \text{deterministic\_random}(\mathbf{x}, s) + +``` + +### 4.2 Seed Propagation Function + +**Initialization:** + +```math + +\text{seed\_torch}(s): \text{torch.manual\_seed}(s) + + +\text{seed\_cuda}(s): \text{torch.cuda.manual\_seed\_all}(s) + + +\text{seed\_cudnn}(s): \text{torch.backends.cudnn.deterministic} = \text{True} + +``` + +**Mathematical Model:** + +```math + +\mathbb{P}(\mathbf{W} | s) = \begin{cases} +\delta(\mathbf{W} - \mathbf{W}_s) & \text{if deterministic} \\ +\text{some distribution} & \text{if stochastic} +\end{cases} + + +where \delta is the Dirac delta and \mathbf{W}_s is the weight initialization given seed s . +``` + +### 4.3 Seed Control Equation + +For weight initialization: + +```math + +\mathbf{W}_0 = \mathcal{I}(\mathbf{s}, \text{init\_method}) + + +where \mathcal{I} is the initialization function. +``` + +**Example - Normal Initialization:** + +```math + +\mathbf{W}_0 \sim \mathcal{N}(0, \sigma^2) \quad \text{with random state } r(s) + + + +W_{ij} = \sigma \cdot \Phi^{-1}(U_{ij}(s)) + + +where: +- \mathcal{N}(0, \sigma^2) = normal distribution +- \Phi^{-1} = inverse CDF +- U_{ij}(s) = uniform random number from seed s +- \sigma = 0.02 (typical value) +``` + +### 4.4 Step-by-Step Explanation + +**Step 1: Seed Input** + +- Input: $s = 42$ +- Meaning: Provides reproducibility guarantee + +**Step 2: RNG State Initialization** + +- Process: Set all random number generators to state based on $s$ +- Meaning: Ensures deterministic behavior + +**Step 3: Weight Initialization** + +- Process: Generate all weights using RNG with seed $s$ +- Example: $W\_{ij} = \text{normal}(0, 0.02, \text{seed}=42)$ +- Meaning: Starting point for optimization + +**Step 4: Training Determinism** + +- Process: Same seed + same data → same gradients → same updates +- Meaning: Complete reproducibility + +**Control Impact**: Seed controls **initial conditions** and **stochastic processes** throughout training. It's the **control parameter** for reproducibility. + +--- + +## 5. Embedding Layer Control + +### 5.1 Embedding as Linear Transformation + +The embedding layer performs a **lookup operation**: + +```math + +\mathcal{E}: \mathbb{N} \rightarrow \mathbb{R}^d + +``` + +**Mathematical Formulation:** + +```math + +\mathbf{E} \in \mathbb{R}^{V \times d} \quad \text{(embedding matrix)} + + + +\mathbf{x}_t = \mathbf{E}[\mathbf{t}_t] = \mathbf{E}_t \in \mathbb{R}^d + + +where \mathbf{t}_t \in \mathbb{N} is the token ID at position t . +``` + +### 5.2 Embedding Control System + +**Batch Processing:** + +```math + +\mathbf{X} = \mathbf{E}[\mathbf{T}] \in \mathbb{R}^{B \times n \times d} + + +where \mathbf{T} \in \mathbb{N}^{B \times n} is the batch of token IDs. +``` + +**Control Function:** + +```math + +\mathbf{X} = \mathcal{E}(\mathbf{T}, \mathbf{E}) + +``` + +**Gradient Flow:** + +```math + +\frac{\partial \mathcal{L}}{\partial \mathbf{E}} = \sum_{b,t} \frac{\partial \mathcal{L}}{\partial \mathbf{X}_{b,t}} \cdot \mathbf{1}[\mathbf{T}_{b,t}] + + +where \mathbf{1}[\mathbf{T}_{b,t}] is a one-hot indicator. +``` + +### 5.3 Step-by-Step Explanation + +**Step 1: Token ID Input** + +- Input: $t = 72$ (token ID for 'H') +- Meaning: Discrete index into vocabulary + +**Step 2: Matrix Lookup** + +- Process: $\mathbf{x} = \mathbf{E}[72]$ +- Example: $\mathbf{x} = [0.1, -0.2, 0.3, ..., 0.05] \in \mathbb{R}^{512}$ +- Meaning: Continuous vector representation + +**Step 3: Semantic Encoding** + +- Property: Similar tokens have similar embeddings (after training) +- Meaning: Embeddings capture semantic relationships + +**Control Impact**: Embedding layer **projects** discrete tokens into continuous space, enabling gradient-based optimization. + +--- + +## 6. Positional Encoding State + +### 6.1 Positional Encoding as Additive Control + +```math + +\mathbf{X}_{pos} = \mathbf{X} + \mathbf{PE} \in \mathbb{R}^{B \times n \times d} + + +where \mathbf{PE} \in \mathbb{R}^{n \times d} is the positional encoding matrix. +``` + +### 6.2 Positional Encoding Function + +```math + +PE_{(pos, i)} = \begin{cases} +\sin\left(\frac{pos}{10000^{2i/d}}\right) & \text{if } i \text{ is even} \\ +\cos\left(\frac{pos}{10000^{2(i-1)/d}}\right) & \text{if } i \text{ is odd} +\end{cases} + +``` + +### 6.3 Control System Interpretation + +**Additive Control:** + +```math + +\mathbf{X}_{out} = \mathbf{X}_{in} + \mathbf{U}_{pos} + + +where \mathbf{U}_{pos} is the **control input** representing position information. +``` + +**Meaning**: Positional encoding **injects** positional information into the embeddings. + +### 6.4 Step-by-Step Explanation + +**Step 1: Position Index** + +- Input: Position $pos = 0, 1, 2, ..., n-1$ +- Meaning: Absolute position in sequence + +**Step 2: Encoding Generation** + +- Process: Compute $PE\_{(pos, i)}$ for each dimension $ i$ +- Example: $PE*{(0, 0)} = 0, PE*{(0, 1)} = 1, PE\_{(1, 0)} \approx 0.84$ +- Meaning: Unique pattern for each position + +**Step 3: Addition Operation** + +- Process: $\mathbf{X}\_{pos} = \mathbf{X} + PE$ +- Meaning: Position information added to embeddings + +**Step 4: Multi-Scale Representation** + +- Property: Different dimensions encode different frequency scales +- Meaning: Model can learn both local and global positional patterns + +**Control Impact**: Positional encoding provides **temporal/spatial awareness** to the model, enabling it to understand sequence order. + +--- + +## 7. Self-Attention Control System + +### 7.1 Attention as Information Routing + +Self-attention can be modeled as a **dynamical control system** that routes information: + +```math + +\mathbf{O} = \text{Attention}(\mathbf{X}, \mathbf{W}_Q, \mathbf{W}_K, \mathbf{W}_V) + +``` + +### 7.2 State-Space Model for Attention + +**Query, Key, Value Generation:** + +```math + +\mathbf{Q} = \mathbf{X} \mathbf{W}_Q \in \mathbb{R}^{B \times n \times d} + + +\mathbf{K} = \mathbf{X} \mathbf{W}_K \in \mathbb{R}^{B \times n \times d} + + +\mathbf{V} = \mathbf{X} \mathbf{W}_V \in \mathbb{R}^{B \times n \times d} + +``` + +**Attention Scores (Transfer Function):** + +```math + +\mathbf{S} = \frac{\mathbf{Q} \mathbf{K}^T}{\sqrt{d_k}} \in \mathbb{R}^{B \times h \times n \times n} + +``` + +**Attention Weights (Control Signal):** + +```math + +\mathbf{A} = \text{softmax}(\mathbf{S}) \in \mathbb{R}^{B \times h \times n \times n} + +``` + +**Output (Controlled Response):** + +```math + +\mathbf{O} = \mathbf{A} \mathbf{V} \in \mathbb{R}^{B \times h \times n \times d_k} + +``` + +### 7.3 Control System Interpretation + +**Attention as Feedback Control:** + +```math + +\mathbf{O}_i = \sum_{j=1}^{n} A_{ij} \mathbf{V}_j + + +where A_{ij} is the **control gain** determining how much information flows from position j to position i . +``` + +**Meaning**: Attention acts as a **learnable routing mechanism** controlled by similarities between queries and keys. + +### 7.4 Multi-Head Attention Control + +**Head Splitting:** + +```math + +\mathbf{Q}_h = \mathbf{Q}[:, :, h \cdot d_k : (h+1) \cdot d_k] \in \mathbb{R}^{B \times n \times d_k} + +``` + +**Parallel Processing:** + +```math + +\mathbf{O}_h = \text{Attention}(\mathbf{Q}_h, \mathbf{K}_h, \mathbf{V}_h), \quad h = 1, ..., H + +``` + +**Concatenation:** + +```math + +\mathbf{O} = \text{Concat}[\mathbf{O}_1, \mathbf{O}_2, ..., \mathbf{O}_H] \in \mathbb{R}^{B \times n \times d} + +``` + +### 7.5 Causal Masking Control + +**Causal Mask:** + +```math + +M_{ij} = \begin{cases} +0 & \text{if } i \geq j \text{ (allowed)} \\ +-\infty & \text{if } i < j \text{ (masked)} +\end{cases} + +``` + +**Masked Attention:** + +```math + +\mathbf{S}_{masked} = \mathbf{S} + M + +``` + +**Effect**: Prevents information flow from future positions. + +### 7.6 Step-by-Step Explanation + +**Step 1: Query, Key, Value Generation** + +- Process: Linear transformations of input +- Meaning: Create three representations: what to look for (Q), what to match (K), what to retrieve (V) + +**Step 2: Similarity Computation** + +- Process: $S\_{ij} = Q_i \cdot K_j / \sqrt{d_k}$ +- Meaning: Measure similarity/relevance between positions $i$ and $ j + $ + +**Step 3: Softmax Normalization** + +- Process: $A*{ij} = \exp(S*{ij}) / \sum*k \exp(S*{ik})$ +- Meaning: Convert similarities to probability distribution (attention weights) + +**Step 4: Weighted Aggregation** + +- Process: $O*i = \sum_j A*{ij} V_j$ +- Meaning: Combine values weighted by attention probabilities + +**Step 5: Information Flow** + +- Property: Each position receives information from all other positions (with causal masking) +- Meaning: Enables long-range dependencies and context understanding + +**Control Impact**: Self-attention is the **core control mechanism** that determines **what information flows where** in the sequence. + +--- + +## 8. Feed-Forward Control + +### 8.1 Feed-Forward as Nonlinear Transformation + +```math + +\text{FFN}(\mathbf{X}) = \text{GELU}(\mathbf{X} \mathbf{W}_1 + \mathbf{b}_1) \mathbf{W}_2 + \mathbf{b}_2 + +``` + +### 8.2 Control System Model + +**Two-Stage Transformation:** + +```math + +\mathbf{H} = \mathbf{X} \mathbf{W}_1 \in \mathbb{R}^{B \times n \times d_{ff}} + + + +\mathbf{H}' = \text{GELU}(\mathbf{H}) \in \mathbb{R}^{B \times n \times d_{ff}} + + + +\mathbf{O} = \mathbf{H}' \mathbf{W}_2 \in \mathbb{R}^{B \times n \times d} + +``` + +### 8.3 GELU Activation Control + +```math + +\text{GELU}(x) = x \cdot \Phi(x) = x \cdot \frac{1}{2}\left(1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right) + +``` + +**Control Interpretation**: GELU applies **smooth gating** - values near zero are suppressed, positive values pass through. + +### 8.4 Step-by-Step Explanation + +**Step 1: Expansion** + +- Process: $\mathbf{H} = \mathbf{X} \mathbf{W}_1 expands to d_{ff} > d$ +- Example: $d = 512 \rightarrow d\_{ff} = 2048$ +- Meaning: Increases capacity for complex transformations + +**Step 2: Nonlinear Activation** + +- Process: $\mathbf{H}' = \text{GELU}(\mathbf{H})$ +- Meaning: Introduces nonlinearity, enabling complex function approximation + +**Step 3: Compression** + +- Process: $\mathbf{O} = \mathbf{H}' \mathbf{W}\_2 $compresses back to$ d$ +- Meaning: Projects back to original dimension + +**Control Impact**: FFN provides **nonlinear processing power** and **feature transformation** at each position. + +--- + +## 9. Layer Normalization Feedback + +### 9.1 Normalization as Feedback Control + +```math + +\text{LayerNorm}(\mathbf{x}) = \gamma \odot \frac{\mathbf{x} - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta + + +where: +- \mu = \frac{1}{d} \sum_{i=1}^{d} x_i (mean) +- \sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2 (variance) +- \gamma, \beta = learnable parameters (scale and shift) +``` + +### 9.2 Control System Interpretation + +**Normalization as State Regulation:** + +```math + +\mathbf{x}_{norm} = \gamma \odot \frac{\mathbf{x} - \mu(\mathbf{x})}{\sigma(\mathbf{x})} + \beta + +``` + +**Meaning**: Normalization **regulates** the distribution of activations, preventing saturation and improving gradient flow. + +### 9.3 Pre-Norm Architecture + +**Transformer Block with Pre-Norm:** + +```math + +\mathbf{x}_{norm} = \text{LayerNorm}(\mathbf{x}_{in}) + + +\mathbf{x}_{attn} = \text{Attention}(\mathbf{x}_{norm}) + + +\mathbf{x}_{out} = \mathbf{x}_{in} + \mathbf{x}_{attn} \quad \text{(residual connection)} + +``` + +**Control Impact**: Pre-norm architecture provides **stability** and **better gradient flow**. + +### 9.4 Step-by-Step Explanation + +**Step 1: Mean Computation** + +- Process: $\mu = \frac{1}{d} \sum x_i$ +- Meaning: Find center of distribution + +**Step 2: Variance Computation** + +- Process: $\sigma^2 = \frac{1}{d} \sum (x_i - \mu)^2$ +- Meaning: Measure spread of distribution + +**Step 3: Normalization** + +- Process: $\hat{x}\_i = (x_i - \mu) / \sqrt{\sigma^2 + \epsilon}$ +- Meaning: Standardize to zero mean, unit variance + +**Step 4: Scale and Shift** + +- Process: $x\_{out} = \gamma \odot \hat{x} + \beta$ +- Meaning: Allow model to learn optimal scale and shift + +**Control Impact**: Layer normalization provides **stability** and **faster convergence** by maintaining consistent activation distributions. + +--- + +## 10. Complete System Dynamics + +### 10.1 Complete Forward Pass + +**System State Evolution:** + +```math + +\mathbf{h}_0 = \mathcal{E}(\mathbf{T}) + \mathbf{PE} \quad \text{(embedding + positional)} + + + +\mathbf{h}_l = \text{TransformerBlock}_l(\mathbf{h}_{l-1}), \quad l = 1, ..., L + + + +\mathbf{y} = \mathbf{h}_L \mathbf{W}_{out} \in \mathbb{R}^{B \times n \times V} + +``` + +### 10.2 Recursive System Equation + +```math + +\mathbf{h}_t^{(l)} = f_l(\mathbf{h}_t^{(l-1)}, \theta_l) + + +where: + + +f_l(\mathbf{x}, \theta_l) = \mathbf{x} + \text{Dropout}(\text{Attention}(\text{LayerNorm}(\mathbf{x}))) + \text{Dropout}(\text{FFN}(\text{LayerNorm}(\mathbf{x} + \text{Attention}(\text{LayerNorm}(\mathbf{x}))))) + +``` + +### 10.3 System Transfer Function + +The complete system can be viewed as: + +```math + +\mathbf{Y} = \mathcal{F}(\mathbf{T}, \theta, \mathbf{s}) + + +where: +- \mathbf{T} = input tokens +- \theta = all parameters +- \mathbf{s} = seed +``` + +**Properties:** + +- **Nonlinear**: Due to softmax, GELU, normalization +- **Differentiable**: All operations have gradients +- **Compositional**: Built from simpler functions + +### 10.4 Step-by-Step System Flow + +**Step 1: Input Encoding** + +- Input: Token sequence $\mathbf{T}$ +- Process: Embedding + Positional Encoding +- Output: $\mathbf{h}\_0 \in \mathbb{R}^{B \times n \times d}$ +- Meaning: Convert discrete tokens to continuous vectors with position info + +**Step 2: Layer Processing** + +- For each layer $l = 1, ..., L $: + - Process: Self-attention + FFN with residual connections + - Output: $\mathbf{h}\_l \in \mathbb{R}^{B \times n \times d}$ + - Meaning: Transform representations through attention and processing + +**Step 3: Output Generation** + +- Process: Final layer norm + output projection +- Output: $\mathbf{L} \in \mathbb{R}^{B \times n \times V} (logits)$ +- Meaning: Predict probability distribution over vocabulary + +**Step 4: Probability Computation** + +- Process: Softmax over logits +- Output: $\mathbf{p} \in \mathbb{R}^{B \times n \times V} + (probabilities)$ +- Meaning: Normalized probability distribution for next token prediction + +--- + +## 11. Training as Optimization Control + +### 11.1 Training as Optimal Control Problem + +**Objective Function:** + +```math + +J(\theta) = \frac{1}{N} \sum_{i=1}^{N} \mathcal{L}(\mathbf{y}_i, \hat{\mathbf{y}}_i(\theta)) + + +where: +- \mathcal{L} = loss function (cross-entropy) +- \mathbf{y}_i = true labels +- \hat{\mathbf{y}}_i(\theta) = model predictions +``` + +**Optimization Problem:** + +```math + +\theta^* = \arg\min_{\theta} J(\theta) + +``` + +### 11.2 Gradient-Based Control + +**Gradient Computation:** + +```math + +\mathbf{g}_t = \nabla_\theta J(\theta_t) = \frac{\partial J}{\partial \theta_t} + +``` + +**Parameter Update (AdamW):** + +```math + +\theta_{t+1} = \theta_t - \eta_t \left(\frac{\hat{\mathbf{m}}_t}{\sqrt{\hat{\mathbf{v}}_t} + \epsilon} + \lambda \theta_t\right) + + +where: +- \hat{\mathbf{m}}_t = biased-corrected momentum +- \hat{\mathbf{v}}_t = biased-corrected variance +- \eta_t = learning rate (controlled by scheduler) +- \lambda = weight decay coefficient +``` + +### 11.3 Learning Rate Control + +**Cosine Annealing Schedule:** + +```math + +\eta_t = \eta_{min} + (\eta_{max} - \eta_{min}) \cdot \frac{1 + \cos(\pi \cdot \frac{t}{T_{max}})}{2} + +``` + +**Control Interpretation**: Learning rate acts as **gain scheduling** - high gain initially for fast convergence, low gain later for fine-tuning. + +### 11.4 Gradient Clipping Control + +**Clipping Function:** + +```math + +\mathbf{g}_{clipped} = \begin{cases} +\mathbf{g} & \text{if } ||\mathbf{g}|| \leq \theta \\ +\mathbf{g} \cdot \frac{\theta}{||\mathbf{g}||} & \text{if } ||\mathbf{g}|| > \theta +\end{cases} + +``` + +**Purpose**: Prevents **explosive gradients** that could destabilize training. + +### 11.5 Step-by-Step Training Control + +**Step 1: Forward Pass** + +- Process: $\hat{\mathbf{y}} = \mathcal{F}(\mathbf{x}, \theta_t)$ +- Meaning: Compute predictions with current parameters + +**Step 2: Loss Computation** + +- Process: $\mathcal{L} = \text{CrossEntropy}(\hat{\mathbf{y}}, \mathbf{y})$ +- Meaning: Measure prediction error + +**Step 3: Backward Pass** + +- Process: $\mathbf{g} = \nabla\_\theta \mathcal{L}$ +- Meaning: Compute gradients for all parameters + +**Step 4: Gradient Clipping** + +- Process: $\mathbf{g}\_{clipped} = \text{Clip}(\mathbf{g}, \theta)$ +- Meaning: Prevent gradient explosion + +**Step 5: Optimizer Update** + +- Process: $\theta*{t+1} = \text{AdamW}(\theta_t, \mathbf{g}*{clipped}, \eta_t)$ +- Meaning: Update parameters using adaptive learning rate + +**Step 6: Learning Rate Update** + +- Process: $\eta\_{t+1} = \text{Scheduler}(\eta_t, t)$ +- Meaning: Adjust learning rate according to schedule + +**Control Impact**: Training process is a **closed-loop control system** where: + +- **Error signal**: Loss +- **Controller**: Optimizer (AdamW) +- **Actuator**: Parameter updates +- **Plant**: Model forward pass + +--- + +## 12. Inference Control Loop + +### 12.1 Autoregressive Generation as Control Loop + +**State-Space Model:** + +```math + +\mathbf{h}_t = \mathcal{F}(\mathbf{x}_t, \mathbf{h}_{t-1}, \theta) + + + +\mathbf{p}_t = \text{softmax}(\mathbf{h}_t \mathbf{W}_{out}) + + + +\mathbf{x}_{t+1} \sim \text{Categorical}(\mathbf{p}_t) + +``` + +### 12.2 Generation Control Function + +**Step-by-Step:** + +1. **Current State**: $\mathbf{h}\_t$ +2. **Output Generation**: $\mathbf{p}_t = \text{softmax}(\mathbf{h}\_t \mathbf{W}_{out})$ +3. **Sampling**: $x\_{t+1} \sim \mathbf{p}\_t (with temperature, top-k, top-p)$ +4. **State Update**: $\mathbf{h}_{t+1} = \mathcal{F}([\mathbf{h}\_t, x_{t+1}], \theta)$ +5. **Repeat**: Until max length or stop token + +### 12.3 Sampling Control Parameters + +**Temperature Control:** + +```math + +\mathbf{p}_t^{temp} = \text{softmax}\left(\frac{\mathbf{h}_t \mathbf{W}_{out}}{T}\right) + + +- T < 1 : More deterministic (sharp distribution) +- T > 1 : More random (flat distribution) +- T = 1 : Default +``` + +**Top-k Filtering:** + +```math + +\mathbf{p}_t^{topk}[v] = \begin{cases} +\mathbf{p}_t[v] & \text{if } v \in \text{top-k}(\mathbf{p}_t) \\ +0 & \text{otherwise} +\end{cases} + +``` + +**Top-p (Nucleus) Sampling:** + +```math + +\mathbf{p}_t^{topp}[v] = \begin{cases} +\mathbf{p}_t[v] & \text{if } v \in S_p \\ +0 & \text{otherwise} +\end{cases} + + +where S_p is the smallest set such that \sum_{v \in S_p} \mathbf{p}_t[v] \geq p . +``` + +### 12.4 Step-by-Step Inference Control + +**Step 1: Initialization** + +- Input: Prompt tokens $\mathbf{P} = [p_1, ..., p_k]$ +- Process: Initialize state $\mathbf{h}\_0 = \mathcal{E}(\mathbf{P}) + \mathbf{PE}$ +- Meaning: Set initial state from prompt + +**Step 2: Forward Pass** + +- Process: $\mathbf{h}_t = \text{Transformer}(\mathbf{h}_{t-1})$ +- Output: Hidden state $\mathbf{h}\_t$ +- Meaning: Process current sequence + +**Step 3: Logit Generation** + +- Process: $\mathbf{l}_t = \mathbf{h}\_t \mathbf{W}_{out}$ +- Output: Logits $\mathbf{l}\_t \in \mathbb{R}^V$ +- Meaning: Unnormalized scores for each token + +**Step 4: Probability Computation** + +- Process: $\mathbf{p}\_t = \text{softmax}(\mathbf{l}\_t / T)$ +- Output: Probability distribution $\mathbf{p}\_t$ +- Meaning: Normalized probabilities with temperature + +**Step 5: Sampling** + +- Process: $x\_{t+1} \sim \mathbf{p}\_t (with optional top-k/top-p)$ +- Output: Next token $x\_{t+1}$ +- Meaning: Stochastically select next token + +**Step 6: State Update** + +- Process: Append $x*{t+1}$ to sequence, update $\mathbf{h}*{t+1}$ +- Meaning: Incorporate new token into state + +**Step 7: Termination Check** + +- Condition: $t < \text{max_length} and x\_{t+1} \neq \text{}$ +- If true: Go to Step 2 +- If false: Return generated sequence + +**Control Impact**: Inference is a **recurrent control system** where: + +- **State**: Current hidden representation +- **Control**: Sampling strategy (temperature, top-k, top-p) +- **Output**: Generated token sequence + +--- + +## Summary: Unified Control System Model + +### Complete System Equation + +```math + +\mathbf{Y} = \mathcal{G}(\mathbf{C}, \theta, \mathbf{s}, \mathbf{T}, \{k, p\}) + + +where: +- \mathbf{C} = input characters +- \theta = model parameters +- \mathbf{s} = seed +- \mathbf{T} = temperature +- \{k, p\} = top-k and top-p parameters +``` + +### System Components as Control Elements + +1. **Tokenizer**: Input encoder $\mathcal{T}$ +2. **Seed**: Initialization control $\mathbf{s}$ +3. **Embeddings**: State projection $\mathcal{E}$ +4. **Positional Encoding**: Temporal control $\mathbf{PE}$ +5. **Attention**: Information routing $\mathcal{A}$ +6. **FFN**: Nonlinear transformation $\mathcal{F}$ +7. **Normalization**: State regulation $\mathcal{N}$ +8. **Optimizer**: Parameter control $\mathcal{O}$ +9. **Scheduler**: Learning rate control $\mathcal{S}$ +10. **Sampling**: Output control $\mathcal{P}$ + +### Control Flow Summary + +``` +Input Characters + ↓ [Tokenizer Control] +Token IDs + ↓ [Seed Control] +Initialized Parameters + ↓ [Embedding Control] +Vector Representations + ↓ [Positional Control] +Position-Aware Vectors + ↓ [Attention Control] +Context-Aware Representations + ↓ [FFN Control] +Transformed Features + ↓ [Normalization Control] +Stabilized Activations + ↓ [Output Control] +Probability Distributions + ↓ [Sampling Control] +Generated Tokens +``` + +Each component acts as a **control element** in a unified dynamical system, working together to transform input text into meaningful language model outputs. + +--- + +## 13. Block Diagram Analysis + +### 13.1 Single Transformer Block Control System + +**Block Diagram (a): Detailed Single Transformer Block** + +``` +Input X + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ LayerNorm │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Multi-Head │ + │ Attention │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Dropout │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ + │ ←─── (Residual Connection from X) + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ LayerNorm │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Feed-Forward│ + │ Network │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Dropout │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ + │ ←─── (Residual Connection) + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + Output X' +``` + +**Mathematical Transfer Function:** + +```math + +\mathbf{X}_{out} = \mathbf{X}_{in} + \text{Dropout}(\text{FFN}(\text{LayerNorm}(\mathbf{X}_{in} + \text{Dropout}(\text{Attention}(\text{LayerNorm}(\mathbf{X}_{in}))))) + +``` + +### 13.2 Simplified Transformer Block + +**Block Diagram (b): Simplified Single Block** + +``` +Input X + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ TransformerBlock │ + │ G_block(X) = X + Attn(LN(X)) + │ + │ FFN(LN(X + Attn(LN(X))))│ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + Output X' +``` + +**Transfer Function:** + +```math + +G_{block}(\mathbf{X}) = \mathbf{X} + G_{attn}(\text{LN}(\mathbf{X})) + G_{ffn}(\text{LN}(\mathbf{X} + G_{attn}(\text{LN}(\mathbf{X})))) + + +where: +- G_{attn} = Attention transfer function +- G_{ffn} = Feed-forward transfer function +- \text{LN} = Layer normalization +``` + +### 13.3 Complete Model with Multiple Layers + +**Block Diagram (c): Cascaded Transformer Blocks** + +``` +Input Tokens T + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Embedding │ + │ G_emb │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Positional │ + │ G_pos │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Block 1 │ + │ G_block₁ │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Block 2 │ + │ G_blockā‚‚ │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ ... │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Block L │ + │ G_block_L │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Final Norm │ + │ G_norm │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Output Proj │ + │ G_out │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + Output Logits +``` + +**Overall Transfer Function:** + +```math + +\mathbf{Y} = G_{out} \circ G_{norm} \circ G_{block_L} \circ ... \circ G_{block_2} \circ G_{block_1} \circ G_{pos} \circ G_{emb}(\mathbf{T}) + +``` + +### 13.4 Closed-Loop Training System + +**Block Diagram (d): Training Control Loop** + +``` +Input Data X + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Model │ + │ Forward │ + │ F │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Output │ + │ Å· │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Loss │ + │ L(Å·, y) │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Gradient │ + │ āˆ‡Īø │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Clipping │ + │ Clip │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Optimizer │ + │ AdamW │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Parameter │ + │ Update │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ - │ ←─── (Feedback to Model) + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ +``` + +**Closed-Loop Transfer Function:** + +```math + +\theta_{t+1} = \theta_t - \eta_t \cdot \text{AdamW}(\text{Clip}(\nabla_\theta L(\mathcal{F}(\mathbf{X}, \theta_t), \mathbf{y}))) + +``` + +--- + +## 14. Vector Visualization and Examples + +### 14.1 Example Phrase: "Hello World" + +We'll trace through the complete system with the phrase **"Hello World"**. + +#### Step 1: Tokenization + +**Input:** `"Hello World"` + +**Process:** + +``` +Characters: ['H', 'e', 'l', 'l', 'o', ' ', 'W', 'o', 'r', 'l', 'd'] +Token IDs: [72, 101, 108, 108, 111, 32, 87, 111, 114, 108, 100] +``` + +**Mathematical:** + +```math + +\mathbf{c} = \text{"Hello World"} + + +\mathbf{t} = \mathcal{T}(\mathbf{c}) = [72, 101, 108, 108, 111, 32, 87, 111, 114, 108, 100] + +``` + +**Vector Representation:** + +- Dimension: $n = 11$ tokens +- Token IDs: $\mathbf{t} \in \mathbb{N}^{11}$ + +#### Step 2: Embedding + +**Embedding Matrix:** $\mathbf{E} \in \mathbb{R}^{128 \times 512}$ + +**Lookup Operation:** + +```math + +\mathbf{X} = \mathbf{E}[\mathbf{t}] = \begin{bmatrix} +\mathbf{E}[72] \\ +\mathbf{E}[101] \\ +\mathbf{E}[108] \\ +\mathbf{E}[108] \\ +\mathbf{E}[111] \\ +\mathbf{E}[32] \\ +\mathbf{E}[87] \\ +\mathbf{E}[111] \\ +\mathbf{E}[114] \\ +\mathbf{E}[108] \\ +\mathbf{E}[100] +\end{bmatrix} \in \mathbb{R}^{11 \times 512} + +``` + +**Example Values (first 3 dimensions):** + +```math + +\mathbf{E}[72] = [0.1, -0.2, 0.3, ...]^T \\ +\mathbf{E}[101] = [-0.1, 0.3, -0.1, ...]^T \\ +\mathbf{E}[108] = [0.05, 0.15, -0.05, ...]^T + +``` + +**Vector Visualization:** + +``` +Token 'H' (ID=72): [0.10, -0.20, 0.30, ..., 0.05] (512-dim vector) +Token 'e' (ID=101): [-0.10, 0.30, -0.10, ..., 0.02] (512-dim vector) +Token 'l' (ID=108): [0.05, 0.15, -0.05, ..., 0.01] (512-dim vector) +... +``` + +#### Step 3: Positional Encoding + +**Positional Encoding Matrix:** $\mathbf{PE} \in \mathbb{R}^{11 \times 512}$ + +**Computation:** + +```math + +PE_{(0, 0)} = \sin(0 / 10000^0) = 0 \\ +PE_{(0, 1)} = \cos(0 / 10000^0) = 1 \\ +PE_{(1, 0)} = \sin(1 / 10000^0) = \sin(1) \approx 0.8415 \\ +PE_{(1, 1)} = \cos(1 / 10000^0) = \cos(1) \approx 0.5403 + +``` + +**Addition:** + +```math + +\mathbf{X}_{pos} = \mathbf{X} + \mathbf{PE} + +``` + +**Example (first token, first 3 dimensions):** + +```math + +\mathbf{X}_{pos}[0, :3] = \begin{bmatrix} +0.1 \\ -0.2 \\ 0.3 +\end{bmatrix} + \begin{bmatrix} +0 \\ 1 \\ 0 +\end{bmatrix} = \begin{bmatrix} +0.1 \\ 0.8 \\ 0.3 +\end{bmatrix} + +``` + +#### Step 4: Multi-Head Attention + +**Query, Key, Value Projections:** + +Let $\mathbf{W}\_Q, \mathbf{W}\_K, \mathbf{W}\_V \in \mathbb{R}^{512 \times 512}$ + +```math + +\mathbf{Q} = \mathbf{X}_{pos} \mathbf{W}_Q \in \mathbb{R}^{11 \times 512} + +``` + +**Example Calculation (head 0, token 0):** + +For $h = 0 , d_k = 512/8 = 64 $: + +```math + +\mathbf{Q}[0, :64] = \mathbf{X}_{pos}[0] \mathbf{W}_Q[:, :64] + +``` + +**Attention Score Computation:** + +```math + +S_{0,1} = \frac{\mathbf{Q}[0] \cdot \mathbf{K}[1]}{\sqrt{64}} = \frac{\sum_{i=0}^{63} Q_{0,i} \cdot K_{1,i}}{8} + +``` + +**Example Numerical Calculation:** + +Assume: + +```math + +\mathbf{Q}[0, :3] = [0.2, -0.1, 0.3] \\ +\mathbf{K}[1, :3] = [0.1, 0.2, -0.1] + + + +S_{0,1} = \frac{0.2 \times 0.1 + (-0.1) \times 0.2 + 0.3 \times (-0.1)}{8} \\ += \frac{0.02 - 0.02 - 0.03}{8} = \frac{-0.03}{8} = -0.00375 + +``` + +**Attention Weights:** + +```math + +A_{0,:} = \text{softmax}(S_{0,:}) = \frac{\exp(S_{0,:})}{\sum_{j=0}^{10} \exp(S_{0,j})} + +``` + +**Example:** + +If $S\_{0,:} = [-0.004, 0.05, 0.02, 0.02, 0.08, -0.01, 0.03, 0.08, 0.01, 0.02, 0.04]$ + +```math + +\exp(S_{0,:}) = [0.996, 1.051, 1.020, 1.020, 1.083, 0.990, 1.030, 1.083, 1.010, 1.020, 1.041] + + + +\sum = 11.335 + + + +A_{0,:} = [0.088, 0.093, 0.090, 0.090, 0.096, 0.087, 0.091, 0.096, 0.089, 0.090, 0.092] + +``` + +**Output Calculation:** + +```math + +\mathbf{O}[0] = \sum_{j=0}^{10} A_{0,j} \mathbf{V}[j] + +``` + +**Example (first dimension):** + +```math + +O_{0,0} = A_{0,0} V_{0,0} + A_{0,1} V_{1,0} + ... + A_{0,10} V_{10,0} \\ += 0.088 \times 0.2 + 0.093 \times 0.1 + ... + 0.092 \times 0.15 \\ +\approx 0.12 + +``` + +#### Step 5: Feed-Forward Network + +**Input:** $\mathbf{X}\_{attn} \in \mathbb{R}^{11 \times 512}$ + +**First Linear Transformation:** + +```math + +\mathbf{H} = \mathbf{X}_{attn} \mathbf{W}_1 \in \mathbb{R}^{11 \times 2048} + +``` + +**Example (token 0, first dimension):** + +```math + +H_{0,0} = \sum_{i=0}^{511} X_{attn,0,i} \cdot W_{1,i,0} + + +Assuming X_{attn}[0, :3] = [0.12, -0.05, 0.08] and W_1[:3, :3] = \begin{bmatrix} 0.1 & 0.2 \\ -0.1 & 0.1 \\ 0.05 & -0.05 \end{bmatrix} + + +H_{0,0} = 0.12 \times 0.1 + (-0.05) \times (-0.1) + 0.08 \times 0.05 \\ += 0.012 + 0.005 + 0.004 = 0.021 + +``` + +**GELU Activation:** + +```math + +\text{GELU}(0.021) = 0.021 \cdot \frac{1}{2}\left(1 + \text{erf}\left(\frac{0.021}{\sqrt{2}}\right)\right) + + + +\text{erf}(0.021/\sqrt{2}) = \text{erf}(0.0148) \approx 0.0167 + + + +\text{GELU}(0.021) = 0.021 \times 0.5 \times (1 + 0.0167) = 0.021 \times 0.5084 \approx 0.0107 + +``` + +**Second Linear Transformation:** + +```math + +\mathbf{O}_{ffn} = \mathbf{H}' \mathbf{W}_2 \in \mathbb{R}^{11 \times 512} + +``` + +#### Step 6: Complete Forward Pass Through One Layer + +**Input:** $\mathbf{X}_{in} = \mathbf{X}_{pos} \in \mathbb{R}^{11 \times 512}$ + +**Step 6.1: Layer Normalization** + +```math + +\mu_0 = \frac{1}{512} \sum_{i=0}^{511} X_{in,0,i} + +``` + +**Example:** + +```math + +\mu_0 = \frac{0.1 + 0.8 + 0.3 + ...}{512} \approx 0.02 + + + +\sigma_0^2 = \frac{1}{512} \sum_{i=0}^{511} (X_{in,0,i} - \mu_0)^2 + + + +\sigma_0^2 \approx \frac{(0.1-0.02)^2 + (0.8-0.02)^2 + ...}{512} \approx 0.15 + + + +\hat{X}_{0,0} = \frac{0.1 - 0.02}{\sqrt{0.15 + 1e-5}} = \frac{0.08}{0.387} \approx 0.207 + +``` + +**Step 6.2: Attention Output** + +```math + +\mathbf{X}_{attn} = \text{Attention}(\hat{\mathbf{X}}) + +``` + +**Step 6.3: Residual Connection** + +```math + +\mathbf{X}_{res1} = \mathbf{X}_{in} + \mathbf{X}_{attn} + +``` + +**Example:** + +```math + +X_{res1,0,0} = 0.1 + 0.12 = 0.22 + +``` + +**Step 6.4: Second Layer Norm + FFN** + +```math + +\mathbf{X}_{ffn} = \text{FFN}(\text{LayerNorm}(\mathbf{X}_{res1})) + +``` + +**Step 6.5: Final Residual** + +```math + +\mathbf{X}_{out} = \mathbf{X}_{res1} + \mathbf{X}_{ffn} + +``` + +**Example:** + +```math + +X_{out,0,0} = 0.22 + 0.15 = 0.37 + +``` + +#### Step 7: Output Projection + +**After L layers:** + +```math + +\mathbf{H}_{final} = \text{LayerNorm}(\mathbf{X}_{out}^{(L)}) \in \mathbb{R}^{11 \times 512} + +``` + +**Output Projection:** + +```math + +\mathbf{L} = \mathbf{H}_{final} \mathbf{W}_{out} \in \mathbb{R}^{11 \times 128} + +``` + +**Example (position 0):** + +```math + +L_{0,:} = \mathbf{H}_{final}[0] \mathbf{W}_{out} \in \mathbb{R}^{128} + +``` + +**Softmax:** + +```math + +p_{0,v} = \frac{\exp(L_{0,v})}{\sum_{w=0}^{127} \exp(L_{0,w})} + +``` + +**Example:** + +If $L*{0,72} = 5.2 (logit for 'H'), L*{0,101} = 3.1 (logit for 'e'), etc.$ + +```math + +\exp(5.2) = 181.27 \\ +\exp(3.1) = 22.20 \\ +\vdots + + + +\sum_{w=0}^{127} \exp(L_{0,w}) \approx 250.0 + + + +p_{0,72} = \frac{181.27}{250.0} \approx 0.725 \quad \text{(72\% probability for H)} + +``` + +--- + +## 15. Complete Numerical Example: "Hello" + +Let's trace through the complete system with **"Hello"** step-by-step. + +### Input: "Hello" + +### Stage 1: Tokenization + +```math + +\mathbf{c} = \text{"Hello"} = ['H', 'e', 'l', 'l', 'o'] + + + +\mathbf{t} = [72, 101, 108, 108, 111] + +``` + +### Stage 2: Embedding (d=512) + +```math + +\mathbf{E} \in \mathbb{R}^{128 \times 512} + + + +\mathbf{X} = \begin{bmatrix} +\mathbf{E}[72] \\ +\mathbf{E}[101] \\ +\mathbf{E}[108] \\ +\mathbf{E}[108] \\ +\mathbf{E}[111] +\end{bmatrix} = \begin{bmatrix} +0.10 & -0.20 & 0.30 & ... & 0.05 \\ +-0.10 & 0.30 & -0.10 & ... & 0.02 \\ +0.05 & 0.15 & -0.05 & ... & 0.01 \\ +0.05 & 0.15 & -0.05 & ... & 0.01 \\ +-0.05 & 0.20 & 0.10 & ... & 0.03 +\end{bmatrix} \in \mathbb{R}^{5 \times 512} + +``` + +### Stage 3: Positional Encoding + +```math + +\mathbf{PE} = \begin{bmatrix} +0 & 1 & 0 & ... & 0 \\ +0.84 & 0.54 & 0.01 & ... & 0.00 \\ +0.91 & -0.42 & 0.02 & ... & 0.00 \\ +0.14 & -0.99 & 0.03 & ... & 0.00 \\ +-0.76 & -0.65 & 0.04 & ... & 0.00 +\end{bmatrix} \in \mathbb{R}^{5 \times 512} + + + +\mathbf{X}_{pos} = \mathbf{X} + \mathbf{PE} = \begin{bmatrix} +0.10 & 0.80 & 0.30 & ... & 0.05 \\ +0.74 & 0.84 & -0.09 & ... & 0.02 \\ +0.96 & -0.27 & -0.03 & ... & 0.01 \\ +0.19 & -0.84 & -0.02 & ... & 0.01 \\ +-0.81 & -0.45 & 0.14 & ... & 0.03 +\end{bmatrix} + +``` + +### Stage 4: Attention (h=8 heads, d_k=64) + +**Query Generation:** + +```math + +\mathbf{Q} = \mathbf{X}_{pos} \mathbf{W}_Q \in \mathbb{R}^{5 \times 512} + +``` + +**Score Matrix (head 0):** + +```math + +\mathbf{S}_0 = \frac{\mathbf{Q}_0 \mathbf{K}_0^T}{\sqrt{64}} \in \mathbb{R}^{5 \times 5} + +``` + +**Example Values:** + +```math + +\mathbf{S}_0 = \begin{bmatrix} +0.50 & -0.10 & 0.20 & 0.15 & 0.30 \\ +-0.05 & 0.45 & 0.10 & 0.08 & 0.25 \\ +0.15 & 0.05 & 0.40 & 0.30 & 0.20 \\ +0.12 & 0.08 & 0.28 & 0.35 & 0.18 \\ +0.25 & 0.15 & 0.22 & 0.20 & 0.42 +\end{bmatrix} + +``` + +**Attention Weights:** + +```math + +\mathbf{A}_0 = \text{softmax}(\mathbf{S}_0) = \begin{bmatrix} +0.35 & 0.15 & 0.22 & 0.20 & 0.28 \\ +0.15 & 0.38 & 0.20 & 0.18 & 0.27 \\ +0.23 & 0.18 & 0.32 & 0.30 & 0.26 \\ +0.21 & 0.19 & 0.28 & 0.33 & 0.25 \\ +0.27 & 0.22 & 0.26 & 0.25 & 0.36 +\end{bmatrix} + +``` + +**Output (head 0):** + +```math + +\mathbf{O}_0 = \mathbf{A}_0 \mathbf{V}_0 \in \mathbb{R}^{5 \times 64} + +``` + +**Concatenate All Heads:** + +```math + +\mathbf{O} = \text{Concat}[\mathbf{O}_0, ..., \mathbf{O}_7] \in \mathbb{R}^{5 \times 512} + +``` + +### Stage 5: Feed-Forward + +```math + +\mathbf{H} = \mathbf{O} \mathbf{W}_1 \in \mathbb{R}^{5 \times 2048} + + + +\mathbf{H}' = \text{GELU}(\mathbf{H}) \in \mathbb{R}^{5 \times 2048} + + + +\mathbf{O}_{ffn} = \mathbf{H}' \mathbf{W}_2 \in \mathbb{R}^{5 \times 512} + +``` + +### Stage 6: Output Logits + +After processing through all L layers: + +```math + +\mathbf{L} = \mathbf{H}_{final} \mathbf{W}_{out} \in \mathbb{R}^{5 \times 128} + +``` + +**Example (position 4, predicting next token):** + +```math + +L_{4,:} = [2.1, 1.5, ..., 5.2, ..., 3.1, ...] + + +Where: +- L_{4,111} = 5.2 (high score for 'o') +- L_{4,32} = 4.8 (high score for space) +- L_{4,87} = 4.5 (high score for 'W') +``` + +**Probability Distribution:** + +```math + +\mathbf{p}_4 = \text{softmax}(L_{4,:}) = [0.01, 0.008, ..., 0.25, ..., 0.18, ...] + + + +p_{4,111} \approx 0.25 \quad \text{(25\% for o)} \\ +p_{4,32} \approx 0.22 \quad \text{(22\% for space)} \\ +p_{4,87} \approx 0.18 \quad \text{(18\% for W)} + +``` + +--- + +## 16. Vector Space Visualization + +### 16.1 Embedding Space + +**2D Projection Example:** + +After embedding "Hello", tokens occupy positions in 512-dimensional space. Projected to 2D: + +``` +Token Positions (idealized 2D projection): + + 'l' (0.05, 0.15) + ā— + + 'e' (-0.10, 0.30) + ā— + +Origin (0, 0) + ā— + + 'H' (0.10, -0.20) + ā— + + 'o' (-0.05, 0.20) + ā— +``` + +**Distance in Embedding Space:** + +```math + +d(\mathbf{E}[72], \mathbf{E}[101]) = ||\mathbf{E}[72] - \mathbf{E}[101]||_2 + + + +d = \sqrt{(0.1 - (-0.1))^2 + (-0.2 - 0.3)^2 + ...} \approx \sqrt{0.04 + 0.25 + ...} \approx 2.1 + +``` + +### 16.2 Attention Weight Visualization + +**Attention Matrix Visualization:** + +``` +Position 0 1 2 3 4 + ā”Œā”€ā”€ā”€ā”€ā”€ā”“ā”€ā”€ā”€ā”€ā”€ā”“ā”€ā”€ā”€ā”€ā”€ā”“ā”€ā”€ā”€ā”€ā”€ā”“ā”€ā”€ā” +Token 0 │ 0.35 0.15 0.22 0.20 0.28 │ 'H' + │ │ +Token 1 │ 0.15 0.38 0.20 0.18 0.27 │ 'e' + │ │ +Token 2 │ 0.23 0.18 0.32 0.30 0.26 │ 'l' + │ │ +Token 3 │ 0.21 0.19 0.28 0.33 0.25 │ 'l' + │ │ +Token 4 │ 0.27 0.22 0.26 0.25 0.36 │ 'o' + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ +``` + +**Interpretation:** + +- Token 0 ('H') attends most to itself (0.35) and token 4 (0.28) +- Token 4 ('o') attends moderately to all positions +- Higher values indicate stronger attention + +### 16.3 Probability Distribution Visualization + +**Output Distribution for Position 5 (next token after "Hello"):** + +``` +Probability Distribution p[5, :] + +Probability + │ +0.3 │ ā— + │ +0.2 │ ā— ā— + │ +0.1 │ ā— ā— ā— ā— + │ +0.0 ā”œā”€ā”“ā”€ā”€ā”€ā”“ā”€ā”€ā”€ā”“ā”€ā”€ā”€ā”“ā”€ā”€ā”€ā”“ā”€ā”€ā”€ā”“ā”€ā”€ā”€ā”“ā”€ā”€ā”€ā”“ā”€ā”€ā”€ Token IDs + 32 72 87 101 108 111 ... 127 + ␣ H W e l o +``` + +**Meaning:** + +- Highest probability for space (32) ā‰ˆ 0.28 +- Next: 'o' (111) ā‰ˆ 0.23 +- Then: 'W' (87) ā‰ˆ 0.18 +- Model predicts space or continuation + +--- + +## 17. Advanced Block Diagram Simplification + +### 17.1 Complex Multi-Layer System Simplification + +Following control system reduction techniques, we can simplify the transformer model step-by-step: + +**Diagram (a): Original Complex System** + +``` +Input R (Tokens) + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Embedding │ + │ G_emb │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Positional │ + │ Encoding │ + │ G_pos │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ + │ ←─── Feedback from Layer 2 + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Layer 1 │ + │ G_block₁ │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ + │ ←─── Feedback from Output + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Layer 2 │ + │ G_blockā‚‚ │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ + │ ←─── Feedback H₁ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Output Proj │ + │ G_out │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + Output C (Logits) +``` + +**Diagram (b): First Simplification (Combine Embedding and Positional)** + +``` +Input R + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ G_emb_pos = │ + │ G_pos ∘ G_emb │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ + │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Layer 1 │ + │ G_block₁ │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ + │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Layer 2 │ + │ G_blockā‚‚ │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ + │ ←─── H₁ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ G_out │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + Output C +``` + +**Diagram (c): Second Simplification (Combine Layers)** + +``` +Input R + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ G_emb_pos │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ G_layers = G_blockā‚‚ ∘ G_block₁ │ + │ Equivalent to: │ + │ X + Δ₁(X) + Δ₂(X + Δ₁(X)) │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ + │ ←─── H₁ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ G_out │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + Output C +``` + +**Diagram (d): Third Simplification (Combine with Output)** + +``` +Input R + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ G_forward = │ + │ G_out ∘ G_layers ∘ G_emb_pos │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ + │ ←─── H₁ (Feedback) + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + Output C +``` + +**Diagram (e): Final Simplified Transfer Function** + +``` +Input R + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Overall Transfer Function: │ + │ │ + │ C/R = G_forward / (1 + G_forward Ɨ H₁) │ + │ │ + │ Where: │ + │ G_forward = G_out ∘ G_layers ∘ G_emb_pos │ + │ │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + Output C +``` + +**Mathematical Derivation:** + +**Step 1:** Combine embedding and positional encoding: + +```math + +G_{emb\_pos}(\mathbf{T}) = G_{pos}(G_{emb}(\mathbf{T})) = \mathbf{E}[\mathbf{T}] + \mathbf{PE} + +``` + +**Step 2:** Combine transformer layers: + +```math + +G_{layers}(\mathbf{X}) = G_{block_2}(G_{block_1}(\mathbf{X})) + + + +G_{layers}(\mathbf{X}) = \mathbf{X} + \Delta_1(\mathbf{X}) + \Delta_2(\mathbf{X} + \Delta_1(\mathbf{X})) + + +where \Delta_l represents the transformation inside block l . +``` + +**Step 3:** Combine with output projection: + +```math + +G_{forward}(\mathbf{T}) = G_{out}(G_{layers}(G_{emb\_pos}(\mathbf{T}))) + +``` + +**Step 4:** Apply feedback reduction: + +```math + +\frac{C}{R} = \frac{G_{forward}}{1 + G_{forward} \times H_1} + +``` + +### 17.2 Attention Block Simplification + +**Diagram (a): Detailed Attention** + +``` +Input X + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Q │ ←─── W_Q + │ K │ ←─── W_K + │ V │ ←─── W_V + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Scores │ + │ S = QK^T/√d │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Softmax │ + │ A = σ(S) │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Output │ + │ O = AV │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Out Proj │ + │ W_O │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + Output X' +``` + +**Diagram (b): Simplified Attention Transfer Function** + +``` +Input X + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ G_attn(X) = │ + │ W_O Ā· softmax(QK^T/√d) Ā· V │ + │ │ + │ Where: │ + │ Q = XW_Q, K = XW_K, V = XW_V │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + Output X' +``` + +**Mathematical Transfer Function:** + +```math + +G_{attn}(\mathbf{X}) = \mathbf{X} \mathbf{W}_O \cdot \text{softmax}\left(\frac{(\mathbf{X} \mathbf{W}_Q)(\mathbf{X} \mathbf{W}_K)^T}{\sqrt{d_k}}\right) \cdot (\mathbf{X} \mathbf{W}_V) + +``` + +--- + +## 18. Vector Trace: "Hello World" Complete Flow + +### 18.1 Complete Vector Trace with Numerical Values + +**Input:** `"Hello World"` + +**Stage 1: Tokenization** + +```math + +\mathbf{t} = [72, 101, 108, 108, 111, 32, 87, 111, 114, 108, 100] + +``` + +**Stage 2: Embedding (showing first 4 dimensions)** + +```math + +\mathbf{X} = \begin{bmatrix} +[H] & 0.10 & -0.20 & 0.30 & 0.15 & ... \\ +[e] & -0.10 & 0.30 & -0.10 & 0.08 & ... \\ +[l] & 0.05 & 0.15 & -0.05 & 0.03 & ... \\ +[l] & 0.05 & 0.15 & -0.05 & 0.03 & ... \\ +[o] & -0.05 & 0.20 & 0.10 & 0.06 & ... \\ +[ ] & 0.02 & 0.05 & 0.02 & 0.01 & ... \\ +[W] & 0.15 & -0.15 & 0.25 & 0.12 & ... \\ +[o] & -0.05 & 0.20 & 0.10 & 0.06 & ... \\ +[r] & 0.08 & 0.10 & -0.08 & 0.04 & ... \\ +[l] & 0.05 & 0.15 & -0.05 & 0.03 & ... \\ +[d] & 0.12 & -0.08 & 0.18 & 0.09 & ... +\end{bmatrix} \in \mathbb{R}^{11 \times 512} + +``` + +**Stage 3: Positional Encoding (first 4 dimensions)** + +```math + +\mathbf{PE} = \begin{bmatrix} +[0] & 0.00 & 1.00 & 0.00 & 0.00 & ... \\ +[1] & 0.84 & 0.54 & 0.01 & 0.00 & ... \\ +[2] & 0.91 & -0.42 & 0.02 & 0.00 & ... \\ +[3] & 0.14 & -0.99 & 0.03 & 0.00 & ... \\ +[4] & -0.76 & -0.65 & 0.04 & 0.00 & ... \\ +[5] & -0.96 & 0.28 & 0.05 & 0.00 & ... \\ +[6] & -0.28 & 0.96 & 0.06 & 0.00 & ... \\ +[7] & 0.65 & 0.76 & 0.07 & 0.00 & ... \\ +[8] & 0.99 & -0.14 & 0.08 & 0.00 & ... \\ +[9] & 0.42 & -0.91 & 0.09 & 0.00 & ... \\ +[10] & -0.54 & -0.84 & 0.10 & 0.00 & ... +\end{bmatrix} + +``` + +**Stage 4: Combined Input** + +```math + +\mathbf{X}_{pos} = \mathbf{X} + \mathbf{PE} + +``` + +**Example Row 0 (token 'H'):** + +```math + +\mathbf{X}_{pos}[0, :4] = [0.10, -0.20, 0.30, 0.15] + [0.00, 1.00, 0.00, 0.00] = [0.10, 0.80, 0.30, 0.15] + +``` + +**Stage 5: Attention (Head 0, showing attention from token 0 to all tokens)** + +```math + +\mathbf{S}_0[0, :] = [0.50, -0.10, 0.20, 0.15, 0.30, -0.05, 0.18, 0.28, 0.12, 0.20, 0.22] + + + +\mathbf{A}_0[0, :] = \text{softmax}(\mathbf{S}_0[0, :]) = [0.35, 0.15, 0.22, 0.20, 0.28, 0.14, 0.19, 0.26, 0.17, 0.21, 0.23] + + +**Meaning:** Token 'H' (position 0) attends: +- 35% to itself +- 28% to token 'o' (position 4) +- 26% to token 'o' (position 7) +- 23% to token 'd' (position 10) +``` + +**Stage 6: Attention Output** + +```math + +\mathbf{O}_0[0, :] = \sum_{j=0}^{10} A_{0,j} \mathbf{V}_0[j, :] + +``` + +**Example (first dimension):** + +```math + +O_{0,0,0} = 0.35 \times 0.12 + 0.15 \times 0.08 + ... + 0.23 \times 0.15 \approx 0.115 + +``` + +**Stage 7: FFN Output** + +```math + +\mathbf{H}_{ffn}[0, :4] = [0.15, -0.08, 0.22, 0.18] + +``` + +**Stage 8: Final Output (after all layers)** + +```math + +\mathbf{H}_{final}[0, :4] = [0.42, 0.25, 0.58, 0.31] + +``` + +**Stage 9: Logits** + +```math + +\mathbf{L}[0, :] = [2.1, 1.8, ..., 5.2, ..., 3.4, ...] + + +Where L[0, 72] = 5.2 is highest (predicting 'H' at position 1). +``` + +**Stage 10: Probabilities** + +```math + +\mathbf{p}[0, :] = \text{softmax}(\mathbf{L}[0, :]) = [0.01, 0.008, ..., 0.28, ..., 0.15, ...] + + + +p[0, 72] \approx 0.28 \quad \text{(28\% probability for H)} + +``` + +--- + +## 19. Vector Plots and Visualizations + +### 19.1 Embedding Vector Trajectory + +**Trajectory Plot:** + +``` +512-Dimensional Embedding Space (2D Projection) + + 0.3 │ 'e' (pos 1) + │ ā— + 0.2 │ 'r' (pos 8) + │ ā— + 0.1 │ 'l' (pos 2,3,9) 'o' (pos 4,7) + │ ā— ā— + 0.0 ā”œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ + │ 'H' (pos 0) + -0.1 │ ā— + │ + -0.2 │ + │ + -0.3 │ 'W' (pos 6) + │ ā— + └─────────────────────────────────────────── + -0.3 -0.2 -0.1 0.0 0.1 0.2 0.3 +``` + +### 19.2 Attention Heatmap + +**Attention Weight Matrix Visualization:** + +``` +Attention Weights A[i,j] for "Hello World" + + j → 0 1 2 3 4 5 6 7 8 9 10 + ↓ ['H'] ['e'] ['l'] ['l'] ['o'] [' '] ['W'] ['o'] ['r'] ['l'] ['d'] +i=0 ['H'] │ 0.35 0.15 0.22 0.20 0.28 0.14 0.19 0.26 0.17 0.21 0.23 │ +i=1 ['e'] │ 0.15 0.38 0.20 0.18 0.27 0.16 0.18 0.25 0.19 0.22 0.20 │ +i=2 ['l'] │ 0.23 0.18 0.32 0.30 0.26 0.17 0.21 0.24 0.25 0.31 0.23 │ +i=3 ['l'] │ 0.21 0.19 0.28 0.33 0.25 0.18 0.20 0.23 0.24 0.30 0.22 │ +i=4 ['o'] │ 0.27 0.22 0.26 0.25 0.36 0.19 0.23 0.29 0.24 0.27 0.25 │ +i=5 [' '] │ 0.18 0.20 0.19 0.21 0.24 0.40 0.22 0.25 0.21 0.20 0.22 │ +i=6 ['W'] │ 0.22 0.21 0.23 0.24 0.26 0.20 0.45 0.28 0.27 0.23 0.25 │ +i=7 ['o'] │ 0.26 0.25 0.24 0.23 0.29 0.21 0.28 0.38 0.26 0.24 0.26 │ +i=8 ['r'] │ 0.19 0.21 0.25 0.24 0.24 0.19 0.27 0.26 0.42 0.27 0.28 │ +i=9 ['l'] │ 0.21 0.22 0.31 0.30 0.27 0.20 0.23 0.24 0.27 0.35 0.24 │ +i=10['d'] │ 0.23 0.20 0.23 0.22 0.25 0.22 0.25 0.26 0.28 0.24 0.48 │ + +Color Coding: +ā–ˆ = 0.48-0.50 (very high attention) +ā–ˆ = 0.35-0.48 (high attention) +ā–ˆ = 0.25-0.35 (medium attention) +ā–ˆ = 0.15-0.25 (low attention) +ā–ˆ = 0.00-0.15 (very low attention) +``` + +### 19.3 Probability Distribution Plot + +**Logits and Probabilities:** + +``` +Logits L[5, :] (predicting token after "Hello ") + +Logit +Value │ + 6.0 │ ā— (token 87 'W') + │ + 5.0 │ ā— (token 111 'o') + │ + 4.0 │ ā— (token 32 ' ') ā— (token 114 'r') + │ + 3.0 │ ā— ā— ā— + │ + 2.0 │ ā— ā— ā— ā— ā— ā— ā— ā— ā— ā— ā— + │ + 1.0 │ ā— ā— ā— ā— ā— ā— ā— ā— ā— ā— ā— + │ + 0.0 ā”œā”€ā”“ā”€ā”€ā”“ā”€ā”€ā”“ā”€ā”€ā”“ā”€ā”€ā”“ā”€ā”€ā”“ā”€ā”€ā”“ā”€ā”€ā”“ā”€ā”€ā”“ā”€ā”€ā”“ā”€ā”€ā”“ā”€ā”€ Token IDs + 32 72 87 101 108 111 114 ... + ␣ H W e l o r + +Probabilities p[5, :] + +Probability + │ + 0.3│ ā— ('W') + │ + 0.2│ ā— (' ') ā— ('o') + │ + 0.1│ ā— ā— ā— ā— ā— ā— ā— + │ + 0.0ā”œā”€ā”“ā”€ā”€ā”“ā”€ā”€ā”“ā”€ā”€ā”“ā”€ā”€ā”“ā”€ā”€ā”“ā”€ā”€ā”“ā”€ā”€ā”“ā”€ā”€ā”“ā”€ā”€ā”“ā”€ā”€ā”“ā”€ā”€ Token IDs + 32 72 87 101 108 111 114 ... +``` + +### 19.4 Hidden State Evolution Through Layers + +**Layer-by-Layer Transformation:** + +``` +Hidden State Evolution for Token 'H' (position 0) + +Dimension 0: +Layer 0: 0.10 (embedding + positional) +Layer 1: 0.42 (after attention + FFN) +Layer 2: 0.58 (after second layer) +Layer 3: 0.65 (after third layer) +... ... +Layer L: 0.72 (final hidden state) + +Dimension 1: +Layer 0: 0.80 (embedding + positional) +Layer 1: 0.25 (after attention + FFN) +Layer 2: 0.18 (after second layer) +Layer 3: 0.22 (after third layer) +... ... +Layer L: 0.15 (final hidden state) +``` + +**Visualization:** + +``` +Hidden State Magnitude ||h[l]|| Over Layers + +Magnitude + │ + 1.0│ ā— + │ ā— + 0.8│ ā— + │ ā— + 0.6│ ā— + │ ā— + 0.4│ ā— + │ ā— + 0.2│ ā— + │ ā— + 0.0ā”œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ Layer + 0 1 2 3 4 5 6 +``` + +--- + +## 20. Summary: Complete Mathematical Trace + +### Complete System Equation with Numerical Example + +**Text:** `"Hello World"` + +**Complete Mathematical Flow:** + +1. **Tokenization:** + +```math + + \mathbf{t} = \mathcal{T}(\text{"Hello World"}) = [72, 101, 108, 108, 111, 32, 87, 111, 114, 108, 100] + +``` + +2. **Embedding:** + +```math + + \mathbf{X} = \mathbf{E}[\mathbf{t}] \in \mathbb{R}^{11 \times 512} + +``` + +3. **Positional Encoding:** + +```math + + \mathbf{X}_{pos} = \mathbf{X} + \mathbf{PE} \in \mathbb{R}^{11 \times 512} + +``` + +4. **Transformer Layers (L=6):** + +```math + + \mathbf{h}_l = \text{TransformerBlock}_l(\mathbf{h}_{l-1}), \quad l = 1, ..., 6 + +``` + +5. **Output:** + +```math + + \mathbf{L} = \mathbf{h}_6 \mathbf{W}_{out} \in \mathbb{R}^{11 \times 128} + +``` + +6. **Probabilities:** + +```math + + \mathbf{p} = \text{softmax}(\mathbf{L}) \in \mathbb{R}^{11 \times 128} + +``` + +**Final Prediction:** + +For position 5 (after "Hello "): + +```math + +p[5, 87] = 0.28 \quad \text{(28\% for W)} \\ +p[5, 32] = 0.22 \quad \text{(22\% for space)} \\ +p[5, 111] = 0.18 \quad \text{(18\% for o)} + +``` + +**Most Likely:** `'W'` → Complete prediction: `"Hello World"` + +--- + +_This document provides a complete mathematical control system formulation with block diagrams, vector visualizations, numerical examples, and step-by-step calculations for every component of the SheepOp LLM._ diff --git a/docs/DATABASE_EXTRACTION_GUIDE.md b/docs/DATABASE_EXTRACTION_GUIDE.md new file mode 100644 index 0000000..6843a74 --- /dev/null +++ b/docs/DATABASE_EXTRACTION_GUIDE.md @@ -0,0 +1,217 @@ +# Database Extraction Guide + +This guide shows you how to extract text from your 1TB database for training. + +## Quick Start + +### SQLite Database + +```bash +# Extract from SQLite database +python3 extract_from_database.py \ + --type sqlite \ + --db-path /path/to/your/database.db \ + --table your_table_name \ + --column text_column_name \ + --output data/database_training.txt \ + --limit 1000000 # Limit to 1M samples (or omit for all) +``` + +### PostgreSQL Database + +```bash +# Install PostgreSQL driver first +pip install psycopg2-binary + +# Extract with SQL query +python3 extract_from_database.py \ + --type sql \ + --connection "host=localhost dbname=mydb user=myuser password=mypass" \ + --query "SELECT text_column FROM your_table WHERE length(text_column) > 50" \ + --output data/database_training.txt \ + --limit 1000000 +``` + +### MySQL Database + +```bash +# Install MySQL driver first +pip install pymysql + +# Extract with SQL query +python3 extract_from_database.py \ + --type sql \ + --connection "mysql+pymysql://user:pass@localhost/dbname" \ + --query "SELECT text_column FROM your_table" \ + --output data/database_training.txt +``` + +### JSON/JSONL Files + +```bash +# Extract from JSON Lines file +python3 extract_from_database.py \ + --type json \ + --json-path /path/to/data.jsonl \ + --text-field content \ + --output data/database_training.txt \ + --limit 1000000 +``` + +## Examples + +### Example 1: Extract All Text from SQLite Table + +```bash +python3 extract_from_database.py \ + --type sqlite \ + --db-path /Volumes/YourDisk/database.db \ + --table articles \ + --column body_text \ + --output data/training_data.txt +``` + +### Example 2: Extract Filtered Data (Longer Texts Only) + +```bash +python3 extract_from_database.py \ + --type sqlite \ + --db-path /Volumes/YourDisk/database.db \ + --table articles \ + --column body_text \ + --where "WHERE length(body_text) > 200" \ + --output data/training_data.txt \ + --min-length 50 +``` + +### Example 3: Extract from Multiple Tables + +```bash +# Extract from table 1 +python3 extract_from_database.py \ + --type sqlite \ + --db-path /Volumes/YourDisk/database.db \ + --table articles \ + --column content \ + --output data/articles.txt + +# Extract from table 2 +python3 extract_from_database.py \ + --type sqlite \ + --db-path /Volumes/YourDisk/database.db \ + --table comments \ + --column text \ + --output data/comments.txt + +# Combine files +cat data/articles.txt data/comments.txt > data/combined_training.txt +``` + +### Example 4: PostgreSQL with Complex Query + +```bash +python3 extract_from_database.py \ + --type sql \ + --connection "host=localhost dbname=mydb user=myuser password=mypass" \ + --query "SELECT description FROM products WHERE description IS NOT NULL AND length(description) > 100 UNION SELECT review_text FROM reviews WHERE review_text IS NOT NULL" \ + --output data/products_and_reviews.txt +``` + +## Options + +### Filtering Options + +```bash +# Only extract texts longer than 100 characters +--min-length 100 + +# Limit total samples +--limit 1000000 + +# Add WHERE clause (SQLite) +--where "WHERE created_at > '2024-01-01' AND length(text) > 200" +``` + +### Output Options + +```bash +# Custom output path +--output data/my_training_data.txt + +# Don't clean/split text (preserve original format) +--no-clean +``` + +## Performance Tips + +1. **Use LIMIT for Testing**: Start with `--limit 10000` to test +2. **Filter in Database**: Use `--where` clause to filter at database level (faster) +3. **Batch Processing**: The script processes in batches automatically +4. **Monitor Progress**: Progress updates every 1000 texts + +## Data Format + +The output file will have: +- One text sample per line +- Cleaned and split into sentences +- Minimum length filtering applied +- UTF-8 encoding + +## Next Steps + +After extraction: + +```bash +# Check how much data you extracted +wc -l data/database_training.txt + +# Train with the extracted data +python3 train.py --data data/database_training.txt --config config.json --device mps +``` + +## Troubleshooting + +### SQLite Database Locked +- Close any applications using the database +- Copy database to a local location first + +### Large Database (1TB) +- Use `--limit` to extract in batches +- Use `--where` to filter at database level +- Consider extracting to multiple files and combining + +### Memory Issues +- The script processes in batches (streaming) +- Use `--limit` to control size +- Process in chunks if needed + +## Example Workflow + +```bash +# 1. Extract 1M samples for testing +python3 extract_from_database.py \ + --type sqlite \ + --db-path /Volumes/YourDisk/database.db \ + --table your_table \ + --column text_column \ + --output data/test_extraction.txt \ + --limit 1000000 + +# 2. Check the data +head -20 data/test_extraction.txt +wc -l data/test_extraction.txt + +# 3. If good, extract more (or all) +python3 extract_from_database.py \ + --type sqlite \ + --db-path /Volumes/YourDisk/database.db \ + --table your_table \ + --column text_column \ + --output data/full_training.txt + +# 4. Train with the data +python3 train.py --data data/full_training.txt --config config.json --device mps +``` + +Good luck extracting your 1TB database! šŸš€ + diff --git a/docs/DATA_GUIDE.md b/docs/DATA_GUIDE.md new file mode 100644 index 0000000..24fb53b --- /dev/null +++ b/docs/DATA_GUIDE.md @@ -0,0 +1,225 @@ +# Data Collection Guide + +This guide shows you how to get training data from the internet or create your own data.txt file. + +## Option 1: Use the Download Script + +### Quick Start + +```bash +# Download Shakespeare text (recommended for testing) +python download_data.py --type shakespeare + +# Create a sample data file +python download_data.py --type sample --output data/my_data.txt --samples 200 + +# Download Wikipedia article (requires: pip install wikipedia) +python download_data.py --type wikipedia --title "Artificial Intelligence" --output data/ai_article.txt +``` + +### Available Options + +**Shakespeare Dataset:** +```bash +python download_data.py --type shakespeare +``` +Downloads classic Shakespeare text - great for testing! + +**Create Sample Data:** +```bash +python download_data.py --type sample --output data/my_data.txt --samples 100 +``` +Creates a file with sample sentences about ML/AI. + +**Wikipedia Article:** +```bash +python download_data.py --type wikipedia --title "Machine Learning" --output data/ml_article.txt +``` +Downloads a Wikipedia article (requires `pip install wikipedia`). + +## Option 2: Manual Data Collection + +### Method A: Create Your Own data.txt + +1. **Create a text file:** +```bash +nano data/my_data.txt +# or +vim data/my_data.txt +``` + +2. **Add your text** (one sentence per line): +``` +This is my first training sample. +This is my second training sample. +Add as many lines as you want. +``` + +3. **Save and use:** +```bash +python train.py --data data/my_data.txt +``` + +### Method B: Download from Public Datasets + +**Shakespeare Text:** +```bash +curl -o data/shakespeare.txt https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +``` + +**Book Corpus Sample:** +```bash +# Download Project Gutenberg books +curl -o data/book.txt https://www.gutenberg.org/files/1342/1342-0.txt # Pride and Prejudice +``` + +**News Articles:** +```bash +# Download news text +curl -o data/news.txt https://raw.githubusercontent.com/sunnysai12345/News_Summary/master/news_summary_more.csv +``` + +### Method C: Scrape Your Own Data + +**From Wikipedia (Python):** +```python +import wikipedia + +page = wikipedia.page("Machine Learning") +with open("data/ml_article.txt", "w") as f: + f.write(page.content) +``` + +**From a Website:** +```python +import requests +from bs4 import BeautifulSoup + +url = "https://example.com/article" +response = requests.get(url) +soup = BeautifulSoup(response.text, 'html.parser') +text = soup.get_text() + +with open("data/scraped.txt", "w") as f: + f.write(text) +``` + +## Option 3: Use Existing Datasets + +### Popular NLP Datasets + +**WikiText-2:** +```bash +# Download WikiText-2 +wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip +unzip wikitext-2-v1.zip +# Use: wikitext-2/wiki.train.tokens +``` + +**OpenWebText Sample:** +```bash +# Download sample +curl -o data/openwebtext_sample.txt https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +``` + +**BookCorpus:** +```bash +# Various book sources available +# Check: https://github.com/soskek/bookcorpus +``` + +## Data Format Requirements + +Your `data.txt` file should: +- Have **one text sample per line** +- Use **UTF-8 encoding** +- Be **plain text** (no special formatting) + +**Example format:** +``` +This is the first training example. +This is the second training example. +Each line becomes one training sample. +``` + +**Good:** +``` +Hello world! +This is a sentence. +Machine learning is cool. +``` + +**Bad:** +``` +This is paragraph 1 with multiple sentences. This is sentence 2. +This is paragraph 2. +``` + +## Preprocessing Tips + +1. **Clean your data:** +```python +import re + +with open("raw_data.txt", "r") as f: + text = f.read() + +# Remove extra whitespace +text = re.sub(r'\s+', ' ', text) + +# Split into sentences +sentences = text.split('.') + +# Write one per line +with open("data/cleaned_data.txt", "w") as f: + for sentence in sentences: + if sentence.strip(): + f.write(sentence.strip() + '\n') +``` + +2. **Split long texts:** +```python +# If you have long texts, split them into sentences +text = "Long paragraph here. Another sentence. More text." +sentences = text.split('.') +for sentence in sentences: + if sentence.strip(): + print(sentence.strip()) +``` + +## Quick Test + +1. **Create a small test file:** +```bash +cat > data/test.txt << EOF +Hello world! +This is a test. +Language models are cool. +EOF +``` + +2. **Train with it:** +```bash +python train.py --data data/test.txt --output ./checkpoints +``` + +## Recommended Data Sources + +- **Small (for testing):** Shakespeare text, sample_data.txt +- **Medium (for training):** Wikipedia articles, news articles +- **Large (for serious training):** WikiText-2, BookCorpus, OpenWebText + +## Next Steps + +Once you have your data.txt file: + +```bash +# Train your model +python train.py --data data/your_data.txt --output ./checkpoints + +# Or use the sample data +python train.py --data data/sample_data.txt --output ./checkpoints +``` + +Happy training! šŸš€ + diff --git a/docs/DATA_PROCESSING_EXPLAINED.md b/docs/DATA_PROCESSING_EXPLAINED.md new file mode 100644 index 0000000..1aca5f0 --- /dev/null +++ b/docs/DATA_PROCESSING_EXPLAINED.md @@ -0,0 +1,914 @@ +# Data Processing Explained: Step-by-Step Guide + +Complete guide to understanding data processing in the SheepOp LLM project, explaining what happens to your data from raw files to training-ready text. + +## Table of Contents + +1. [What is Data Processing?](#1-what-is-data-processing) +2. [Why Do We Need Data Processing?](#2-why-do-we-need-data-processing) +3. [The Data Processing Pipeline](#3-the-data-processing-pipeline) +4. [Step-by-Step: How Each File Type is Processed](#4-step-by-step-how-each-file-type-is-processed) +5. [Data Transformation Stages](#5-data-transformation-stages) +6. [Complete Example: Processing "Hello World.pdf"](#6-complete-example-processing-hello-worldpdf) +7. [Data Quality and Filtering](#7-data-quality-and-filtering) +8. [Common Questions](#8-common-questions) + +--- + +## 1. What is Data Processing? + +**Data processing** is the transformation of raw, unstructured data into a format that machine learning models can understand and learn from. + +### Simple Analogy + +Think of data processing like preparing ingredients for cooking: + +**Raw Ingredients (Your Files):** +- PDF documents +- Text files +- Images with text +- Code files + +**Prepared Ingredients (Processed Data):** +- Clean text lines +- Consistent format +- Ready for training + +**The Recipe (Training):** +- The model learns from the prepared ingredients + +### In Our Context + +**Input:** Mixed file types (PDFs, images, code, text) +**Output:** List of text strings ready for tokenization +**Purpose:** Extract meaningful text that the model can learn from + +--- + +## 2. Why Do We Need Data Processing? + +### 2.1 The Problem + +Machine learning models (like our transformer) understand **numbers**, not: +- PDF files +- Images +- Raw text files +- Code files + +### 2.2 The Solution + +We need to: +1. **Extract** text from different file formats +2. **Clean** the text (remove noise, handle encoding) +3. **Standardize** the format (consistent structure) +4. **Prepare** for tokenization (split into manageable pieces) + +### 2.3 Benefits + +āœ… **Unified Format**: All data becomes text lines +āœ… **Easy to Process**: Simple format for tokenization +āœ… **Flexible**: Works with many file types +āœ… **Scalable**: Can process thousands of files automatically + +--- + +## 3. The Data Processing Pipeline + +### 3.1 High-Level Overview + +``` +Raw Files + ↓ +[File Type Detection] + ↓ +[Text Extraction] + ↓ +[Text Cleaning] + ↓ +[Line Splitting] + ↓ +[Filtering] + ↓ +Clean Text Lines + ↓ +[Tokenization] ← Not part of data processing + ↓ +[Training] ← Not part of data processing +``` + +### 3.2 Detailed Pipeline + +``` +Step 1: Directory Scan + └─→ Find all files in data/ directory + └─→ Categorize by file type (.pdf, .txt, .png, etc.) + +Step 2: File Type Detection + └─→ Check file extension + └─→ Route to appropriate processor + +Step 3: Text Extraction + ā”œā”€ā†’ PDF files → PDF text extraction + ā”œā”€ā†’ Text files → Read as text + ā”œā”€ā†’ Image files → OCR (Optical Character Recognition) + └─→ Code files → Read as text + +Step 4: Text Cleaning + └─→ Remove extra whitespace + └─→ Handle encoding issues + └─→ Normalize line endings + +Step 5: Line Splitting + └─→ Split text into individual lines + └─→ Each line becomes one training sample + +Step 6: Filtering + └─→ Remove empty lines + └─→ Filter by minimum length + └─→ Remove lines that are too short + +Step 7: Output + └─→ List of text strings + └─→ Ready for tokenization +``` + +--- + +## 4. Step-by-Step: How Each File Type is Processed + +### 4.1 Text Files (.txt, .md, .log, etc.) + +**What happens:** +1. File is opened +2. Content is read line by line +3. Each line becomes a separate text sample + +**Example:** + +**Input:** `document.txt` +``` +Hello world +This is a sentence. +Machine learning is fascinating. +``` + +**Processing:** +``` +Line 1: "Hello world" +Line 2: "This is a sentence." +Line 3: "Machine learning is fascinating." +``` + +**Output:** +```python +[ + "Hello world", + "This is a sentence.", + "Machine learning is fascinating." +] +``` + +**Why this works:** Text files are already in plain text format, so extraction is straightforward. + +--- + +### 4.2 Code Files (.py, .js, .java, etc.) + +**What happens:** +1. File is opened +2. Content is read line by line +3. Each line becomes a separate text sample + +**Example:** + +**Input:** `example.py` +```python +def hello(): + print("Hello") + return True +``` + +**Processing:** +``` +Line 1: "def hello():" +Line 2: " print("Hello")" +Line 3: " return True" +``` + +**Output:** +```python +[ + "def hello():", + " print("Hello")", + " return True" +] +``` + +**Why this works:** Code files are text files, so they're processed the same way. The model learns code patterns and syntax. + +--- + +### 4.3 PDF Files (.pdf) + +**What happens:** +1. PDF file is opened +2. Text is extracted from each page +3. Text is split into lines +4. Lines are filtered for quality + +**Example:** + +**Input:** `document.pdf` (3 pages) + +**Page 1:** +``` +Introduction to Machine Learning +Machine learning is a subset of artificial intelligence. +``` + +**Page 2:** +``` +Neural Networks +Neural networks are computing systems inspired by biological neural networks. +``` + +**Page 3:** +``` +Conclusion +In conclusion, machine learning has revolutionized technology. +``` + +**Processing:** + +**Step 1: Extract text from each page** +``` +Page 1 text: "Introduction to Machine Learning\nMachine learning is a subset of artificial intelligence." +Page 2 text: "Neural Networks\nNeural networks are computing systems inspired by biological neural networks." +Page 3 text: "Conclusion\nIn conclusion, machine learning has revolutionized technology." +``` + +**Step 2: Split by newlines** +``` +Line 1: "Introduction to Machine Learning" +Line 2: "Machine learning is a subset of artificial intelligence." +Line 3: "Neural Networks" +Line 4: "Neural networks are computing systems inspired by biological neural networks." +Line 5: "Conclusion" +Line 6: "In conclusion, machine learning has revolutionized technology." +``` + +**Step 3: Filter short lines** +``` +Remove: "Introduction to Machine Learning" (too short for context) +Keep: "Machine learning is a subset of artificial intelligence." +Remove: "Neural Networks" (too short) +Keep: "Neural networks are computing systems inspired by biological neural networks." +Remove: "Conclusion" (too short) +Keep: "In conclusion, machine learning has revolutionized technology." +``` + +**Output:** +```python +[ + "Machine learning is a subset of artificial intelligence.", + "Neural networks are computing systems inspired by biological neural networks.", + "In conclusion, machine learning has revolutionized technology." +] +``` + +**Why this works:** PDFs contain text embedded in the file structure. Libraries like PyPDF2 or pdfplumber extract this text, preserving the content but losing formatting. + +--- + +### 4.4 Image Files (.png, .jpg, etc.) + +**What happens:** +1. Image file is opened +2. OCR (Optical Character Recognition) reads text from the image +3. Extracted text is split into lines +4. Lines are filtered for quality + +**Example:** + +**Input:** `screenshot.png` containing: +``` +Hello World +This is text in an image. +``` + +**Processing:** + +**Step 1: OCR Processing** +``` +Image → OCR Engine → Text +"Hello World\nThis is text in an image." +``` + +**Step 2: Split by newlines** +``` +Line 1: "Hello World" +Line 2: "This is text in an image." +``` + +**Step 3: Filter short lines** +``` +Remove: "Hello World" (might be too short) +Keep: "This is text in an image." +``` + +**Output:** +```python +[ + "This is text in an image." +] +``` + +**Why this works:** OCR software analyzes the image pixel by pixel, identifies characters, and converts them to text. Accuracy depends on image quality. + +--- + +## 5. Data Transformation Stages + +### 5.1 Stage 1: File Discovery + +**Purpose:** Find all files to process + +**Process:** +``` +Directory: data/ + ā”œā”€ā”€ document.pdf + ā”œā”€ā”€ code.py + ā”œā”€ā”€ screenshot.png + └── notes.txt + +Scan recursively: + ā”œā”€ā”€ Find: document.pdf + ā”œā”€ā”€ Find: code.py + ā”œā”€ā”€ Find: screenshot.png + └── Find: notes.txt + +Total: 4 files found +``` + +**Result:** List of file paths to process + +--- + +### 5.2 Stage 2: File Type Classification + +**Purpose:** Determine how to process each file + +**Process:** +``` +File: document.pdf + ā”œā”€ā”€ Extension: .pdf + ā”œā”€ā”€ Type: PDF + └── Processor: PDF Extractor + +File: code.py + ā”œā”€ā”€ Extension: .py + ā”œā”€ā”€ Type: Code + └── Processor: Text Reader + +File: screenshot.png + ā”œā”€ā”€ Extension: .png + ā”œā”€ā”€ Type: Image + └── Processor: OCR + +File: notes.txt + ā”œā”€ā”€ Extension: .txt + ā”œā”€ā”€ Type: Text + └── Processor: Text Reader +``` + +**Result:** Each file assigned to appropriate processor + +--- + +### 5.3 Stage 3: Text Extraction + +**Purpose:** Get raw text from each file + +**Process:** + +**PDF File:** +``` +document.pdf + → Open PDF + → Extract Page 1: "Introduction..." + → Extract Page 2: "Chapter 1..." + → Extract Page 3: "Conclusion..." + → Combine: "Introduction...\nChapter 1...\nConclusion..." +``` + +**Text File:** +``` +notes.txt + → Open file + → Read content: "Hello\nWorld\nTest" +``` + +**Image File:** +``` +screenshot.png + → Open image + → Run OCR + → Extract: "Hello World\nThis is text" +``` + +**Code File:** +``` +code.py + → Open file + → Read content: "def hello():\n print('Hi')" +``` + +**Result:** Raw text strings from each file + +--- + +### 5.4 Stage 4: Text Cleaning + +**Purpose:** Standardize and clean the extracted text + +**Process:** + +**Input:** +``` +"Hello World\n\n\nThis is a test. " +``` + +**Step 1: Remove Extra Whitespace** +``` +"Hello World\n\n\nThis is a test. " + ↓ +"Hello World\n\n\nThis is a test." +``` + +**Step 2: Normalize Line Endings** +``` +"Hello World\n\n\nThis is a test." + ↓ +"Hello World\n\n\nThis is a test." +``` + +**Step 3: Handle Encoding** +``` +"Hello World" (UTF-8) + ↓ +"Hello World" (checked and valid) +``` + +**Result:** Cleaned text strings + +--- + +### 5.5 Stage 5: Line Splitting + +**Purpose:** Break text into individual training samples + +**Process:** + +**Input:** +``` +"Hello World\nThis is a test.\nMachine learning is cool." +``` + +**Split by newlines:** +``` +Line 1: "Hello World" +Line 2: "This is a test." +Line 3: "Machine learning is cool." +``` + +**Result:** List of individual text lines + +--- + +### 5.6 Stage 6: Filtering + +**Purpose:** Keep only useful text samples + +**Process:** + +**Input:** +```python +[ + "Hello World", # Length: 11 + "Hi", # Length: 2 (too short) + "This is a sentence.", # Length: 19 + "", # Empty (remove) + "A" # Length: 1 (too short) +] +``` + +**Filter criteria:** +- Minimum length: 10 characters +- Non-empty strings + +**Filtering:** +``` +Keep: "Hello World" (length 11 ≄ 10) +Remove: "Hi" (length 2 < 10) +Keep: "This is a sentence." (length 19 ≄ 10) +Remove: "" (empty) +Remove: "A" (length 1 < 10) +``` + +**Output:** +```python +[ + "Hello World", + "This is a sentence." +] +``` + +**Result:** Filtered list of quality text samples + +--- + +## 6. Complete Example: Processing "Hello World.pdf" + +Let's trace through processing a complete PDF file step-by-step. + +### Input +**File:** `Hello World.pdf` +**Location:** `data/documents/Hello World.pdf` +**Content:** 2 pages with text + +### Step-by-Step Processing + +#### Step 1: File Discovery + +``` +Scanning: data/ + ā”œā”€ā”€ documents/ + │ └── Hello World.pdf ← Found + ā”œā”€ā”€ images/ + └── code/ + +File found: data/documents/Hello World.pdf +``` + +#### Step 2: File Type Detection + +``` +File: Hello World.pdf +Extension: .pdf +Type: PDF +Processor: PDF Extractor +``` + +#### Step 3: PDF Text Extraction + +**Page 1 Content:** +``` +Hello World +This is a simple example document. +It contains multiple sentences. +``` + +**Page 2 Content:** +``` +Second Page +Here is more content. +The end. +``` + +**Extraction Process:** +``` +Open PDF file + ↓ +Extract Page 1: + Text: "Hello World\nThis is a simple example document.\nIt contains multiple sentences." + ↓ +Extract Page 2: + Text: "Second Page\nHere is more content.\nThe end." + ↓ +Combine pages: + "Hello World\nThis is a simple example document.\nIt contains multiple sentences.\nSecond Page\nHere is more content.\nThe end." +``` + +#### Step 4: Text Cleaning + +**Input:** +``` +"Hello World\nThis is a simple example document.\nIt contains multiple sentences.\nSecond Page\nHere is more content.\nThe end." +``` + +**Process:** +``` +Remove extra whitespace: āœ“ (already clean) +Normalize encoding: āœ“ (UTF-8) +Handle special characters: āœ“ (none found) +``` + +**Output:** +``` +"Hello World\nThis is a simple example document.\nIt contains multiple sentences.\nSecond Page\nHere is more content.\nThe end." +``` + +#### Step 5: Line Splitting + +**Input:** +``` +"Hello World\nThis is a simple example document.\nIt contains multiple sentences.\nSecond Page\nHere is more content.\nThe end." +``` + +**Split by newline character (`\n`):** +``` +Line 1: "Hello World" +Line 2: "This is a simple example document." +Line 3: "It contains multiple sentences." +Line 4: "Second Page" +Line 5: "Here is more content." +Line 6: "The end." +``` + +#### Step 6: Filtering + +**Input:** +```python +[ + "Hello World", # Length: 11 + "This is a simple example document.", # Length: 36 + "It contains multiple sentences.", # Length: 31 + "Second Page", # Length: 11 + "Here is more content.", # Length: 21 + "The end." # Length: 8 (too short!) +] +``` + +**Filter: Minimum length = 10** +``` +āœ“ Keep: "Hello World" (11 ≄ 10) +āœ“ Keep: "This is a simple example document." (36 ≄ 10) +āœ“ Keep: "It contains multiple sentences." (31 ≄ 10) +āœ“ Keep: "Second Page" (11 ≄ 10) +āœ“ Keep: "Here is more content." (21 ≄ 10) +āœ— Remove: "The end." (8 < 10) +``` + +#### Step 7: Final Output + +**Result:** +```python +[ + "Hello World", + "This is a simple example document.", + "It contains multiple sentences.", + "Second Page", + "Here is more content." +] +``` + +**Statistics:** +- Files processed: 1 +- Pages extracted: 2 +- Lines extracted: 6 +- Lines kept: 5 +- Lines filtered: 1 + +--- + +## 7. Data Quality and Filtering + +### 7.1 Why Filter? + +**Problem:** Not all text is useful for training + +**Examples of Low-Quality Text:** + +``` +āœ— "" (empty line) +āœ— " " (just whitespace) +āœ— "Hi" (too short, no context) +āœ— "A" (single character) +āœ— "..." (ellipsis, no meaning) +āœ— "---" (separator line) +``` + +**Examples of High-Quality Text:** + +``` +āœ“ "Machine learning is a subset of artificial intelligence." +āœ“ "The transformer architecture uses self-attention mechanisms." +āœ“ "Gradient descent optimizes neural network parameters." +``` + +### 7.2 Filtering Criteria + +**Minimum Length Filter:** + +**Purpose:** Remove very short lines that don't provide context + +**Example:** +``` +Minimum length: 10 characters + +Keep: +āœ“ "Hello world" (11 chars) +āœ“ "This is a test." (15 chars) + +Remove: +āœ— "Hi" (2 chars) +āœ— "Test" (4 chars) +āœ— "OK" (2 chars) +``` + +**Why 10 characters?** +- Provides enough context for meaningful learning +- Filters out headers, separators, and noise +- Ensures each sample has semantic value + +### 7.3 Encoding Handling + +**Problem:** Files may have different encodings + +**Solution:** Try multiple encodings + +**Process:** +``` +Try UTF-8 first: + āœ“ Success → Use UTF-8 + āœ— Failure → Try Latin-1 + āœ“ Success → Use Latin-1 + āœ— Failure → Log error and skip file +``` + +**Example:** + +**UTF-8 file:** +``` +"Hello äø–ē•Œ" → Reads correctly +``` + +**Latin-1 file:** +``` +"Hello cafĆ©" → Reads correctly with Latin-1 +``` + +### 7.4 Error Handling + +**What happens when processing fails?** + +**Examples:** + +**Corrupted PDF:** +``` +File: corrupted.pdf + → Try to extract text + → Error: "Cannot read PDF" + → Log warning: "Failed to process corrupted.pdf" + → Skip file + → Continue with next file +``` + +**Unsupported File Type:** +``` +File: presentation.pptx + → Extension: .pptx + → Type: Not supported + → Warning: "Unsupported file type: .pptx" + → Skip file + → Continue with next file +``` + +**Image OCR Failure:** +``` +File: blurry_image.png + → Try OCR + → OCR returns empty or garbled text + → Filter removes empty lines + → No text extracted + → File processed (no output) +``` + +--- + +## 8. Common Questions + +### Q1: Why process PDFs instead of using them directly? + +**Answer:** +Models work with numbers (token IDs), not file formats. PDFs have: +- Complex structure (fonts, layouts, metadata) +- Embedded formatting +- Binary data mixed with text + +Processing extracts just the text content, which is what the model needs. + +### Q2: What if OCR doesn't work well on an image? + +**Answer:** +- Low-quality images produce poor OCR results +- The system will extract what it can +- Poor OCR output is filtered out (too short or garbled) +- The file is processed but may contribute little or no text + +**Solution:** Use high-quality images with clear text for best results. + +### Q3: Why split text into lines? + +**Answer:** +- Each line becomes a training sample +- Models predict next tokens in sequences +- Shorter sequences are easier to process +- Allows the model to learn from diverse sentence structures + +### Q4: What happens to code formatting? + +**Answer:** +- Code is processed as text +- Indentation and structure are preserved +- Each line becomes a sample +- The model learns code patterns and syntax + +**Example:** +```python +def hello(): + print("Hi") +``` + +Becomes: +``` +"def hello():" +" print("Hi")" +``` + +### Q5: Can I process files in parallel? + +**Answer:** +Currently, files are processed sequentially. Future improvements could include: +- Parallel processing of multiple files +- Multi-threaded extraction +- Batch processing for efficiency + +### Q6: What if a file is very large? + +**Answer:** +- Large files are processed line by line +- Memory usage stays manageable +- Progress is logged every 100 files +- System can handle files of any size (within memory limits) + +### Q7: How is data from different file types combined? + +**Answer:** +All extracted text is combined into a single list: + +``` +PDF file → 50 lines extracted +Text file → 30 lines extracted +Code file → 100 lines extracted +Image → 5 lines extracted + +Combined: 185 text lines total +``` + +All lines are treated equally, regardless of source file type. + +--- + +## Summary + +### What is Data Processing? + +**Data processing** is the transformation of raw files (PDFs, images, code, text) into clean text lines that can be tokenized and used for training. + +### Key Steps + +1. **Find Files**: Scan directory for all files +2. **Classify**: Determine file type (.pdf, .txt, .png, etc.) +3. **Extract**: Get text content from each file +4. **Clean**: Remove noise and standardize format +5. **Split**: Break into individual lines +6. **Filter**: Keep only quality text samples + +### Result + +A list of text strings ready for: +- Tokenization (converting to numbers) +- Training (teaching the model) +- Learning (model understanding patterns) + +### Example Flow + +``` +PDF file "document.pdf" + ↓ +Extract text from pages + ↓ +Clean and split into lines + ↓ +Filter by length + ↓ +["Sentence 1.", "Sentence 2.", "Sentence 3."] + ↓ +Ready for tokenization and training! +``` + +--- + +*This document explains what data processing means and how it transforms your raw files into training-ready text, step by step.* + diff --git a/docs/EMBEDDINGS_EXPLAINED.md b/docs/EMBEDDINGS_EXPLAINED.md new file mode 100644 index 0000000..bb480d9 --- /dev/null +++ b/docs/EMBEDDINGS_EXPLAINED.md @@ -0,0 +1,287 @@ +# What are Embeddings? Step-by-Step Explanation + +Complete step-by-step explanation of embeddings in transformer models: how words become numbers that capture meaning. + +## Table of Contents + +1. [The Problem Embeddings Solve](#11-the-problem-embeddings-solve) +2. [What is an Embedding?](#12-what-is-an-embedding) +3. [How Embeddings Work](#13-how-embeddings-work) +4. [Step-by-Step Example: Embedding "Hello"](#14-step-by-step-example-embedding-hello) +5. [Why Embeddings Matter](#15-why-embeddings-matter) +6. [Complete Example: Embedding Multiple Words](#16-complete-example-embedding-multiple-words) +7. [Visual Representation](#17-visual-representation) +8. [Key Takeaways](#18-key-takeaways) + +--- + +## 1.1 The Problem Embeddings Solve + +### The Challenge + +**Computers understand numbers, not words.** + +Your model receives: +- Input: `"Hello"` (a word, not a number) + +But neural networks need: +- Input: Numbers (like `[0.1, -0.2, 0.3, ...]`) + +### The Solution: Embeddings + +**Embeddings convert words (or tokens) into numbers (vectors) that capture meaning.** + +--- + +## 1.2 What is an Embedding? + +### Simple Definition + +An **embedding** is a numerical representation of a word or token that captures its semantic meaning. + +**Think of it like this:** +- Each word gets a unique "address" in a high-dimensional space +- Similar words end up close together +- Different words are far apart + +### Visual Analogy + +Imagine a map where: +- Words are cities +- Similar words are nearby cities +- Different words are distant cities + +``` + Semantic Space (2D visualization) + + "cat" "dog" + ā— ā— + + "car" "vehicle" + ā— ā— + + "king" "queen" + ā— ā— +``` + +In reality, embeddings use **512 dimensions** (not 2D), but the concept is the same. + +--- + +## 1.3 How Embeddings Work + +### Step 1: Vocabulary Mapping + +**Create a mapping from words to numbers:** + +``` +Vocabulary: +"Hello" → Token ID: 72 +"World" → Token ID: 87 +"the" → Token ID: 32 +... +``` + +**Result:** Each word has a unique ID number + +### Step 2: Embedding Matrix + +**Create a matrix where each row represents a word:** + +``` +Embedding Matrix E: + + Dimension 0 Dimension 1 Dimension 2 ... Dimension 511 +Token 0 [ 0.05 , -0.10 , 0.20 , ..., 0.15 ] +Token 1 [ -0.08 , 0.12 , -0.05 , ..., 0.08 ] +Token 2 [ 0.10 , -0.15 , 0.25 , ..., 0.12 ] +... +Token 72 [ 0.10 , -0.20 , 0.30 , ..., 0.05 ] ← "Hello" +... +Token 87 [ 0.15 , -0.18 , 0.28 , ..., 0.10 ] ← "World" +``` + +**Key Points:** +- Each row is a 512-dimensional vector +- Each row represents one token/word +- The values are learned during training + +### Step 3: Lookup Operation + +**When you need an embedding, look it up:** + +``` +Input: Token ID = 72 ("Hello") + ↓ +Lookup: E[72] + ↓ +Output: [0.10, -0.20, 0.30, ..., 0.05] (512 numbers) +``` + +--- + +## 1.4 Step-by-Step Example: Embedding "Hello" + +### Input + +``` +Word: "Hello" +Token ID: 72 +``` + +### Process + +**Step 1: Get Token ID** +``` +"Hello" → Lookup in vocabulary → 72 +``` + +**Step 2: Lookup Embedding** +``` +E[72] = [0.10, -0.20, 0.30, 0.15, -0.05, ..., 0.05] +``` + +**Step 3: Result** +``` +Embedding vector: [0.10, -0.20, 0.30, ..., 0.05] +Dimension: 512 numbers +Meaning: Numerical representation of "Hello" +``` + +### What These Numbers Mean + +**Individual numbers don't mean much by themselves**, but **together** they represent: +- Semantic meaning (what the word means) +- Contextual relationships (how it relates to other words) +- Syntactic information (grammatical role) + +**Key Insight:** The model learns these values during training to capture meaning. + +--- + +## 1.5 Why Embeddings Matter + +### Benefit 1: Continuous Space + +**Before Embeddings:** +``` +"Hello" = 72 +"World" = 87 +Distance: |72 - 87| = 15 (meaningless!) +``` + +**After Embeddings:** +``` +"Hello" = [0.10, -0.20, 0.30, ...] +"World" = [0.15, -0.18, 0.28, ...] +Distance: Can measure similarity mathematically! +``` + +### Benefit 2: Semantic Relationships + +**Similar words have similar embeddings:** + +``` +"cat" ā‰ˆ [0.8, 0.2, 0.1, ...] +"dog" ā‰ˆ [0.7, 0.3, 0.1, ...] ← Similar to "cat" +"car" ā‰ˆ [0.1, 0.9, 0.8, ...] ← Different from "cat" +``` + +**Distance in embedding space = semantic similarity** + +### Benefit 3: Mathematical Operations + +**You can do math with embeddings:** + +``` +"king" - "man" + "woman" ā‰ˆ "queen" +``` + +This works because embeddings capture semantic relationships! + +--- + +## 1.6 Complete Example: Embedding Multiple Words + +### Input Sentence + +``` +"Hello World" +``` + +### Step-by-Step Processing + +**Step 1: Tokenize** +``` +"Hello" → Token ID: 72 +"World" → Token ID: 87 +``` + +**Step 2: Lookup Embeddings** +``` +E[72] = [0.10, -0.20, 0.30, ..., 0.05] (512 numbers) +E[87] = [0.15, -0.18, 0.28, ..., 0.10] (512 numbers) +``` + +**Step 3: Stack Together** +``` +Embedding Matrix: +[ + [0.10, -0.20, 0.30, ..., 0.05], ← "Hello" + [0.15, -0.18, 0.28, ..., 0.10] ← "World" +] +Shape: [2, 512] +``` + +**Result:** Each word becomes a 512-dimensional vector + +--- + +## 1.7 Visual Representation + +### Embedding Space Visualization + +``` +2D Projection of 512-Dimensional Embedding Space: + + 0.3 │ "World" + │ ā— + 0.2 │ "Hello" + │ ā— + 0.1 │ + │ + 0.0 ā”œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ + │ + -0.1 │ + │ + -0.2 │ + │ + -0.3 │ +``` + +**Reality:** Embeddings exist in 512-dimensional space, but we can visualize them in 2D or 3D projections. + +### Similarity Visualization + +``` +Word Similarities (distance in embedding space): + +"cat" ──── 0.15 distance ──── "dog" (similar) + "cat" ──── 2.5 distance ──── "car" (different) +"king" ──── 0.8 distance ──── "queen" (related) +``` + +--- + +## 1.8 Key Takeaways: Embeddings + +āœ… **Embeddings convert words to numbers** +āœ… **Each word becomes a vector (list of numbers)** +āœ… **Similar words have similar vectors** +āœ… **Enables mathematical operations on words** +āœ… **Learned during training to capture meaning** + +--- + +*This document provides a step-by-step explanation of embeddings, the fundamental component that converts words into numerical representations in transformer models.* + diff --git a/docs/FEED_FORWARD_EXPLAINED.md b/docs/FEED_FORWARD_EXPLAINED.md new file mode 100644 index 0000000..22b4fa1 --- /dev/null +++ b/docs/FEED_FORWARD_EXPLAINED.md @@ -0,0 +1,470 @@ +# What is Feed-Forward? Step-by-Step Explanation + +Complete step-by-step explanation of feed-forward networks in transformer models: how models transform and refine features. + +## Table of Contents + +1. [The Problem Feed-Forward Solves](#31-the-problem-feed-forward-solves) +2. [What is Feed-Forward?](#32-what-is-feed-forward) +3. [How Feed-Forward Works: Step-by-Step](#33-how-feed-forward-works-step-by-step) +4. [Complete Example: Feed-Forward on "Hello"](#34-complete-example-feed-forward-on-hello) +5. [Why Feed-Forward Matters](#35-why-feed-forward-matters) +6. [Complete Feed-Forward Formula](#36-complete-feed-forward-formula) +7. [Visual Representation](#37-visual-representation) +8. [Why Expand and Compress?](#38-why-expand-and-compress) +9. [Key Takeaways](#39-key-takeaways) + +--- + +## 3.1 The Problem Feed-Forward Solves + +### The Challenge + +**Attention provides context, but we need to process and transform that information.** + +Think of it like cooking: +- **Attention:** Gathers ingredients (context) +- **Feed-Forward:** Cooks and transforms ingredients (processing) + +### The Solution: Feed-Forward Network + +**Feed-Forward applies complex transformations to each position independently.** + +--- + +## 3.2 What is Feed-Forward? + +### Simple Definition + +A **Feed-Forward Network (FFN)** is a two-layer neural network that: +1. **Expands** the input to a larger dimension +2. **Applies** a nonlinear transformation +3. **Compresses** back to original dimension + +### Visual Analogy + +**Think of it like a funnel:** + +``` +Input (512 dimensions) + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ EXPAND │ + │ 512 → 2048 │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ TRANSFORM │ + │ (GELU) │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ COMPRESS │ + │ 2048 → 512 │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ +Output (512 dimensions) +``` + +--- + +## 3.3 How Feed-Forward Works: Step-by-Step + +### High-Level Overview + +``` +Step 1: Expand dimension (512 → 2048) +Step 2: Apply nonlinear activation (GELU) +Step 3: Compress dimension (2048 → 512) +``` + +### Detailed Step-by-Step + +#### Step 1: Expansion (First Linear Layer) + +**Input:** Vector of size 512 +**Output:** Vector of size 2048 + +**Mathematical Operation:** +``` +H = X Ɨ W₁ + b₁ +``` + +**Example:** + +**Input X:** +``` +[0.10, -0.20, 0.30, ..., 0.05] (512 numbers) +``` + +**Weight Matrix W₁:** +``` +Shape: [512, 2048] +Each column transforms input to one output dimension +``` + +**Process:** +``` +H[0] = X[0]ƗW₁[0,0] + X[1]ƗW₁[1,0] + ... + X[511]ƗW₁[511,0] +H[1] = X[0]ƗW₁[0,1] + X[1]ƗW₁[1,1] + ... + X[511]ƗW₁[511,1] +... +H[2047] = X[0]ƗW₁[0,2047] + ... + X[511]ƗW₁[511,2047] +``` + +**Result:** +``` +H = [0.12, -0.08, 0.25, ..., 0.18] (2048 numbers) +``` + +**Why Expand?** +- More dimensions = more capacity for complex transformations +- Allows the model to learn intricate patterns +- Think of it as "more room to work" + +#### Step 2: Nonlinear Activation (GELU) + +**Apply GELU to each element:** + +**GELU Function:** +``` +GELU(x) = x Ɨ Φ(x) + +Where Φ(x) is the cumulative distribution function of standard normal distribution +``` + +**Simplified Understanding:** +- Values near zero → suppressed (close to 0) +- Positive values → pass through (modified) +- Negative values → suppressed more + +**Example:** + +**Input H:** +``` +H = [0.12, -0.08, 0.25, ..., 0.18] +``` + +**Apply GELU element-wise:** + +``` +GELU(0.12) ā‰ˆ 0.12 Ɨ 0.548 ā‰ˆ 0.066 +GELU(-0.08) ā‰ˆ -0.08 Ɨ 0.468 ā‰ˆ -0.037 +GELU(0.25) ā‰ˆ 0.25 Ɨ 0.599 ā‰ˆ 0.150 +... +GELU(0.18) ā‰ˆ 0.18 Ɨ 0.572 ā‰ˆ 0.103 +``` + +**Result:** +``` +H' = [0.066, -0.037, 0.150, ..., 0.103] (2048 numbers) +``` + +**Why Nonlinear?** +- Linear transformations can only do so much +- Nonlinearity enables complex function approximation +- Essential for learning patterns + +#### Step 3: Compression (Second Linear Layer) + +**Input:** Vector of size 2048 +**Output:** Vector of size 512 + +**Mathematical Operation:** +``` +O = H' Ɨ Wā‚‚ + bā‚‚ +``` + +**Process:** +``` +O[0] = H'[0]ƗWā‚‚[0,0] + H'[1]ƗWā‚‚[1,0] + ... + H'[2047]ƗWā‚‚[2047,0] +O[1] = H'[0]ƗWā‚‚[0,1] + H'[1]ƗWā‚‚[1,1] + ... + H'[2047]ƗWā‚‚[2047,1] +... +O[511] = H'[0]ƗWā‚‚[0,511] + ... + H'[2047]ƗWā‚‚[2047,511] +``` + +**Result:** +``` +O = [0.15, -0.10, 0.22, ..., 0.12] (512 numbers) +``` + +**Why Compress?** +- Project back to original dimension +- Maintains consistent size throughout model +- Combines expanded features into compact representation + +--- + +## 3.4 Complete Example: Feed-Forward on "Hello" + +### Input + +``` +Word: "Hello" +After Attention: [0.146, 0.108, 0.192, ..., 0.11] +Dimension: 512 +``` + +### Step-by-Step Processing + +#### Step 1: Expansion + +**Input X:** +``` +[0.146, 0.108, 0.192, ..., 0.11] (512 numbers) +``` + +**Weight Matrix W₁:** +``` +Shape: [512, 2048] +Values: Learned during training +``` + +**Compute:** +``` +H = X Ɨ W₁ +``` + +**Result:** +``` +H = [0.21, -0.15, 0.28, ..., 0.19] (2048 numbers) +``` + +**Visualization:** +``` +512 dimensions ──→ ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” ──→ 2048 dimensions + │ W₁ │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ +``` + +#### Step 2: Activation + +**Input H:** +``` +[0.21, -0.15, 0.28, ..., 0.19] (2048 numbers) +``` + +**Apply GELU element-wise:** + +``` +GELU(0.21) ā‰ˆ 0.115 +GELU(-0.15) ā‰ˆ -0.058 +GELU(0.28) ā‰ˆ 0.168 +... +GELU(0.19) ā‰ˆ 0.109 +``` + +**Result:** +``` +H' = [0.115, -0.058, 0.168, ..., 0.109] (2048 numbers) +``` + +**Visualization:** +``` +2048 dimensions ──→ ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” ──→ 2048 dimensions + │ GELU │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ +``` + +#### Step 3: Compression + +**Input H':** +``` +[0.115, -0.058, 0.168, ..., 0.109] (2048 numbers) +``` + +**Weight Matrix Wā‚‚:** +``` +Shape: [2048, 512] +Values: Learned during training +``` + +**Compute:** +``` +O = H' Ɨ Wā‚‚ +``` + +**Result:** +``` +O = [0.18, -0.12, 0.24, ..., 0.14] (512 numbers) +``` + +**Visualization:** +``` +2048 dimensions ──→ ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” ──→ 512 dimensions + │ Wā‚‚ │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ +``` + +#### Final Output + +``` +Output: [0.18, -0.12, 0.24, ..., 0.14] (512 numbers) +``` + +**Meaning:** Transformed representation that captures processed features + +--- + +## 3.5 Why Feed-Forward Matters + +### Benefit 1: Feature Transformation + +**Before FFN:** +``` +Input: Raw attention output +Information: Contextual relationships +``` + +**After FFN:** +``` +Output: Transformed features +Information: Processed and refined understanding +``` + +### Benefit 2: Non-Linear Processing + +**Linear operations** (like attention) can only do limited transformations. +**Non-linear operations** (like GELU in FFN) enable complex function learning. + +**Analogy:** +- Linear: Can only draw straight lines +- Non-linear: Can draw curves, circles, complex shapes + +### Benefit 3: Position-Wise Processing + +**FFN processes each position independently:** + +``` +Position 0 ("Hello"): FFN → Transformed representation +Position 1 ("World"): FFN → Transformed representation +``` + +**Each word gets its own transformation!** + +--- + +## 3.6 Complete Feed-Forward Formula + +### Mathematical Expression + +``` +FFN(X) = GELU(X Ɨ W₁ + b₁) Ɨ Wā‚‚ + bā‚‚ +``` + +**Breaking it down:** + +**Part 1: First Linear Transformation** +``` +H = X Ɨ W₁ + b₁ +``` +- Expands from 512 to 2048 dimensions + +**Part 2: Non-Linear Activation** +``` +H' = GELU(H) +``` +- Applies non-linear transformation + +**Part 3: Second Linear Transformation** +``` +O = H' Ɨ Wā‚‚ + bā‚‚ +``` +- Compresses from 2048 back to 512 dimensions + +**Complete:** +``` +FFN(X) = O +``` + +--- + +## 3.7 Visual Representation + +### Feed-Forward Pipeline + +``` +Input Vector (512D) + │ + │ [0.146, 0.108, 0.192, ..., 0.11] + ↓ +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ Linear Layer 1 │ +│ (512 → 2048 expansion) │ +│ │ +│ H = X Ɨ W₁ │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + │ + │ [0.21, -0.15, 0.28, ..., 0.19] (2048D) + ↓ +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ GELU Activation │ +│ (Non-linear transformation) │ +│ │ +│ H' = GELU(H) │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + │ + │ [0.115, -0.058, 0.168, ..., 0.109] (2048D) + ↓ +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ Linear Layer 2 │ +│ (2048 → 512 compression) │ +│ │ +│ O = H' Ɨ Wā‚‚ │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + │ + │ [0.18, -0.12, 0.24, ..., 0.14] (512D) + ↓ + Output Vector (512D) +``` + +### Dimension Flow + +``` +512 ──→ [Expand] ──→ 2048 ──→ [Transform] ──→ 2048 ──→ [Compress] ──→ 512 +``` + +**Like a funnel:** Expand → Transform → Compress + +--- + +## 3.8 Why Expand and Compress? + +### The Expansion-Compression Strategy + +**Why not stay at 512 dimensions?** + +**Answer:** Expansion provides "working space" + +**Analogy:** +- Think of doing math on paper +- Small paper (512D) = limited space +- Large paper (2048D) = room to work +- Then copy results back to small paper (512D) + +**Benefits:** +1. **More capacity:** 2048 dimensions = more parameters to learn +2. **Better transformations:** More space = more complex functions +3. **Feature refinement:** Transformation happens in expanded space + +**Why compress back?** + +**Answer:** Maintain consistent size throughout the model + +- All layers use 512 dimensions +- Consistent size enables stacking layers +- Easier to manage and optimize + +--- + +## 3.9 Key Takeaways: Feed-Forward + +āœ… **FFN transforms features through expansion and compression** +āœ… **Expands to larger dimension for processing** +āœ… **Applies non-linear transformation (GELU)** +āœ… **Compresses back to original dimension** +āœ… **Processes each position independently** + +--- + +*This document provides a step-by-step explanation of feed-forward networks, the component that transforms and refines features in transformer models.* + diff --git a/docs/GENERATION_EXPLAINED.md b/docs/GENERATION_EXPLAINED.md new file mode 100644 index 0000000..f78c02c --- /dev/null +++ b/docs/GENERATION_EXPLAINED.md @@ -0,0 +1,748 @@ +# What is Generation? Step-by-Step Explanation + +Complete step-by-step explanation of text generation: how models generate text using autoregressive generation, sampling, and decoding strategies. + +## Table of Contents + +1. [What is Generation?](#91-what-is-generation) +2. [Autoregressive Generation](#92-autoregressive-generation) +3. [Sampling Strategies](#93-sampling-strategies) +4. [Temperature](#94-temperature) +5. [Top-k Sampling](#95-top-k-sampling) +6. [Top-p (Nucleus) Sampling](#96-top-p-nucleus-sampling) +7. [Step-by-Step Generation Process](#97-step-by-step-generation-process) +8. [Exercise: Complete Generation Example](#98-exercise-complete-generation-example) +9. [Key Takeaways](#99-key-takeaways) + +--- + +## 9.1 What is Generation? + +### Simple Definition + +**Generation** (text generation) is the process of using a trained model to produce new text, one token at a time, based on a given prompt. + +### Visual Analogy + +**Think of generation like writing a story:** + +``` +Prompt: "Once upon a time" + +Model generates: + "Once upon a time" → "there" + "Once upon a time there" → "was" + "Once upon a time there was" → "a" + "Once upon a time there was a" → "princess" + ... + +Final: "Once upon a time there was a princess..." +``` + +**Model predicts next word, one at a time!** + +### What Generation Does + +**Generation:** +1. **Takes** a prompt (starting text) +2. **Predicts** next token probabilities +3. **Samples** a token from distribution +4. **Appends** token to sequence +5. **Repeats** until complete + +**Result:** Generated text continuation! + +--- + +## 9.2 Autoregressive Generation + +### What is Autoregressive? + +**Autoregressive** means the model uses its own previous outputs as inputs for the next prediction. + +### How It Works + +**Step 1: Initial Prompt** +``` +Prompt: "Hello" +Sequence: ["Hello"] +``` + +**Step 2: First Prediction** +``` +Input: ["Hello"] +Model output: Probabilities for next token + "World": 0.4 + "there": 0.3 + "friend": 0.2 + ... +``` + +**Step 3: Sample Token** +``` +Sample: "World" (selected) +Sequence: ["Hello", "World"] +``` + +**Step 4: Second Prediction** +``` +Input: ["Hello", "World"] +Model output: Probabilities for next token + "!": 0.5 + ".": 0.3 + ",": 0.1 + ... +``` + +**Step 5: Continue** +``` +Sample: "!" +Sequence: ["Hello", "World", "!"] +Continue until max length or stop token... +``` + +### Mathematical Formulation + +**For prompt $\mathbf{P} = [p_1, ..., p_k]$:** + +**Initialization:** +```math +\mathbf{T}_0 = \mathbf{P} +``` + +**For each step $t \geq k+1$:** + +1. **Forward pass:** + ```math + \mathbf{L}_t = \text{Model}(\mathbf{T}_{t-1}) + ``` + +2. **Get next token probabilities:** + ```math + \mathbf{p}_t = \text{softmax}(\mathbf{L}_t[:, -1, :]) + ``` + +3. **Sample token:** + ```math + t_t \sim \text{Categorical}(\mathbf{p}_t) + ``` + +4. **Append token:** + ```math + \mathbf{T}_t = [\mathbf{T}_{t-1}, t_t] + ``` + +**Repeat until stop condition!** + +--- + +## 9.3 Sampling Strategies + +### Deterministic vs Stochastic + +**Deterministic (Greedy):** +``` +Always pick highest probability: + "World": 0.4 ← Highest + "there": 0.3 + "friend": 0.2 + +→ Always picks "World" +→ Same output every time +``` + +**Stochastic (Sampling):** +``` +Sample from distribution: + "World": 0.4 (40% chance) + "there": 0.3 (30% chance) + "friend": 0.2 (20% chance) + +→ Different output each time +→ More diverse generations +``` + +### Why Sampling? + +**Greedy (Deterministic):** +- Same output every time +- Can be repetitive +- Less creative + +**Sampling:** +- Different outputs each time +- More diverse +- More creative +- Better for creative tasks + +--- + +## 9.4 Temperature + +### What is Temperature? + +**Temperature** controls the randomness of sampling by scaling the logits before applying softmax. + +### Formula + +```math +\mathbf{p}_t = \text{softmax}\left(\frac{\mathbf{l}_t}{T}\right) +``` + +**Where:** +- $\mathbf{l}_t$ = logits (raw scores) +- $T$ = temperature +- $\mathbf{p}_t$ = probabilities + +### How Temperature Works + +**T = 0.5 (Low Temperature - More Deterministic):** +``` +Logits: [2.0, 1.0, 0.5] +After scaling: [4.0, 2.0, 1.0] +After softmax: [0.88, 0.11, 0.01] +→ Sharp distribution (one token dominates) +→ More deterministic +``` + +**T = 1.0 (Standard Temperature):** +``` +Logits: [2.0, 1.0, 0.5] +After scaling: [2.0, 1.0, 0.5] +After softmax: [0.66, 0.24, 0.10] +→ Moderate distribution +→ Balanced +``` + +**T = 2.0 (High Temperature - More Random):** +``` +Logits: [2.0, 1.0, 0.5] +After scaling: [1.0, 0.5, 0.25] +After softmax: [0.52, 0.31, 0.17] +→ Flat distribution (more uniform) +→ More random +``` + +### Visual Comparison + +``` +Probability + │ + 1.0│ T=0.5: ā— + │ + 0.8│ + │ + 0.6│ T=1.0: ā— + │ + 0.4│ + │ + 0.2│ T=2.0: ā— + │ + 0.0ā”œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ Token + "World" "there" "friend" +``` + +**Lower T = Sharper distribution = More deterministic** +**Higher T = Flatter distribution = More random** + +### When to Use Different Temperatures + +**Low Temperature (T < 1.0):** +- Factual tasks +- Reproducible outputs +- When you want consistent results + +**Standard Temperature (T = 1.0):** +- Default setting +- Balanced behavior +- Good for most tasks + +**High Temperature (T > 1.0):** +- Creative writing +- Diverse outputs +- When you want variety + +--- + +## 9.5 Top-k Sampling + +### What is Top-k? + +**Top-k sampling** limits the sampling to only the top k most likely tokens. + +### How It Works + +**Step 1: Get Probabilities** +``` +All tokens: + "World": 0.4 + "there": 0.3 + "friend": 0.2 + "hello": 0.05 + "cat": 0.03 + "dog": 0.02 + ... +``` + +**Step 2: Select Top-k (e.g., k=3)** +``` +Top 3: + "World": 0.4 + "there": 0.3 + "friend": 0.2 +``` + +**Step 3: Remove Others** +``` +Set others to 0: + "World": 0.4 + "there": 0.3 + "friend": 0.2 + "hello": 0.0 + "cat": 0.0 + "dog": 0.0 + ... +``` + +**Step 4: Renormalize** +``` +Sum = 0.4 + 0.3 + 0.2 = 0.9 +Renormalize: + "World": 0.4/0.9 = 0.44 + "there": 0.3/0.9 = 0.33 + "friend": 0.2/0.9 = 0.22 +``` + +**Step 5: Sample from Top-k** +``` +Sample from these 3 tokens only +``` + +### Mathematical Formulation + +**Given probabilities $\mathbf{p}_t$ and top-k:** + +```math +\mathbf{p}_t^{topk}[v] = \begin{cases} +\frac{\mathbf{p}_t[v]}{\sum_{u \in \text{top-k}} \mathbf{p}_t[u]} & \text{if } v \in \text{top-k} \\ +0 & \text{otherwise} +\end{cases} +``` + +### Why Top-k? + +**Benefits:** +- Removes low-probability tokens +- Focuses on likely candidates +- Reduces randomness from unlikely tokens +- Better quality generations + +**Example:** +``` +Without top-k: Might sample "xyz" (very unlikely) +With top-k=50: Only samples from top 50 tokens +→ Better quality! +``` + +--- + +## 9.6 Top-p (Nucleus) Sampling + +### What is Top-p? + +**Top-p (nucleus) sampling** keeps the smallest set of tokens whose cumulative probability is at least p. + +### How It Works + +**Step 1: Sort Probabilities** +``` +Sorted (descending): + "World": 0.4 + "there": 0.3 + "friend": 0.2 + "hello": 0.05 + "cat": 0.03 + "dog": 0.02 + ... +``` + +**Step 2: Compute Cumulative Probabilities** +``` +Cumulative: + "World": 0.4 + "there": 0.7 (0.4 + 0.3) + "friend": 0.9 (0.7 + 0.2) + "hello": 0.95 (0.9 + 0.05) + "cat": 0.98 (0.95 + 0.03) + ... +``` + +**Step 3: Find Nucleus (e.g., p=0.9)** +``` +Find smallest set where sum ≄ 0.9: + "World": 0.4 + "there": 0.3 + "friend": 0.2 + Cumulative: 0.9 āœ“ + +→ Keep these 3 tokens +``` + +**Step 4: Remove Others** +``` +Keep: + "World": 0.4 + "there": 0.3 + "friend": 0.2 + Others: 0.0 +``` + +**Step 5: Renormalize and Sample** +``` +Renormalize and sample +``` + +### Mathematical Formulation + +**Given probabilities $\mathbf{p}_t$ and top-p:** + +**Find smallest set S:** +```math +S = \arg\min \{ |S'| : \sum_{v \in S'} \mathbf{p}_t[v] \geq p \} +``` + +**Then:** +```math +\mathbf{p}_t^{topp}[v] = \begin{cases} +\frac{\mathbf{p}_t[v]}{\sum_{u \in S} \mathbf{p}_t[u]} & \text{if } v \in S \\ +0 & \text{otherwise} +\end{cases} +``` + +### Why Top-p? + +**Benefits:** +- Adapts to distribution shape +- Keeps relevant tokens dynamically +- Better than fixed k in some cases +- More flexible than top-k + +**Example:** +``` +Sharp distribution: Top-p=0.9 might keep 3 tokens +Flat distribution: Top-p=0.9 might keep 50 tokens +→ Adapts automatically! +``` + +--- + +## 9.7 Step-by-Step Generation Process + +### Complete Process + +**Given prompt: "Hello"** + +#### Step 1: Encode Prompt + +``` +Prompt: "Hello" +Token IDs: [72] +``` + +#### Step 2: Forward Pass + +``` +Input: [72] +Model processes through layers +Output: Logits for all tokens + Token 72: 5.2 + Token 87: 4.8 ← "World" + Token 101: 3.2 ← "there" + Token 108: 2.1 ← "friend" + ... +``` + +#### Step 3: Apply Temperature + +``` +Temperature: T = 1.0 +Scaled logits: Same as above +``` + +#### Step 4: Apply Top-k (Optional) + +``` +Top-k: k = 50 +Keep top 50 tokens, remove others +``` + +#### Step 5: Apply Top-p (Optional) + +``` +Top-p: p = 0.95 +Keep tokens with cumulative prob ≄ 0.95 +``` + +#### Step 6: Compute Probabilities + +``` +Apply softmax: + "World": 0.4 + "there": 0.3 + "friend": 0.2 + ... +``` + +#### Step 7: Sample Token + +``` +Sample from distribution: +Selected: "World" (token 87) +``` + +#### Step 8: Append Token + +``` +Sequence: [72, 87] +Text: "Hello World" +``` + +#### Step 9: Repeat + +``` +Input: [72, 87] +→ Predict next token +→ Sample +→ Append +→ Repeat... +``` + +--- + +## 9.8 Exercise: Complete Generation Example + +### Problem + +**Given:** +- Prompt: "The" +- Model logits for next token: `[10.0, 8.0, 5.0, 2.0, 1.0, 0.5, ...]` (for tokens: "cat", "dog", "car", "house", "tree", "book", ...) +- Temperature: T = 1.0 +- Top-k: k = 3 +- Top-p: p = 0.9 + +**Generate the next token step-by-step.** + +### Step-by-Step Solution + +#### Step 1: Initial Setup + +**Prompt:** +``` +"The" +Token IDs: [32] (assuming "The" = token 32) +``` + +**Logits:** +``` +Token "cat": 10.0 +Token "dog": 8.0 +Token "car": 5.0 +Token "house": 2.0 +Token "tree": 1.0 +Token "book": 0.5 +... +``` + +#### Step 2: Apply Temperature + +**Temperature: T = 1.0** + +**Scaled logits (divide by T):** +``` +Token "cat": 10.0 / 1.0 = 10.0 +Token "dog": 8.0 / 1.0 = 8.0 +Token "car": 5.0 / 1.0 = 5.0 +Token "house": 2.0 / 1.0 = 2.0 +Token "tree": 1.0 / 1.0 = 1.0 +Token "book": 0.5 / 1.0 = 0.5 +``` + +**No change (T=1.0 is identity)** + +#### Step 3: Apply Top-k Filtering + +**Top-k: k = 3** + +**Select top 3 tokens:** +``` +Top 3: + "cat": 10.0 + "dog": 8.0 + "car": 5.0 +``` + +**Set others to -āˆž:** +``` +Token "cat": 10.0 +Token "dog": 8.0 +Token "car": 5.0 +Token "house": -āˆž +Token "tree": -āˆž +Token "book": -āˆž +``` + +#### Step 4: Apply Top-p Filtering + +**First, compute probabilities from top-k tokens:** + +**Apply softmax:** +``` +exp(10.0) = 22026.47 +exp(8.0) = 2980.96 +exp(5.0) = 148.41 +Sum = 25155.84 + +P("cat") = 22026.47 / 25155.84 ā‰ˆ 0.875 +P("dog") = 2980.96 / 25155.84 ā‰ˆ 0.119 +P("car") = 148.41 / 25155.84 ā‰ˆ 0.006 +``` + +**Cumulative probabilities:** +``` +"cat": 0.875 +"dog": 0.994 (0.875 + 0.119) +"car": 1.000 (0.994 + 0.006) +``` + +**Find smallest set where sum ≄ 0.9:** +``` +"cat": 0.875 < 0.9 +"cat" + "dog": 0.994 ≄ 0.9 āœ“ + +→ Keep "cat" and "dog" +→ Remove "car" +``` + +**Result:** +``` +Token "cat": 10.0 +Token "dog": 8.0 +Token "car": -āˆž (removed) +``` + +#### Step 5: Compute Final Probabilities + +**Apply softmax to remaining tokens:** +``` +exp(10.0) = 22026.47 +exp(8.0) = 2980.96 +Sum = 25007.43 + +P("cat") = 22026.47 / 25007.43 ā‰ˆ 0.881 +P("dog") = 2980.96 / 25007.43 ā‰ˆ 0.119 +``` + +#### Step 6: Sample Token + +**Sample from distribution:** +``` +Random number: 0.75 + +Cumulative: + "cat": 0.881 ← 0.75 falls here + "dog": 1.000 + +→ Selected: "cat" +``` + +### Answer + +**Generated token: "cat"** + +**Final sequence:** +``` +Prompt: "The" +Generated: "cat" +Full text: "The cat" +``` + +### Summary + +| Step | Operation | Result | +|------|-----------|--------| +| 1 | Initial logits | [10.0, 8.0, 5.0, 2.0, ...] | +| 2 | Apply temperature (T=1.0) | [10.0, 8.0, 5.0, 2.0, ...] | +| 3 | Top-k filtering (k=3) | Keep top 3: [10.0, 8.0, 5.0] | +| 4 | Top-p filtering (p=0.9) | Keep cumulative ≄0.9: [10.0, 8.0] | +| 5 | Compute probabilities | [0.881, 0.119] | +| 6 | Sample | "cat" selected | + +**The model generated "cat" following "The"!** + +--- + +## 9.9 Key Takeaways + +### Generation + +āœ… **Generation produces text one token at a time** +āœ… **Autoregressive: uses previous outputs as inputs** +āœ… **Iterative process: predict → sample → append → repeat** + +### Sampling Strategies + +āœ… **Temperature: Controls randomness (lower = deterministic, higher = random)** +āœ… **Top-k: Limits to top k tokens** +āœ… **Top-p: Keeps smallest set with cumulative probability ≄ p** +āœ… **Combined: Often use temperature + top-k or top-p** + +### Why Important + +āœ… **Enables text generation from trained models** +āœ… **Different strategies produce different outputs** +āœ… **Essential for language model deployment** + +--- + +## Mathematical Summary + +### Generation Process + +**Initialization:** +```math +\mathbf{T}_0 = \mathbf{P} +``` + +**For each step $t$:** +```math +\mathbf{l}_t = \text{Model}(\mathbf{T}_{t-1})[:, -1, :] +``` + +```math +\mathbf{l}_t' = \frac{\mathbf{l}_t}{T} \quad \text{(temperature)} +``` + +```math +\mathbf{l}_t'' = \text{Top-k}(\mathbf{l}_t') \quad \text{(optional)} +``` + +```math +\mathbf{l}_t''' = \text{Top-p}(\mathbf{l}_t'') \quad \text{(optional)} +``` + +```math +\mathbf{p}_t = \text{softmax}(\mathbf{l}_t''') +``` + +```math +t_t \sim \text{Categorical}(\mathbf{p}_t) +``` + +```math +\mathbf{T}_t = [\mathbf{T}_{t-1}, t_t] +``` + +--- + +*This document provides a comprehensive explanation of text generation, including autoregressive generation, sampling strategies, temperature, top-k, and top-p with mathematical formulations and solved exercises.* + diff --git a/docs/MATHEMATICS.md b/docs/MATHEMATICS.md new file mode 100644 index 0000000..431ff2d --- /dev/null +++ b/docs/MATHEMATICS.md @@ -0,0 +1,1087 @@ +# SheepOp LLM - Complete Mathematical Formulation + +Complete mathematical derivation and step-by-step solutions for every component of the SheepOp Language Model. + +## Table of Contents + +1. [Data Processing and Tokenization](#1-data-processing-and-tokenization) +2. [Token Embedding](#2-token-embedding) +3. [Positional Encoding](#3-positional-encoding) +4. [Multi-Head Self-Attention](#4-multi-head-self-attention) +5. [Feed-Forward Network](#5-feed-forward-network) +6. [Layer Normalization](#6-layer-normalization) +7. [Transformer Block](#7-transformer-block) +8. [Complete Forward Pass](#8-complete-forward-pass) +9. [Loss Computation](#9-loss-computation) +10. [Backpropagation](#10-backpropagation) +11. [AdamW Optimizer Update](#11-adamw-optimizer-update) +12. [Learning Rate Scheduling](#12-learning-rate-scheduling) +13. [Text Generation](#13-text-generation) + +--- + +## 1. Data Processing and Tokenization + +### 1.1 Text Extraction + +Given a text file with lines, we extract text samples: + +**Input:** Raw text files, PDFs, images, code files +**Output:** List of text strings $S = \{s_1, s_2, \ldots, s_N\}$ where each $s_i$ is a text line + +**Example:** +``` +Input: "Hello world\nMachine learning is cool." +Output: S = ["Hello world", "Machine learning is cool."] +``` + +### 1.2 Character-Level Tokenization + +**Vocabulary Construction:** + +For character-level tokenization, we create a vocabulary $V$ mapping characters to token IDs: + +```math +V = \{(\text{}, 0), (\text{}, 1), (\text{}, 2), (\text{}, 3), (\text{space}, 4), (\text{!}, 5), \ldots, (\text{z}, 129)\} +``` + +Or more formally: + +```math +V: \mathcal{C} \rightarrow \mathbb{N}, \quad V(c) = \begin{cases} +0 & \text{if } c = \text{} \\ +1 & \text{if } c = \text{} \\ +2 & \text{if } c = \text{} \\ +3 & \text{if } c = \text{} \\ +4 & \text{if } c = \text{space} \\ +\vdots & \\ +129 & \text{if } c = \text{z} +\end{cases} +``` + +where $\mathcal{C}$ is the set of all characters in the vocabulary. + +**Encoding Function:** + +For a text string $s = c_1 c_2 \ldots c_n$ where $c_i$ are characters: + +```math +\text{encode}(s) = [V[c_1], V[c_2], \ldots, V[c_n]] +``` + +**Example:** +``` +Input: "Hi" +s = ['H', 'i'] +V = {'H': 72, 'i': 105} # ASCII values +encode("Hi") = [72, 105] +``` + +**Decoding Function:** + +```math +\text{decode}([t_1, t_2, ..., t_n]) = V^{-1}[t_1] \cdot V^{-1}[t_2] \cdot \ldots \cdot V^{-1}[t_n] +``` + +where $V^{-1}$ is the inverse mapping from token IDs to characters. + +### 1.3 Sequence Chunking + +For a token sequence $T = [t_1, t_2, ..., t_L]$ and maximum length $M$: + +**Chunking:** + +```math +\text{chunks} = \{[t_{i\cdot S}, t_{i\cdot S+1}, ..., t_{\min(i\cdot S+M, L)}] : i \in \{0, 1, ..., \lfloor\frac{L-M}{S}\rfloor\}\} +``` + +where $S$ is the stride (default $S = M$). + +**Padding:** + +For a chunk $C$ with length $|C| < M$: + +```math +\text{padded}(C) = C \oplus [\text{pad\_token}]^{(M - |C|)} +``` + +**Example:** +``` +M = 5, S = 5 +T = [72, 105, 44, 32, 119, 111, 114, 108, 100] +Chunk 1: [72, 105, 44, 32, 119] +Chunk 2: [111, 114, 108, 100, ] +``` + +--- + +## 2. Token Embedding + +### 2.1 Embedding Matrix + +We have an embedding matrix $E \in \mathbb{R}^{V \times d}$ where: +- $V$ = vocabulary size +- $d$ = embedding dimension (d_model) + +### 2.2 Embedding Lookup + +For input token IDs $\mathbf{t} = [t_1, t_2, ..., t_n]$: + +```math +\mathbf{X} = E[\mathbf{t}] = \begin{bmatrix} E[t_1] \\ E[t_2] \\ \vdots \\ E[t_n] \end{bmatrix} \in \mathbb{R}^{n \times d} +``` + +**Example:** +``` +V = 128, d = 512 +t = [72, 105] +E[72] = [0.1, -0.2, ..., 0.05] (512-dim vector) +E[105] = [-0.1, 0.3, ..., 0.02] (512-dim vector) + +X = [[0.1, -0.2, ..., 0.05], + [-0.1, 0.3, ..., 0.02]] +``` + +**Batch Processing:** + +For batch size $B$: + +```math +\mathbf{X} = E[\mathbf{T}] \in \mathbb{R}^{B \times n \times d} +``` + +where $\mathbf{T} \in \mathbb{N}^{B \times n}$ is the batch of token IDs. + +--- + +## 3. Positional Encoding + +### 3.1 Sinusoidal Positional Encoding + +For position $pos$ and dimension $i$: + +```math +PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d}}\right) +``` + +```math +PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d}}\right) +``` + +**Origin of the 10000 Constant:** + +The constant $10000$ is a **hyperparameter** introduced in the original "Attention Is All You Need" paper (Vaswani et al., 2017). This value controls the **frequency** (or wavelength) of the sinusoidal functions used for positional encoding. + +**What 10000 Controls:** + +The term $10000^{2i/d}$ creates a **geometric progression** of frequencies across different dimensions: + +- **Lower dimensions** (small $i$): Higher frequencies (faster oscillation) +- **Higher dimensions** (large $i$): Lower frequencies (slower oscillation) + +**Mathematical Interpretation:** + +The wavelength $\lambda_i$ for dimension pair $(2i, 2i+1)$ is: + +```math +\lambda_i = 2\pi \cdot 10000^{2i/d} +``` + +This means: +- When $i = 0$: $\lambda_0 = 2\pi \cdot 10000^{0} = 2\pi \approx 6.28$ (short wavelength) +- When $i = d/2 - 1$: $\lambda_{d/2-1} = 2\pi \cdot 10000^{(d-2)/d} \approx 2\pi \cdot 10000$ (long wavelength) + +**Why 10000?** + +1. **Scale Balance**: It provides a good balance between: + - Being large enough to create distinguishable patterns across positions + - Being small enough to prevent numerical issues + +2. **Empirical Choice**: The authors found this value works well for typical sequence lengths (up to ~5000 tokens) + +3. **Frequency Range**: For $d = 512$: + - Lowest frequency: $\frac{1}{10000^{512/512}} = \frac{1}{10000} = 0.0001$ cycles per position + - Highest frequency: $\frac{1}{10000^{0/512}} = 1$ cycle per position + - This covers a wide range allowing the model to capture both local and long-range positional patterns + +**What Happens if We Change It?** + +- **Smaller values** (e.g., 100): Higher frequencies overall → better for short sequences, but may cause aliasing for long sequences +- **Larger values** (e.g., 100000): Lower frequencies overall → better for very long sequences, but may lose fine-grained positional information +- **Different values** are sometimes used: Some models use 10000, others use 5000 or 20000 depending on their typical sequence lengths + +**Example Frequency Analysis:** + +For $d = 512$: + +``` +i = 0: 10000^(0/512) = 1.0 → wavelength ā‰ˆ 6.28 positions +i = 64: 10000^(128/512) = 10 → wavelength ā‰ˆ 62.8 positions +i = 128: 10000^(256/512) = 100 → wavelength ā‰ˆ 628 positions +i = 256: 10000^(512/512) = 10000 → wavelength ā‰ˆ 62,832 positions +``` + +This creates a **multi-scale representation** where different dimensions encode positional information at different resolutions. + +**Simplified Form:** + +```math +PE_{(pos, 2i)} = \sin\left(pos \cdot \exp\left(-\frac{2i \log(10000)}{d}\right)\right) +``` + +```math +PE_{(pos, 2i+1)} = \cos\left(pos \cdot \exp\left(-\frac{2i \log(10000)}{d}\right)\right) +``` + +### 3.2 Positional Encoding Matrix + +For sequence length $n$ and model dimension $d$: + +```math +PE = \begin{bmatrix} +PE_{(0,0)} & PE_{(0,1)} & \cdots & PE_{(0,d-1)} \\ +PE_{(1,0)} & PE_{(1,1)} & \cdots & PE_{(1,d-1)} \\ +\vdots & \vdots & \ddots & \vdots \\ +PE_{(n-1,0)} & PE_{(n-1,1)} & \cdots & PE_{(n-1,d-1)} +\end{bmatrix} \in \mathbb{R}^{n \times d} +``` + +### 3.3 Adding Positional Encoding + +```math +\mathbf{X}' = \mathbf{X} + PE +``` + +**Example Calculation:** + +``` +d = 512, pos = 0, i = 0: +PE(0,0) = sin(0 / 10000^(0/512)) = sin(0) = 0 +PE(0,1) = cos(0 / 10000^(0/512)) = cos(0) = 1 + +pos = 0, i = 1: +PE(0,2) = sin(0 / 10000^(2/512)) = sin(0) = 0 +PE(0,3) = cos(0 / 10000^(2/512)) = cos(0) = 1 + +pos = 1, i = 0: +PE(1,0) = sin(1 / 10000^(0/512)) = sin(1) ā‰ˆ 0.8415 +PE(1,1) = cos(1 / 10000^(0/512)) = cos(1) ā‰ˆ 0.5403 +``` + +### 3.4 Dropout Application + +```math +\mathbf{X}'' = \text{Dropout}(\mathbf{X}', p) +``` + +where $p$ is the dropout probability (typically 0.1). + +--- + +## 4. Multi-Head Self-Attention + +### 4.1 Query, Key, Value Projections + +For input $\mathbf{X} \in \mathbb{R}^{B \times n \times d}$: + +```math +\mathbf{Q} = \mathbf{X} W_Q, \quad \mathbf{K} = \mathbf{X} W_K, \quad \mathbf{V} = \mathbf{X} W_V +``` + +where: +- $W_Q, W_K, W_V \in \mathbb{R}^{d \times d}$ are learnable weight matrices +- $\mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{B \times n \times d}$ + +**Example:** +``` +B = 2, n = 5, d = 512 +X shape: [2, 5, 512] +W_Q shape: [512, 512] +Q = X @ W_Q → [2, 5, 512] +``` + +### 4.2 Multi-Head Splitting + +For $h$ heads: + +```math +d_k = \frac{d}{h} +``` + +```math +\mathbf{Q}_i = \mathbf{Q}[:, :, i \cdot d_k : (i+1) \cdot d_k] \in \mathbb{R}^{B \times n \times d_k} +``` + +```math +\mathbf{K}_i = \mathbf{K}[:, :, i \cdot d_k : (i+1) \cdot d_k] \in \mathbb{R}^{B \times n \times d_k} +``` + +```math +\mathbf{V}_i = \mathbf{V}[:, :, i \cdot d_k : (i+1) \cdot d_k] \in \mathbb{R}^{B \times n \times d_k} +``` + +**Reshaping:** + +```math +\mathbf{Q}_i \in \mathbb{R}^{B \times h \times n \times d_k} +``` + +**Example:** +``` +d = 512, h = 8, d_k = 64 +Q shape: [2, 5, 512] +After reshape: [2, 8, 5, 64] +``` + +### 4.3 Scaled Dot-Product Attention + +**Attention Scores:** + +```math +\mathbf{S} = \frac{\mathbf{Q}_i \mathbf{K}_i^T}{\sqrt{d_k}} \in \mathbb{R}^{B \times h \times n \times n} +``` + +**Example Calculation:** + +``` +For head i, one example: +Q_i[0,0] = [0.1, -0.2, 0.3, ..., 0.05] (64-dim) +K_i[0,0] = [0.2, 0.1, -0.1, ..., 0.1] (64-dim) + +Dot product: Q_i[0,0] Ā· K_i[0,0] = 0.1Ɨ0.2 + (-0.2)Ɨ0.1 + ... = 0.15 +Scale: 0.15 / √64 = 0.15 / 8 = 0.01875 + +Score matrix S[i,j] = Q_i[i] Ā· K_i[j] / √d_k +``` + +### 4.4 Causal Masking + +For causal (autoregressive) attention: + +```math +M_{causal} = \begin{bmatrix} +1 & -\infty & -\infty & \cdots \\ +1 & 1 & -\infty & \cdots \\ +1 & 1 & 1 & \cdots \\ +\vdots & \vdots & \vdots & \ddots +\end{bmatrix} +``` + +```math +\mathbf{S}_{masked} = \mathbf{S} + M_{causal} +``` + +**Example:** +``` +n = 3 +M_causal = [[0, -inf, -inf], + [0, 0, -inf], + [0, 0, 0]] + +S = [[0.2, 0.1, 0.3], + [0.1, 0.4, 0.2], + [0.3, 0.2, 0.5]] + +S_masked = [[0.2, -inf, -inf], + [0.1, 0.4, -inf], + [0.3, 0.2, 0.5]] +``` + +### 4.5 Softmax Normalization + +```math +\mathbf{A} = \text{softmax}(\mathbf{S}_{masked}) = \frac{\exp(\mathbf{S}_{masked})}{\sum_{j=1}^{n} \exp(\mathbf{S}_{masked}[i,j])} +``` + +**Element-wise:** + +```math +A_{ij} = \frac{\exp(S_{masked,ij})}{\sum_{k=1}^{n} \exp(S_{masked,ik})} +``` + +**Example:** +``` +S_masked = [[0.2, -inf, -inf], + [0.1, 0.4, -inf], + [0.3, 0.2, 0.5]] + +For row 0: +exp(0.2) = 1.221, exp(-inf) = 0, exp(-inf) = 0 +sum = 1.221 +A[0,0] = 1.221/1.221 = 1.0 +A[0,1] = 0/1.221 = 0 +A[0,2] = 0/1.221 = 0 + +For row 1: +exp(0.1) = 1.105, exp(0.4) = 1.492, exp(-inf) = 0 +sum = 2.597 +A[1,0] = 1.105/2.597 ā‰ˆ 0.426 +A[1,1] = 1.492/2.597 ā‰ˆ 0.574 +A[1,2] = 0/2.597 = 0 + +A = [[1.0, 0.0, 0.0], + [0.426, 0.574, 0.0], + [0.268, 0.263, 0.469]] +``` + +### 4.6 Attention Application + +```math +\mathbf{O}_i = \mathbf{A}_i \mathbf{V}_i \in \mathbb{R}^{B \times h \times n \times d_k} +``` + +**Example:** +``` +A[0] = [1.0, 0.0, 0.0] +V[0] = [[0.1, 0.2, ...], + [0.3, 0.4, ...], + [0.5, 0.6, ...]] + +O[0] = 1.0Ɨ[0.1,0.2,...] + 0.0Ɨ[0.3,0.4,...] + 0.0Ɨ[0.5,0.6,...] + = [0.1, 0.2, ...] +``` + +### 4.7 Concatenation and Output Projection + +**Concatenate heads:** + +```math +\mathbf{O} = \text{Concat}(\mathbf{O}_1, \mathbf{O}_2, ..., \mathbf{O}_h) \in \mathbb{R}^{B \times n \times d} +``` + +**Output projection:** + +```math +\text{Attention}(\mathbf{X}) = \mathbf{O} W_O \in \mathbb{R}^{B \times n \times d} +``` + +where $W_O \in \mathbb{R}^{d \times d}$ is the output projection weight matrix. + +--- + +## 5. Feed-Forward Network + +### 5.1 Feed-Forward Computation + +```math +\text{FFN}(\mathbf{X}) = \text{ReLU}(\mathbf{X} W_1 + \mathbf{b}_1) W_2 + \mathbf{b}_2 +``` + +Using GELU activation (default): + +```math +\text{GELU}(x) = x \cdot \Phi(x) = x \cdot \frac{1}{2}\left(1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right) +``` + +where $\Phi(x)$ is the standard normal CDF. + +**Approximation:** + +```math +\text{GELU}(x) \approx 0.5x\left(1 + \tanh\left(\sqrt{\frac{2}{\pi}}\left(x + 0.044715x^3\right)\right)\right) +``` + +**Complete FFN:** + +```math +\mathbf{H} = \mathbf{X} W_1 \in \mathbb{R}^{B \times n \times d_{ff}} +``` + +```math +\mathbf{H}' = \text{GELU}(\mathbf{H}) \in \mathbb{R}^{B \times n \times d_{ff}} +``` + +```math +\mathbf{H}'' = \text{Dropout}(\mathbf{H}', p) +``` + +```math +\text{FFN}(\mathbf{X}) = \mathbf{H}'' W_2 \in \mathbb{R}^{B \times n \times d} +``` + +**Example:** +``` +d = 512, d_ff = 2048 +X shape: [2, 5, 512] +W1 shape: [512, 2048] +H = X @ W1 → [2, 5, 2048] +H' = GELU(H) → [2, 5, 2048] +H'' = Dropout(H', 0.1) → [2, 5, 2048] +W2 shape: [2048, 512] +FFN(X) = H'' @ W2 → [2, 5, 512] +``` + +--- + +## 6. Layer Normalization + +### 6.1 Layer Normalization Formula + +For input $\mathbf{x} \in \mathbb{R}^d$: + +```math +\mu = \frac{1}{d} \sum_{i=1}^{d} x_i +``` + +```math +\sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2 +``` + +```math +\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} +``` + +```math +\text{LayerNorm}(\mathbf{x}) = \gamma \odot \hat{\mathbf{x}} + \beta +``` + +where: +- $\epsilon$ = small constant (default 1e-5) +- $\gamma$ = learnable scale parameter +- $\beta$ = learnable shift parameter +- $\odot$ = element-wise multiplication + +**Example:** +``` +x = [1.0, 2.0, 3.0, 4.0] +d = 4 +μ = (1.0 + 2.0 + 3.0 + 4.0) / 4 = 2.5 +σ² = ((1-2.5)² + (2-2.5)² + (3-2.5)² + (4-2.5)²) / 4 + = (2.25 + 0.25 + 0.25 + 2.25) / 4 = 1.25 +σ = √1.25 ā‰ˆ 1.118 + +ε = 1e-5 +xĢ‚ = [(1-2.5)/(1.118+1e-5), (2-2.5)/(1.118+1e-5), ...] + = [-1.341, -0.447, 0.447, 1.341] + +γ = [1.0, 1.0, 1.0, 1.0] (initialized) +β = [0.0, 0.0, 0.0, 0.0] (initialized) +LayerNorm(x) = γ āŠ™ xĢ‚ + β = xĢ‚ +``` + +--- + +## 7. Transformer Block + +### 7.1 Pre-Norm Architecture + +**Self-Attention Block:** + +```math +\mathbf{X}_1 = \mathbf{X} + \text{Dropout}(\text{Attention}(\text{LayerNorm}(\mathbf{X})), p) +``` + +**Feed-Forward Block:** + +```math +\mathbf{X}_2 = \mathbf{X}_1 + \text{Dropout}(\text{FFN}(\text{LayerNorm}(\mathbf{X}_1)), p) +``` + +**Complete Transformer Block:** + +```math +\mathbf{X}_{out} = \text{TransformerBlock}(\mathbf{X}_{in}) +``` + +**Step-by-step:** + +```math +1. \mathbf{X}_{norm1} = \text{LayerNorm}(\mathbf{X}_{in}) +2. \mathbf{X}_{attn} = \text{Attention}(\mathbf{X}_{norm1}) +3. \mathbf{X}_{attn\_drop} = \text{Dropout}(\mathbf{X}_{attn}, p) +4. \mathbf{X}_1 = \mathbf{X}_{in} + \mathbf{X}_{attn\_drop}$ (residual connection) +5. \mathbf{X}_{norm2} = \text{LayerNorm}(\mathbf{X}_1) +6. \mathbf{X}_{ffn} = \text{FFN}(\mathbf{X}_{norm2}) +7. \mathbf{X}_{ffn\_drop} = \text{Dropout}(\mathbf{X}_{ffn}, p) +8. \mathbf{X}_{out} = \mathbf{X}_1 + \mathbf{X}_{ffn\_drop}$ (residual connection) +``` + +--- + +## 8. Complete Forward Pass + +### 8.1 Full Model Forward Pass + +Given input token IDs $\mathbf{T} \in \mathbb{N}^{B \times n}$: + +**Step 1: Token Embedding** +```math +\mathbf{X}_0 = E[\mathbf{T}] \in \mathbb{R}^{B \times n \times d} +``` + +**Step 2: Positional Encoding** +```math +\mathbf{X}_1 = \mathbf{X}_0 + PE \in \mathbb{R}^{B \times n \times d} +``` +```math +\mathbf{X}_2 = \text{Dropout}(\mathbf{X}_1, p) +``` + +**Step 3: Transformer Layers** + +For $L$ layers: + +```math +\mathbf{X}_{l+1} = \text{TransformerBlock}_l(\mathbf{X}_l), \quad l = 2, 3, ..., L+1 +``` + +**Step 4: Final Layer Norm** + +```math +\mathbf{X}_{final} = \text{LayerNorm}(\mathbf{X}_{L+1}) +``` + +**Step 5: Output Projection** + +```math +\mathbf{L} = \mathbf{X}_{final} W_{out} \in \mathbb{R}^{B \times n \times V} +``` + +where $W_{out} \in \mathbb{R}^{d \times V}$ is the output projection matrix. + +**Output logits:** + +```math +\text{logits}[b, t, v] = \text{log probability of token } v \text{ at position } t \text{ in batch } b +``` + +--- + +## 9. Loss Computation + +### 9.1 Cross-Entropy Loss + +For logits $\mathbf{L} \in \mathbb{R}^{B \times n \times V}$ and labels $\mathbf{Y} \in \mathbb{N}^{B \times n}$: + +**Reshape for loss:** + +```math +\mathbf{L}_{flat} = \mathbf{L}.view(B \cdot n, V) \in \mathbb{R}^{(B \cdot n) \times V} +``` + +```math +\mathbf{Y}_{flat} = \mathbf{Y}.view(B \cdot n) \in \mathbb{N}^{B \cdot n} +``` + +**Softmax probabilities:** + +```math +p_i = \frac{\exp(L_{flat}[i, y_i])}{\sum_{v=1}^{V} \exp(L_{flat}[i, v])} +``` + +**Cross-entropy loss:** + +```math +\mathcal{L} = -\frac{1}{N} \sum_{i=1}^{N} \log(p_i) +``` + +where $N$ is the number of valid (non-padding) tokens. + +**Masked loss (ignoring padding):** + +```math +\mathcal{L} = -\frac{1}{N} \sum_{i: y_i \neq \text{pad\_id}} \log(p_i) +``` + +**Example:** +``` +B = 2, n = 3, V = 128 +L shape: [2, 3, 128] +Y = [[72, 105, -100], [44, 32, 119]] (-100 is padding) + +L_flat shape: [6, 128] +Y_flat = [72, 105, -100, 44, 32, 119] + +For i=0 (y_i=72): + logits = L_flat[0] = [0.1, -0.2, ..., 0.5, ...] (128 values) + p_0 = exp(0.5) / sum(exp(logits)) ā‰ˆ 0.8 (assuming 0.5 was max) + log(p_0) = log(0.8) ā‰ˆ -0.223 + +For i=2 (y_i=-100): + Skip (padding token) + +Total loss = -1/5 * (log(p_0) + log(p_1) + log(p_3) + log(p_4) + log(p_5)) +``` + +### 9.2 Perplexity + +```math +\text{Perplexity} = \exp(\mathcal{L}) = \exp\left(-\frac{1}{N} \sum_{i=1}^{N} \log(p_i)\right) +``` + +**Example:** +``` +If L = 2.0, then Perplexity = exp(2.0) ā‰ˆ 7.39 +``` + +--- + +## 10. Backpropagation + +### 10.1 Gradient Flow + +**Loss gradient:** + +```math +\frac{\partial \mathcal{L}}{\partial \mathbf{L}_{flat}} = \frac{\partial}{\partial \mathbf{L}_{flat}} \left(-\frac{1}{N} \sum_{i=1}^{N} \log(p_i)\right) +``` + +**Chain rule through output projection:** + +```math +\frac{\partial \mathcal{L}}{\partial W_{out}} = \frac{\partial \mathcal{L}}{\partial \mathbf{L}} \cdot \frac{\partial \mathbf{L}}{\partial W_{out}} +``` + +```math +\frac{\partial \mathcal{L}}{\partial \mathbf{X}_{final}} = \frac{\partial \mathcal{L}}{\partial \mathbf{L}} \cdot W_{out}^T +``` + +**Through transformer layers (backward):** + +For layer $l$ from $L$ to $1$: + +```math +\frac{\partial \mathcal{L}}{\partial \mathbf{X}_l} = \frac{\partial \mathcal{L}}{\partial \mathbf{X}_{l+1}} \cdot \frac{\partial \mathbf{X}_{l+1}}{\partial \mathbf{X}_l} +``` + +**Residual connection gradient:** + +```math +\frac{\partial \mathcal{L}}{\partial \mathbf{X}_{in}} = \frac{\partial \mathcal{L}}{\partial \mathbf{X}_{out}} + \frac{\partial \mathcal{L}}{\partial \mathbf{X}_{residual}} +``` + +### 10.2 Attention Gradients + +**Attention weight gradients:** + +```math +\frac{\partial \mathcal{L}}{\partial \mathbf{A}} = \frac{\partial \mathcal{L}}{\partial \mathbf{O}} \cdot \mathbf{V}^T +``` + +**Query, Key, Value gradients:** + +```math +\frac{\partial \mathcal{L}}{\partial \mathbf{Q}} = \frac{\partial \mathcal{L}}{\partial \mathbf{S}} \cdot \mathbf{K} \cdot \frac{1}{\sqrt{d_k}} +``` + +```math +\frac{\partial \mathcal{L}}{\partial \mathbf{K}} = \frac{\partial \mathcal{L}}{\partial \mathbf{S}} \cdot \mathbf{Q}^T \cdot \frac{1}{\sqrt{d_k}} +``` + +```math +\frac{\partial \mathcal{L}}{\partial \mathbf{V}} = \mathbf{A}^T \cdot \frac{\partial \mathcal{L}}{\partial \mathbf{O}} +``` + +### 10.3 Gradient Clipping + +**Gradient norm:** + +```math +||\mathbf{g}|| = \sqrt{\sum_{i} g_i^2} +``` + +**Clipped gradient:** + +```math +\mathbf{g}_{clipped} = \begin{cases} +\mathbf{g} & \text{if } ||\mathbf{g}|| \leq \theta \\ +\mathbf{g} \cdot \frac{\theta}{||\mathbf{g}||} & \text{if } ||\mathbf{g}|| > \theta +\end{cases} +``` + +where $\theta$ is the max gradient norm (default 1.0). + +**Example:** +``` +g = [0.5, 0.8, 1.2] +||g|| = √(0.5² + 0.8² + 1.2²) = √(0.25 + 0.64 + 1.44) = √2.33 ā‰ˆ 1.526 +Īø = 1.0 +Since ||g|| > Īø: +g_clipped = g Ɨ (1.0 / 1.526) = [0.328, 0.524, 0.786] +``` + +--- + +## 11. AdamW Optimizer Update + +### 11.1 AdamW Algorithm + +For parameter $\theta_t$ at step $t$: + +**Momentum update:** + +```math +m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t +``` + +```math +v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 +``` + +**Bias correction:** + +```math +\hat{m}_t = \frac{m_t}{1 - \beta_1^t} +``` + +```math +\hat{v}_t = \frac{v_t}{1 - \beta_2^t} +``` + +**Parameter update:** + +```math +\theta_t = \theta_{t-1} - \eta_t \left(\frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \theta_{t-1}\right) +``` + +where: +- $\beta_1 = 0.9$ (momentum decay) +- $\beta_2 = 0.999$ (variance decay) +- $\eta_t$ = learning rate at step $t$ +- $\lambda$ = weight decay coefficient (default 0.01) +- $\epsilon = 10^{-8}$ (numerical stability) + +### 11.2 Step-by-Step Example + +**Initialization:** +``` +t = 0 +Īøā‚€ = 0.5 (initial parameter value) +mā‚€ = 0, vā‚€ = 0 +β₁ = 0.9, β₂ = 0.999 +Ī· = 0.001 (learning rate) +Ī» = 0.01 (weight decay) +ε = 1e-8 +``` + +**Step 1:** +``` +t = 1 +g₁ = 0.3 (gradient) + +m₁ = 0.9 Ɨ 0 + 0.1 Ɨ 0.3 = 0.03 +v₁ = 0.999 Ɨ 0 + 0.001 Ɨ 0.3² = 0.001 Ɨ 0.09 = 0.00009 + +m̂₁ = 0.03 / (1 - 0.9¹) = 0.03 / 0.1 = 0.3 +v̂₁ = 0.00009 / (1 - 0.999¹) = 0.00009 / 0.001 = 0.09 + +θ₁ = 0.5 - 0.001 Ɨ (0.3 / (√0.09 + 1e-8) + 0.01 Ɨ 0.5) + = 0.5 - 0.001 Ɨ (0.3 / 0.3 + 0.005) + = 0.5 - 0.001 Ɨ (1.005) + = 0.5 - 0.001005 + = 0.498995 +``` + +**Step 2:** +``` +t = 2 +gā‚‚ = -0.2 + +mā‚‚ = 0.9 Ɨ 0.03 + 0.1 Ɨ (-0.2) = 0.027 - 0.02 = 0.007 +vā‚‚ = 0.999 Ɨ 0.00009 + 0.001 Ɨ (-0.2)² = 0.00008991 + 0.00004 = 0.00012991 + +m̂₂ = 0.007 / (1 - 0.9²) = 0.007 / 0.19 = 0.0368 +v̂₂ = 0.00012991 / (1 - 0.999²) = 0.00012991 / 0.001999 ā‰ˆ 0.06496 + +Īøā‚‚ = 0.498995 - 0.001 Ɨ (0.0368 / (√0.06496 + 1e-8) + 0.01 Ɨ 0.498995) + = 0.498995 - 0.001 Ɨ (0.0368 / 0.2549 + 0.00498995) + = 0.498995 - 0.001 Ɨ (0.1444 + 0.00498995) + = 0.498995 - 0.001 Ɨ 0.1494 + = 0.498995 - 0.0001494 + = 0.498846 +``` + +### 11.3 AdamW vs Adam + +The key difference is the weight decay term: + +**Adam:** +```math +\theta_t = \theta_{t-1} - \eta_t \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} +``` + +Then separately apply weight decay: +```math +\theta_t = \theta_t (1 - \lambda) +``` + +**AdamW:** +```math +\theta_t = \theta_{t-1} - \eta_t \left(\frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \theta_{t-1}\right) +``` + +AdamW decouples weight decay from gradient-based updates, leading to better generalization. + +--- + +## 12. Learning Rate Scheduling + +### 12.1 Cosine Annealing Schedule + +```math +\eta_t = \eta_{min} + (\eta_{max} - \eta_{min}) \cdot \frac{1 + \cos(\pi \cdot \frac{t}{T_{max}})}{2} +``` + +where: +- $\eta_{max}$ = initial learning rate +- $\eta_{min}$ = minimum learning rate (default 0) +- $T_{max}$ = total number of steps +- $t$ = current step + +**Example:** +``` +Ī·_max = 0.001 +Ī·_min = 0 +T_max = 10000 +t = 0: Ī·ā‚€ = 0 + (0.001 - 0) Ɨ (1 + cos(0)) / 2 = 0.001 Ɨ 1 = 0.001 +t = 2500: Ī· = 0 + 0.001 Ɨ (1 + cos(Ļ€/4)) / 2 = 0.001 Ɨ (1 + 0.707) / 2 ā‰ˆ 0.000854 +t = 5000: Ī· = 0 + 0.001 Ɨ (1 + cos(Ļ€/2)) / 2 = 0.001 Ɨ (1 + 0) / 2 = 0.0005 +t = 7500: Ī· = 0 + 0.001 Ɨ (1 + cos(3Ļ€/4)) / 2 ā‰ˆ 0.000146 +t = 10000: Ī· = 0 + 0.001 Ɨ (1 + cos(Ļ€)) / 2 = 0.001 Ɨ (1 + (-1)) / 2 = 0 +``` + +### 12.2 Learning Rate Schedule Visualization + +The cosine annealing schedule creates a smooth decay from maximum to minimum learning rate following a cosine curve. + +--- + +## 13. Text Generation + +### 13.1 Autoregressive Generation + +Given prompt tokens $\mathbf{P} = [p_1, p_2, ..., p_k]$: + +**Initialization:** +```math +\mathbf{T}_0 = \mathbf{P} +``` + +**For each generation step $t$ from $k+1$ to $k+n$:** + +1. **Forward pass:** +```math +\mathbf{L}_t = \text{Model}(\mathbf{T}_{t-1}) +``` + +2. **Get next token logits:** +```math +\mathbf{l}_t = \mathbf{L}_t[:, -1, :] \in \mathbb{R}^{B \times V} +``` + +3. **Apply temperature:** +```math +\mathbf{l}_t' = \frac{\mathbf{l}_t}{T} +``` + where $T$ is the temperature (default 1.0). + +4. **Top-k filtering (optional):** +```math +\mathbf{l}_t''[v] = \begin{cases} +\mathbf{l}_t'[v] & \text{if } v \in \text{top-k}(\mathbf{l}_t') \\ +-\infty & \text{otherwise} +\end{cases} +``` + +5. **Top-p (nucleus) sampling (optional):** + - Sort tokens by probability + - Find smallest set $S$ where $\sum_{v \in S} p(v) \geq p$ + - Set probabilities outside $S$ to 0 + +6. **Sample token:** +```math +p_t = \text{softmax}(\mathbf{l}_t'') \in \mathbb{R}^V +t_t \sim \text{Categorical}(p_t) +``` + +7. **Append token:** +```math +\mathbf{T}_t = [\mathbf{T}_{t-1}, t_t] +``` + +### 13.2 Generation Example + +**Input:** +``` +Prompt: "Hello" +P = [72, 101, 108, 108, 111] ("Hello") +``` + +**Step 1:** +``` +Tā‚€ = [72, 101, 108, 108, 111] +Forward pass → L₁ shape: [1, 5, 128] +l₁ = L₁[0, -1, :] = [0.1, -0.2, ..., 0.8, ...] (logits for next token) + +Apply temperature T=1.0: +l₁' = l₁ / 1.0 = l₁ + +Softmax: +p₁ = softmax(l₁) = [0.001, 0.0005, ..., 0.15, ...] + +Sample (let's say token 32 = ' '): +t₁ = 32 +T₁ = [72, 101, 108, 108, 111, 32] +``` + +**Step 2:** +``` +T₁ = [72, 101, 108, 108, 111, 32] +Forward pass → Lā‚‚ shape: [1, 6, 128] +lā‚‚ = Lā‚‚[0, -1, :] + +Continue until max_length reached... +``` + +### 13.3 Top-k Sampling + +**Example:** +``` +V = 128, k = 50 +l = [0.5, 0.3, ..., -0.1, ...] (128 logits) + +Sort and get top 50: +top_k_indices = [0, 5, 12, ..., 87] (50 tokens) + +l' = [-inf, -inf, ..., 0.5, -inf, ..., 0.3, ...] + (only top-k kept, others set to -inf) +``` + +### 13.4 Top-p (Nucleus) Sampling + +**Example:** +``` +p = 0.95 (threshold) +p_sorted = [0.3, 0.2, 0.15, 0.1, 0.05, 0.03, ...] (sorted probabilities) + +Cumulative: [0.3, 0.5, 0.65, 0.75, 0.8, 0.83, ...] + +Find where cumulative ≄ 0.95: +At index 20: cumulative = 0.96 ≄ 0.95 +Keep first 20 tokens, set others to 0 +``` + +--- + +## Summary + +This document provides complete mathematical formulations for: + +1. **Data Processing**: Tokenization, chunking, padding +2. **Embeddings**: Token embeddings and positional encodings +3. **Attention**: Multi-head self-attention with scaling and masking +4. **Feed-Forward**: GELU activation and linear transformations +5. **Normalization**: Layer normalization with learnable parameters +6. **Training**: Loss computation, backpropagation, gradient clipping +7. **Optimization**: AdamW update rule with momentum and variance tracking +8. **Scheduling**: Cosine annealing learning rate schedule +9. **Generation**: Autoregressive sampling with temperature, top-k, and top-p + +Each section includes: +- Mathematical formulations +- Step-by-step calculations +- Worked examples with numerical values +- Implementation details + +All equations are directly implementable in PyTorch and match the actual implementation in the SheepOp codebase. + diff --git a/docs/MULTI_FORMAT_DATA_GUIDE.md b/docs/MULTI_FORMAT_DATA_GUIDE.md new file mode 100644 index 0000000..ecbc282 --- /dev/null +++ b/docs/MULTI_FORMAT_DATA_GUIDE.md @@ -0,0 +1,195 @@ +# Multi-Format Data Processing Guide + +## Overview + +The training script now supports processing multiple file types from your `data/` directory: + +- **Text files**: `.txt`, `.md`, `.rst`, `.log`, `.csv`, `.json`, `.jsonl`, `.xml`, `.html`, `.htm` +- **Code files**: `.py`, `.js`, `.ts`, `.java`, `.cpp`, `.c`, `.go`, `.rs`, `.rb`, `.php`, `.swift`, and many more +- **PDF files**: `.pdf` (requires PyPDF2 or pdfplumber) +- **Images**: `.png`, `.jpg`, `.jpeg`, `.gif`, `.bmp`, `.tiff`, `.webp` (requires pytesseract for OCR) + +## Basic Usage + +Simply point the training script to your data directory: + +```bash +python train.py --data /path/to/your/data/directory +``` + +The script will automatically: +1. Scan the directory (recursively by default) +2. Extract text from all supported file types +3. Process and tokenize the text +4. Train the model on all extracted content + +## Installation + +### Core Dependencies + +The core dependencies are already in `requirements.txt`. Install them with: + +```bash +pip install -r requirements.txt +``` + +### Optional Dependencies for PDF and Image Processing + +If you want to process PDFs or images, install the optional dependencies: + +```bash +# For PDF processing (choose one): +pip install PyPDF2 +# OR +pip install pdfplumber # Alternative, often better for complex PDFs + +# For image OCR: +pip install pytesseract Pillow + +# Also install Tesseract OCR engine: +# macOS: brew install tesseract +# Ubuntu/Debian: sudo apt-get install tesseract-ocr +# Windows: Download from https://github.com/UB-Mannheim/tesseract/wiki +``` + +## How It Works + +### 1. Text Files + +Text files are read line by line. Each non-empty line becomes a training sample. + +### 2. Code Files + +Code files are processed as text. Each line of code becomes a training sample. This allows the model to learn code patterns and syntax. + +### 3. PDF Files + +PDFs are processed page by page: +- Text is extracted from each page +- Split into lines +- Filtered to remove very short lines +- Each line becomes a training sample + +**Note**: PDF extraction works best with text-based PDFs. Scanned PDFs (images) should use OCR instead. + +### 4. Image Files + +Images are processed using OCR (Optical Character Recognition): +- Images are opened using PIL/Pillow +- pytesseract extracts text from the image +- Text is split into lines +- Each line becomes a training sample + +**Note**: OCR quality depends on image quality. For best results: +- Use high-resolution images +- Ensure good contrast between text and background +- Avoid images with complex layouts + +## Configuration Options + +You can customize the data processing behavior: + +```python +from pathlib import Path +from data import DataProcessor + +processor = DataProcessor( + use_ocr=True, # Enable OCR for images + use_pdf_extraction=True # Enable PDF extraction +) + +# Process directory +texts = processor.process_to_list( + directory=Path("data/"), + recursive=True, # Process subdirectories + min_length=10, # Minimum line length + max_samples=None, # Limit number of samples (None = all) +) +``` + +## Examples + +### Example 1: Process all files in directory + +```bash +python train.py --data /mnt/storage/sheepOp/data +``` + +### Example 2: Process single file + +```bash +python train.py --data /mnt/storage/sheepOp/data/document.pdf +``` + +### Example 3: Using Python API + +```python +from pathlib import Path +from data import extract_text_from_directory + +# Extract text from all supported files +texts = extract_text_from_directory( + directory=Path("data/"), + recursive=True, + use_ocr=True, + use_pdf_extraction=True, + min_length=10, +) + +print(f"Extracted {len(texts)} text samples") +``` + +## Supported File Types Summary + +| Category | Extensions | Requirements | +|----------|-----------|--------------| +| Text | `.txt`, `.md`, `.rst`, `.log`, `.csv`, `.json`, `.jsonl`, `.xml`, `.html`, `.htm` | None | +| Code | `.py`, `.js`, `.ts`, `.java`, `.cpp`, `.c`, `.go`, `.rs`, `.rb`, `.php`, `.swift`, and 30+ more | None | +| PDF | `.pdf` | PyPDF2 or pdfplumber | +| Images | `.png`, `.jpg`, `.jpeg`, `.gif`, `.bmp`, `.tiff`, `.webp` | pytesseract + Pillow + Tesseract OCR | + +## Troubleshooting + +### PDF extraction not working + +- Install PyPDF2: `pip install PyPDF2` +- Or install pdfplumber (better for complex PDFs): `pip install pdfplumber` +- If PDFs are scanned images, use OCR instead + +### OCR not working + +1. Install pytesseract: `pip install pytesseract Pillow` +2. Install Tesseract OCR engine (see installation instructions above) +3. On some systems, you may need to set the tesseract path: + ```python + import pytesseract + pytesseract.pytesseract.tesseract_cmd = '/usr/local/bin/tesseract' # macOS example + ``` + +### No text extracted + +- Check that files are in supported formats +- Verify file permissions +- Check logs for error messages +- Try processing a single file first to debug + +## Performance Tips + +1. **Large directories**: Processing can take time for large directories. Progress is logged every 100 files. + +2. **Parallel processing**: Consider processing files in parallel if you have many large files. + +3. **Filtering**: Use `min_length` to filter out very short lines that may not be useful for training. + +4. **Caching**: For repeated processing, consider saving extracted text to a file first. + +## Next Steps + +Once your data is processed: + +1. The training script will automatically tokenize the text +2. Create training batches +3. Train your model + +For more information on training, see `RETRAINING_GUIDE.md`. + diff --git a/docs/NEURAL_NETWORK_EXPLAINED.md b/docs/NEURAL_NETWORK_EXPLAINED.md new file mode 100644 index 0000000..9129c5b --- /dev/null +++ b/docs/NEURAL_NETWORK_EXPLAINED.md @@ -0,0 +1,948 @@ +# What is a Neural Network? Step-by-Step Explanation + +Complete step-by-step explanation of neural networks: what neurons are, what weights are, how calculations work, why they're important, with mathematical derivations and solved exercises. + +## Table of Contents + +1. [What is a Neural Network?](#61-what-is-a-neural-network) +2. [What is a Neuron?](#62-what-is-a-neuron) +3. [What are Weights?](#63-what-are-weights) +4. [How Neurons Calculate](#64-how-neurons-calculate) +5. [Why Weights are Important](#65-why-weights-are-important) +6. [Complete Mathematical Formulation](#66-complete-mathematical-formulation) +7. [Multi-Layer Neural Networks](#67-multi-layer-neural-networks) +8. [Exercise 1: Single Neuron Calculation](#68-exercise-1-single-neuron-calculation) +9. [Exercise 2: Multi-Layer Network](#69-exercise-2-multi-layer-network) +10. [Exercise 3: Learning Weights](#610-exercise-3-learning-weights) +11. [Key Takeaways](#611-key-takeaways) + +--- + +## 6.1 What is a Neural Network? + +### Simple Definition + +A **neural network** is a computational model inspired by biological neurons that processes information through interconnected nodes (neurons) to make predictions or decisions. + +### Visual Analogy + +**Think of a neural network like a factory:** + +``` +Input → Worker 1 → Worker 2 → Worker 3 → Output +``` + +**Neural Network:** + +``` +Input → Neuron 1 → Neuron 2 → Neuron 3 → Output +``` + +**Each worker (neuron) does a specific job, and they work together to produce the final result.** + +### Basic Structure + +``` +Input Layer Hidden Layer Output Layer + ā— ā— ā— + ā— ā— ā— + ā— ā— ā— + ā— ā— +``` + +**Key Components:** + +- **Input Layer:** Receives data +- **Hidden Layers:** Process information +- **Output Layer:** Produces predictions +- **Connections:** Weights between neurons + +--- + +## 6.2 What is a Neuron? + +### Simple Definition + +A **neuron** (also called a node or unit) is the basic processing unit of a neural network. It receives inputs, performs calculations, and produces an output. + +### Biological Inspiration + +**Biological Neuron:** + +``` +Dendrites → Cell Body → Axon → Synapses +(inputs) (process) (output) (connections) +``` + +**Artificial Neuron:** + +``` +Inputs → Weighted Sum → Activation → Output +``` + +### Structure of a Neuron + +``` +Input 1 (x₁) ────┐ + │ +Input 2 (xā‚‚) ────┼──→ [Ī£] ─→ [f] ─→ Output (y) + │ +Input 3 (xā‚ƒ) ā”€ā”€ā”€ā”€ā”˜ +``` + +**Components:** + +1. **Inputs:** Values fed into the neuron +2. **Weights:** Strength of connections +3. **Weighted Sum:** Sum of inputs Ɨ weights +4. **Bias:** Added constant +5. **Activation Function:** Applies nonlinearity +6. **Output:** Final result + +### Visual Representation + +``` +Neuron: + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Inputs: x₁, xā‚‚, xā‚ƒ │ + │ Weights: w₁, wā‚‚, wā‚ƒā”‚ + │ │ + │ z = Ī£(xįµ¢ Ɨ wįµ¢) + b │ + │ y = f(z) │ + │ │ + │ Output: y │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ +``` + +**Where:** + +- `z` = weighted sum (before activation) +- `f` = activation function +- `y` = output (after activation) + +--- + +## 6.3 What are Weights? + +### Simple Definition + +**Weights** are numerical values that determine the strength of connections between neurons. They control how much each input contributes to the output. + +### Visual Analogy + +**Think of weights like volume controls:** + +``` +Music Source 1 ──[Volume: 0.8]──→ Speakers +Music Source 2 ──[Volume: 0.3]──→ Speakers +Music Source 3 ──[Volume: 0.5]──→ Speakers +``` + +**Higher weight = Louder contribution** + +**Neural Network:** + +``` +Input 1 ──[Weight: 0.8]──→ Neuron +Input 2 ──[Weight: 0.3]──→ Neuron +Input 3 ──[Weight: 0.5]──→ Neuron +``` + +**Higher weight = Stronger influence** + +### What Weights Do + +**Weights determine:** + +1. **How much each input matters** +2. **The relationship between inputs and outputs** +3. **What patterns the neuron learns** + +**Example:** + +**Weight = 0.1:** + +- Input has small influence +- Weak connection + +**Weight = 5.0:** + +- Input has large influence +- Strong connection + +**Weight = -2.0:** + +- Input has negative influence +- Inverts the relationship + +**Weight = 0.0:** + +- Input has no influence +- Connection is cut + +### Weight Matrix + +**In a layer with multiple neurons:** + +``` +Input Layer Weights Matrix Output Layer +x₁ ───────────────────┐ + │ w₁₁ w₁₂ y₁ +xā‚‚ ───────────────────┼─ w₂₁ wā‚‚ā‚‚ ──── yā‚‚ + │ wā‚ƒā‚ wā‚ƒā‚‚ +xā‚ƒ ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ +``` + +**Weight Matrix:** + +``` +W = [w₁₁ w₁₂] + [w₂₁ wā‚‚ā‚‚] + [wā‚ƒā‚ wā‚ƒā‚‚] +``` + +**Each row:** Connections from one input +**Each column:** Connections to one output + +--- + +## 6.4 How Neurons Calculate + +### Step-by-Step Calculation + +#### Step 1: Weighted Sum + +**Multiply each input by its weight:** + +```math +z = x_1 \times w_1 + x_2 \times w_2 + x_3 \times w_3 + ... + b +``` + +**Or in vector form:** + +```math +z = \mathbf{x} \cdot \mathbf{w} + b = \sum_{i=1}^{n} x_i w_i + b +``` + +**Where:** + +- $x_i$ = input value +- $w_i$ = weight for input $i$ +- $b$ = bias (constant) +- $n$ = number of inputs + +#### Step 2: Add Bias + +**Bias shifts the activation:** + +```math +z = \sum_{i=1}^{n} x_i w_i + b +``` + +**Bias allows the neuron to:** + +- Shift activation threshold +- Learn patterns independent of inputs +- Adjust baseline output + +#### Step 3: Apply Activation Function + +**Apply nonlinear function:** + +```math +y = f(z) +``` + +**Common activation functions:** + +**ReLU (Rectified Linear Unit):** + +```math +f(z) = \max(0, z) +``` + +**Sigmoid:** + +```math +f(z) = \frac{1}{1 + e^{-z}} +``` + +**Tanh:** + +```math +f(z) = \tanh(z) = \frac{e^z - e^{-z}}{e^z + e^{-z}} +``` + +**GELU (used in transformers):** + +```math +f(z) = z \cdot \Phi(z) +``` + +**Where $\Phi(z)$ is the CDF of standard normal distribution** + +### Complete Example + +**Given:** + +- Inputs: $x_1 = 0.5, x_2 = 0.3, x_3 = 0.8$ +- Weights: $w_1 = 0.6, w_2 = 0.4, w_3 = 0.2$ +- Bias: $b = 0.1$ +- Activation: ReLU + +**Step 1: Weighted Sum** + +``` +z = (0.5 Ɨ 0.6) + (0.3 Ɨ 0.4) + (0.8 Ɨ 0.2) + 0.1 + = 0.3 + 0.12 + 0.16 + 0.1 + = 0.68 +``` + +**Step 2: Apply Activation** + +``` +y = ReLU(0.68) + = max(0, 0.68) + = 0.68 +``` + +**Result:** Output = 0.68 + +--- + +## 6.5 Why Weights are Important + +### Reason 1: They Determine What the Neuron Learns + +**Different weights = Different patterns:** + +**Pattern 1: Emphasis on Input 1** + +``` +w₁ = 5.0, wā‚‚ = 0.1, wā‚ƒ = 0.1 +→ Neuron cares mostly about input 1 +``` + +**Pattern 2: Balanced Weights** + +``` +w₁ = 0.5, wā‚‚ = 0.5, wā‚ƒ = 0.5 +→ Neuron treats all inputs equally +``` + +**Pattern 3: Inverted Relationship** + +``` +w₁ = -2.0, wā‚‚ = 1.0, wā‚ƒ = 1.0 +→ Neuron inverses input 1's effect +``` + +### Reason 2: They Enable Learning + +**Training adjusts weights:** + +**Before Training:** + +``` +Weights: Random values +→ Random predictions +``` + +**After Training:** + +``` +Weights: Learned values +→ Accurate predictions +``` + +**Weights are what the model learns!** + +### Reason 3: They Control Information Flow + +**High weights:** Information flows easily +**Low weights:** Information flows weakly +**Zero weights:** Information blocked +**Negative weights:** Information inverted + +### Reason 4: They Enable Complex Patterns + +**Multiple neurons with different weights:** + +``` +Neuron 1: w₁ = 1.0, wā‚‚ = 0.0 → Detects pattern A +Neuron 2: w₁ = 0.0, wā‚‚ = 1.0 → Detects pattern B +Neuron 3: w₁ = 0.5, wā‚‚ = 0.5 → Detects pattern C +``` + +**Together:** Model learns complex relationships! + +--- + +## 6.6 Complete Mathematical Formulation + +### Single Neuron Formula + +**Complete neuron calculation:** + +```math +z = \sum_{i=1}^{n} x_i w_i + b +``` + +```math +y = f(z) +``` + +**Where:** + +- $\mathbf{x} = [x_1, x_2, ..., x_n]$ = input vector +- $\mathbf{w} = [w_1, w_2, ..., w_n]$ = weight vector +- $b$ = bias (scalar) +- $f$ = activation function +- $z$ = weighted sum (before activation) +- $y$ = output (after activation) + +### Matrix Formulation + +**For multiple neurons:** + +```math +\mathbf{z} = \mathbf{X} \mathbf{W} + \mathbf{b} +``` + +```math +\mathbf{Y} = f(\mathbf{z}) +``` + +**Where:** + +- $\mathbf{X} \in \mathbb{R}^{B \times n}$ = input matrix (B samples, n features) +- $\mathbf{W} \in \mathbb{R}^{n \times m}$ = weight matrix (n inputs, m neurons) +- $\mathbf{b} \in \mathbb{R}^{1 \times m}$ = bias vector +- $\mathbf{z} \in \mathbb{R}^{B \times m}$ = weighted sums +- $\mathbf{Y} \in \mathbb{R}^{B \times m}$ = outputs + +**Example:** + +**Input Matrix:** + +``` +X = [x₁₁ x₁₂] (2 samples, 2 features) + [x₂₁ xā‚‚ā‚‚] +``` + +**Weight Matrix:** + +``` +W = [w₁₁ w₁₂] (2 inputs, 2 neurons) + [w₂₁ wā‚‚ā‚‚] +``` + +**Bias Vector:** + +``` +b = [b₁ bā‚‚] (2 neurons) +``` + +**Calculation:** + +``` +z = X Ɨ W + b + +z₁₁ = x₁₁×w₁₁ + x₁₂×w₂₁ + b₁ +z₁₂ = x₁₁×w₁₂ + x₁₂×wā‚‚ā‚‚ + bā‚‚ +z₂₁ = x₂₁×w₁₁ + xā‚‚ā‚‚Ć—w₂₁ + b₁ +zā‚‚ā‚‚ = x₂₁×w₁₂ + xā‚‚ā‚‚Ć—wā‚‚ā‚‚ + bā‚‚ +``` + +--- + +## 6.7 Multi-Layer Neural Networks + +### Structure + +``` +Input Layer → Hidden Layer 1 → Hidden Layer 2 → Output Layer + x₁ h₁₁ h₂₁ y₁ + xā‚‚ h₁₂ hā‚‚ā‚‚ yā‚‚ + xā‚ƒ hā‚ā‚ƒ hā‚‚ā‚ƒ +``` + +### Forward Pass + +**Layer 1:** + +```math +\mathbf{h}_1 = f_1(\mathbf{X} \mathbf{W}_1 + \mathbf{b}_1) +``` + +**Layer 2:** + +```math +\mathbf{h}_2 = f_2(\mathbf{h}_1 \mathbf{W}_2 + \mathbf{b}_2) +``` + +**Output Layer:** + +```math +\mathbf{Y} = f_3(\mathbf{h}_2 \mathbf{W}_3 + \mathbf{b}_3) +``` + +**Chained together:** + +```math +\mathbf{Y} = f_3(f_2(f_1(\mathbf{X} \mathbf{W}_1 + \mathbf{b}_1) \mathbf{W}_2 + \mathbf{b}_2) \mathbf{W}_3 + \mathbf{b}_3) +``` + +**Each layer transforms the input!** + +--- + +## 6.8 Exercise 1: Single Neuron Calculation + +### Problem + +**Given a single neuron with:** + +- Inputs: $x_1 = 2.0, x_2 = -1.0, x_3 = 0.5$ +- Weights: $w_1 = 0.5, w_2 = -0.3, w_3 = 0.8$ +- Bias: $b = 0.2$ +- Activation function: ReLU $f(z) = \max(0, z)$ + +**Calculate the output of this neuron.** + +### Step-by-Step Solution + +#### Step 1: Weighted Sum + +**Compute:** + +```math +z = \sum_{i=1}^{3} x_i w_i + b +``` + +**Substitute values:** + +```math +z = (2.0 \times 0.5) + (-1.0 \times -0.3) + (0.5 \times 0.8) + 0.2 +``` + +**Calculate each term:** + +```math +z = (1.0) + (0.3) + (0.4) + 0.2 +``` + +**Sum:** + +```math +z = 1.0 + 0.3 + 0.4 + 0.2 = 1.9 +``` + +#### Step 2: Apply Activation Function + +**Apply ReLU:** + +```math +y = \text{ReLU}(z) = \max(0, z) = \max(0, 1.9) = 1.9 +``` + +### Answer + +**The output of the neuron is $y = 1.9$.** + +### Verification + +**Check calculation:** + +- Input contribution 1: $2.0 \times 0.5 = 1.0$ +- Input contribution 2: $-1.0 \times -0.3 = 0.3$ +- Input contribution 3: $0.5 \times 0.8 = 0.4$ +- Bias: $0.2$ +- Total: $1.0 + 0.3 + 0.4 + 0.2 = 1.9$ āœ“ +- ReLU(1.9) = 1.9 āœ“ + +--- + +## 6.9 Exercise 2: Multi-Layer Network + +### Problem + +**Given a neural network with 2 layers:** + +**Layer 1:** + +- Inputs: $x_1 = 1.0, x_2 = 0.5$ +- Weights: $W_1 = \begin{bmatrix} 0.6 & 0.4 \\ 0.2 & 0.8 \end{bmatrix}$ +- Bias: $b_1 = [0.1, -0.1]$ +- Activation: ReLU + +**Layer 2:** + +- Inputs: Outputs from Layer 1 +- Weights: $W_2 = \begin{bmatrix} 0.5 \\ 0.7 \end{bmatrix}$ +- Bias: $b_2 = 0.2$ +- Activation: ReLU + +**Calculate the final output.** + +### Step-by-Step Solution + +#### Step 1: Layer 1 - Weighted Sum + +**Input vector:** + +```math +\mathbf{x} = [1.0, 0.5] +``` + +**Weight matrix:** + +```math +\mathbf{W}_1 = \begin{bmatrix} 0.6 & 0.4 \\ 0.2 & 0.8 \end{bmatrix} +``` + +**Bias vector:** + +```math +\mathbf{b}_1 = [0.1, -0.1] +``` + +**Calculate:** + +```math +\mathbf{z}_1 = \mathbf{x} \mathbf{W}_1 + \mathbf{b}_1 +``` + +**Matrix multiplication:** + +```math +\mathbf{z}_1 = [1.0, 0.5] \begin{bmatrix} 0.6 & 0.4 \\ 0.2 & 0.8 \end{bmatrix} + [0.1, -0.1] +``` + +**Compute:** + +```math +z_{1,1} = 1.0 \times 0.6 + 0.5 \times 0.2 + 0.1 = 0.6 + 0.1 + 0.1 = 0.8 +``` + +```math +z_{1,2} = 1.0 \times 0.4 + 0.5 \times 0.8 + (-0.1) = 0.4 + 0.4 - 0.1 = 0.7 +``` + +```math +\mathbf{z}_1 = [0.8, 0.7] +``` + +#### Step 2: Layer 1 - Apply Activation + +**Apply ReLU:** + +```math +\mathbf{h}_1 = \text{ReLU}(\mathbf{z}_1) = [\max(0, 0.8), \max(0, 0.7)] = [0.8, 0.7] +``` + +#### Step 3: Layer 2 - Weighted Sum + +**Input (from Layer 1):** + +```math +\mathbf{h}_1 = [0.8, 0.7] +``` + +**Weight matrix:** + +```math +\mathbf{W}_2 = \begin{bmatrix} 0.5 \\ 0.7 \end{bmatrix} +``` + +**Bias:** + +```math +b_2 = 0.2 +``` + +**Calculate:** + +```math +z_2 = \mathbf{h}_1 \mathbf{W}_2 + b_2 +``` + +**Matrix multiplication:** + +```math +z_2 = [0.8, 0.7] \begin{bmatrix} 0.5 \\ 0.7 \end{bmatrix} + 0.2 +``` + +**Compute:** + +```math +z_2 = 0.8 \times 0.5 + 0.7 \times 0.7 + 0.2 = 0.4 + 0.49 + 0.2 = 1.09 +``` + +#### Step 4: Layer 2 - Apply Activation + +**Apply ReLU:** + +```math +y = \text{ReLU}(z_2) = \max(0, 1.09) = 1.09 +``` + +### Answer + +**The final output is $y = 1.09$.** + +### Summary Table + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
LayerInputWeightsBiasWeighted SumActivationOutput
1[1.0, 0.5]$$\begin{bmatrix} 0.6 & 0.4 \\ 0.2 & 0.8 \end{bmatrix}$$[0.1, -0.1][0.8, 0.7]ReLU[0.8, 0.7]
2[0.8, 0.7]$$\begin{bmatrix} 0.5 \\ 0.7 \end{bmatrix}$$0.21.09ReLU1.09
+--- + +## 6.10 Exercise 3: Learning Weights + +### Problem + +**Given a neuron that should output 1.0 when inputs are [1.0, 1.0] and output 0.0 when inputs are [0.0, 0.0], find appropriate weights and bias.** + +**Use:** + +- Activation: Sigmoid $f(z) = \frac{1}{1 + e^{-z}}$ +- Desired behavior: AND gate (output 1 only when both inputs are 1) + +### Step-by-Step Solution + +#### Step 1: Set Up Equations + +**For input [1.0, 1.0], desired output ā‰ˆ 1.0:** + +```math +f(w_1 \times 1.0 + w_2 \times 1.0 + b) = 1.0 +``` + +**For input [0.0, 0.0], desired output ā‰ˆ 0.0:** + +```math +f(w_1 \times 0.0 + w_2 \times 0.0 + b) = 0.0 +``` + +**Note:** Sigmoid outputs range from 0 to 1, so: + +- $f(z) \approx 1.0$ when $z \gg 0$ (e.g., $z > 5$) +- $f(z) \approx 0.0$ when $z \ll 0$ (e.g., $z < -5$) + +#### Step 2: Solve for Bias + +**From equation 2:** + +```math +f(b) = 0.0 +``` + +**For sigmoid to output ā‰ˆ 0:** + +```math +b < -5 +``` + +**Let's use:** + +```math +b = -10 +``` + +#### Step 3: Solve for Weights + +**From equation 1:** + +```math +f(w_1 + w_2 - 10) = 1.0 +``` + +**For sigmoid to output ā‰ˆ 1:** + +```math +w_1 + w_2 - 10 > 5 +``` + +```math +w_1 + w_2 > 15 +``` + +**Let's use equal weights:** + +```math +w_1 = w_2 = 8.0 +``` + +**Check:** + +```math +w_1 + w_2 = 8.0 + 8.0 = 16.0 > 15 \quad āœ“ +``` + +#### Step 4: Verify Solution + +**Test Case 1: Input [1.0, 1.0]** + +```math +z = 1.0 \times 8.0 + 1.0 \times 8.0 + (-10) = 8.0 + 8.0 - 10 = 6.0 +``` + +```math +y = \frac{1}{1 + e^{-6.0}} = \frac{1}{1 + 0.0025} \approx 0.9975 \approx 1.0 \quad āœ“ +``` + +**Test Case 2: Input [0.0, 0.0]** + +```math +z = 0.0 \times 8.0 + 0.0 \times 8.0 + (-10) = -10 +``` + +```math +y = \frac{1}{1 + e^{10}} = \frac{1}{1 + 22026} \approx 0.00005 \approx 0.0 \quad āœ“ +``` + +**Test Case 3: Input [1.0, 0.0]** + +```math +z = 1.0 \times 8.0 + 0.0 \times 8.0 + (-10) = 8.0 - 10 = -2.0 +``` + +```math +y = \frac{1}{1 + e^{2.0}} = \frac{1}{1 + 7.39} \approx 0.12 < 0.5 \quad āœ“ +``` + +**Test Case 4: Input [0.0, 1.0]** + +```math +z = 0.0 \times 8.0 + 1.0 \times 8.0 + (-10) = 8.0 - 10 = -2.0 +``` + +```math +y = \frac{1}{1 + e^{2.0}} \approx 0.12 < 0.5 \quad āœ“ +``` + +### Answer + +**Appropriate weights and bias:** + +- $w_1 = 8.0$ +- $w_2 = 8.0$ +- $b = -10.0$ + +**The neuron implements an AND gate correctly!** + +### Key Insight + +**This demonstrates learning:** + +- Training finds weights that produce desired behavior +- Different weights = Different logic functions +- Learning algorithms (like backpropagation) automatically find these weights from data! + +--- + +## 6.11 Key Takeaways + +### Neurons + +āœ… **Neurons are the basic processing units** +āœ… **Receive inputs, compute weighted sum, apply activation** +āœ… **Output is the result of activation function** + +### Weights + +āœ… **Weights control connection strength** +āœ… **Determine what patterns neurons learn** +āœ… **Are what the model learns during training** +āœ… **Enable complex pattern recognition** + +### Calculation + +āœ… **Weighted sum: $z = \sum x_i w_i + b$** +āœ… **Activation: $y = f(z)$** +āœ… **Matrix form enables efficient computation** + +### Importance + +āœ… **Weights enable learning** +āœ… **Control information flow** +āœ… **Enable complex pattern recognition** +āœ… **Are adjusted during training to minimize error** + +### Neural Networks + +āœ… **Multiple neurons form layers** +āœ… **Multiple layers form networks** +āœ… **Each layer transforms the input** +āœ… **Deep networks learn hierarchical features** + +--- + +## Mathematical Summary + +### Single Neuron + +```math +z = \sum_{i=1}^{n} x_i w_i + b +``` + +```math +y = f(z) +``` + +### Multiple Neurons (Matrix Form) + +```math +\mathbf{z} = \mathbf{X} \mathbf{W} + \mathbf{b} +``` + +```math +\mathbf{Y} = f(\mathbf{z}) +``` + +### Multi-Layer Network + +```math +\mathbf{h}_1 = f_1(\mathbf{X} \mathbf{W}_1 + \mathbf{b}_1) +``` + +```math +\mathbf{h}_2 = f_2(\mathbf{h}_1 \mathbf{W}_2 + \mathbf{b}_2) +``` + +```math +\mathbf{Y} = f_3(\mathbf{h}_2 \mathbf{W}_3 + \mathbf{b}_3) +``` + +--- + +_This document provides a comprehensive explanation of neural networks, neurons, weights, and calculations with mathematical derivations and solved exercises._ diff --git a/docs/NORMALIZATION_EXPLAINED.md b/docs/NORMALIZATION_EXPLAINED.md new file mode 100644 index 0000000..c4e4e7a --- /dev/null +++ b/docs/NORMALIZATION_EXPLAINED.md @@ -0,0 +1,640 @@ +# What is Normalization? Step-by-Step Explanation + +Complete step-by-step explanation of normalization in transformer models: how normalization stabilizes training and improves model performance. + +## Table of Contents + +1. [The Problem Normalization Solves](#41-the-problem-normalization-solves) +2. [What is Normalization?](#42-what-is-normalization) +3. [How Layer Normalization Works: Step-by-Step](#43-how-layer-normalization-works-step-by-step) +4. [Complete Example: Normalizing a Vector](#44-complete-example-normalizing-a-vector) +5. [Why Normalization Matters](#45-why-normalization-matters) +6. [Pre-Norm vs Post-Norm Architecture](#46-pre-norm-vs-post-norm-architecture) +7. [Visual Representation](#47-visual-representation) +8. [Key Takeaways](#48-key-takeaways) + +--- + +## 4.1 The Problem Normalization Solves + +### The Challenge + +**During training, activations can become unstable:** + +**Problem 1: Varying Activations** +``` +Layer 1 output: [0.1, 0.2, 0.3, ...] (small values) +Layer 2 output: [10.5, 20.3, 15.8, ...] (large values) +Layer 3 output: [0.01, 0.02, 0.03, ...] (very small values) +``` + +**Problem 2: Internal Covariate Shift** +- Activations change distribution as weights update +- Later layers struggle to adapt to changing inputs +- Training becomes slower and less stable + +**Problem 3: Gradient Problems** +``` +Large activations → Large gradients → Exploding gradients +Small activations → Small gradients → Vanishing gradients +``` + +### The Solution: Normalization + +**Normalization standardizes activations to have consistent statistics (mean zero, variance one), making training stable and efficient.** + +--- + +## 4.2 What is Normalization? + +### Simple Definition + +**Normalization** is a technique that transforms activations to have: +- **Mean of zero** (centered) +- **Variance of one** (standardized scale) + +**Think of it like standardization:** +- Converts any distribution to a standard form +- Makes values comparable across different scales +- Helps the model learn faster and more reliably + +### Visual Analogy + +**Imagine weights on a scale:** + +**Before Normalization:** +``` +Bronze weight: 1 kg +Silver weight: 100 kg +Gold weight: 0.001 kg +→ Hard to compare! +``` + +**After Normalization:** +``` +All weights standardized to mean 0, variance 1 +→ Easy to compare and work with! +``` + +### Types of Normalization + +**In transformers, we use Layer Normalization:** + +- **Layer Normalization:** Normalizes across features (dimensions) for each sample +- **Batch Normalization:** Normalizes across samples in a batch (not used in transformers) +- **Instance Normalization:** Normalizes each sample independently + +**Why Layer Normalization?** +- Works well with variable sequence lengths +- Doesn't depend on batch size +- Suitable for autoregressive models + +--- + +## 4.3 How Layer Normalization Works: Step-by-Step + +### High-Level Overview + +``` +Step 1: Compute mean of activations +Step 2: Compute variance of activations +Step 3: Normalize (subtract mean, divide by std) +Step 4: Scale and shift (learnable parameters) +``` + +### Detailed Step-by-Step + +#### Step 1: Compute Mean + +**Calculate the average value across all dimensions:** + +```math +\mu = \frac{1}{d} \sum_{i=1}^{d} x_i +``` + +**Example:** + +**Input vector:** +``` +x = [1.0, 2.0, 3.0, 4.0] +d = 4 (number of dimensions) +``` + +**Compute mean:** +``` +μ = (1.0 + 2.0 + 3.0 + 4.0) / 4 + = 10.0 / 4 + = 2.5 +``` + +**Meaning:** The center of the distribution is at 2.5 + +#### Step 2: Compute Variance + +**Measure how spread out the values are:** + +```math +\sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2 +``` + +**Example:** + +**Using the same input:** +``` +x = [1.0, 2.0, 3.0, 4.0] +μ = 2.5 +``` + +**Compute variance:** +``` +σ² = [(1.0 - 2.5)² + (2.0 - 2.5)² + (3.0 - 2.5)² + (4.0 - 2.5)²] / 4 + = [(-1.5)² + (-0.5)² + (0.5)² + (1.5)²] / 4 + = [2.25 + 0.25 + 0.25 + 2.25] / 4 + = 5.0 / 4 + = 1.25 +``` + +**Compute standard deviation:** +``` +σ = √σ² = √1.25 ā‰ˆ 1.118 +``` + +**Meaning:** Values are spread out with standard deviation of 1.118 + +#### Step 3: Normalize + +**Subtract mean and divide by standard deviation:** + +```math +\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} +``` + +**Where:** +- $\epsilon$ is a small constant (default: 1e-5) to prevent division by zero + +**Example:** + +**Using the same input:** +``` +x = [1.0, 2.0, 3.0, 4.0] +μ = 2.5 +σ ā‰ˆ 1.118 +ε = 0.00001 +``` + +**Normalize each element:** +``` +x̂₁ = (1.0 - 2.5) / (1.118 + 0.00001) ā‰ˆ -1.341 +x̂₂ = (2.0 - 2.5) / (1.118 + 0.00001) ā‰ˆ -0.447 +xĢ‚ā‚ƒ = (3.0 - 2.5) / (1.118 + 0.00001) ā‰ˆ 0.447 +x̂₄ = (4.0 - 2.5) / (1.118 + 0.00001) ā‰ˆ 1.341 +``` + +**Result:** +``` +xĢ‚ = [-1.341, -0.447, 0.447, 1.341] +``` + +**Check:** +- Mean ā‰ˆ 0 āœ“ +- Standard deviation ā‰ˆ 1 āœ“ + +**Meaning:** Values are now standardized! + +#### Step 4: Scale and Shift + +**Apply learnable parameters:** + +```math +\text{LayerNorm}(x) = \gamma \odot \hat{x} + \beta +``` + +**Where:** +- $\gamma$ = learnable scale parameter (initialized to 1) +- $\beta$ = learnable shift parameter (initialized to 0) +- $\odot$ = element-wise multiplication + +**Example:** + +**Normalized vector:** +``` +xĢ‚ = [-1.341, -0.447, 0.447, 1.341] +``` + +**Learnable parameters (initialized):** +``` +γ = [1.0, 1.0, 1.0, 1.0] (scale) +β = [0.0, 0.0, 0.0, 0.0] (shift) +``` + +**Apply scale and shift:** +``` +Output = γ āŠ™ xĢ‚ + β + = [1.0, 1.0, 1.0, 1.0] āŠ™ [-1.341, -0.447, 0.447, 1.341] + [0.0, 0.0, 0.0, 0.0] + = [-1.341, -0.447, 0.447, 1.341] + [0.0, 0.0, 0.0, 0.0] + = [-1.341, -0.447, 0.447, 1.341] +``` + +**Initially, normalization is identity!** +**During training, γ and β learn optimal scale and shift.** + +--- + +## 4.4 Complete Example: Normalizing a Vector + +### Input + +``` +Word embedding after attention: [0.146, 0.108, 0.192, 0.155, ..., 0.11] +Dimension: 512 +``` + +### Step-by-Step Processing + +#### Step 1: Compute Mean + +**Input:** +``` +x = [0.146, 0.108, 0.192, ..., 0.11] (512 numbers) +``` + +**Compute mean:** +``` +μ = (0.146 + 0.108 + 0.192 + ... + 0.11) / 512 + ā‰ˆ 0.135 +``` + +**Visualization:** +``` +Values: [0.146, 0.108, 0.192, ..., 0.11] + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + Mean: 0.135 (center point) +``` + +#### Step 2: Compute Variance + +**Compute variance:** +``` +σ² = [(0.146 - 0.135)² + (0.108 - 0.135)² + (0.192 - 0.135)² + ... + (0.11 - 0.135)²] / 512 + ā‰ˆ 0.0023 +``` + +**Compute standard deviation:** +``` +σ = √0.0023 ā‰ˆ 0.048 +``` + +**Visualization:** +``` +Values: [0.146, 0.108, 0.192, ..., 0.11] +Spread: └───────── σ ā‰ˆ 0.048 ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ +``` + +#### Step 3: Normalize + +**Normalize each element:** +``` +x̂₁ = (0.146 - 0.135) / (0.048 + 0.00001) ā‰ˆ 0.229 +x̂₂ = (0.108 - 0.135) / (0.048 + 0.00001) ā‰ˆ -0.562 +xĢ‚ā‚ƒ = (0.192 - 0.135) / (0.048 + 0.00001) ā‰ˆ 1.188 +... +x̂₅₁₂ = (0.11 - 0.135) / (0.048 + 0.00001) ā‰ˆ -0.521 +``` + +**Result:** +``` +xĢ‚ = [0.229, -0.562, 1.188, ..., -0.521] +``` + +**Properties:** +- Mean ā‰ˆ 0 āœ“ +- Standard deviation ā‰ˆ 1 āœ“ + +#### Step 4: Scale and Shift + +**Apply learnable parameters:** +``` +γ = [1.0, 1.0, ..., 1.0] (512 values, may change during training) +β = [0.0, 0.0, ..., 0.0] (512 values, may change during training) +``` + +**Output:** +``` +Output = γ āŠ™ xĢ‚ + β + = [0.229, -0.562, 1.188, ..., -0.521] +``` + +**After training, γ and β adapt to optimal values!** + +--- + +## 4.5 Why Normalization Matters + +### Benefit 1: Stable Training + +**Without Normalization:** +``` +Layer 1: activations = [0.1, 0.2, ...] +Layer 2: activations = [50.0, 100.0, ...] ← Exploding! +Layer 3: activations = [0.001, 0.002, ...] ← Vanishing! +``` + +**With Normalization:** +``` +Layer 1: activations = [0.1, -0.2, ...] (normalized) +Layer 2: activations = [0.3, -0.1, ...] (normalized) +Layer 3: activations = [0.2, 0.4, ...] (normalized) +→ Consistent scale throughout! +``` + +### Benefit 2: Better Gradient Flow + +**Normalization helps gradients flow better:** + +**Without Normalization:** +``` +Gradient 1: 0.0001 (too small, vanishing) +Gradient 2: 1000.0 (too large, exploding) +Gradient 3: 0.001 (too small) +``` + +**With Normalization:** +``` +Gradient 1: 0.01 (reasonable) +Gradient 2: 0.02 (reasonable) +Gradient 3: 0.015 (reasonable) +→ Stable gradients! +``` + +### Benefit 3: Faster Convergence + +**Normalized activations allow:** +- Higher learning rates +- Faster weight updates +- Quicker convergence to good solutions + +**Analogy:** +- **Without normalization:** Walking on rough terrain (slow progress) +- **With normalization:** Walking on smooth path (fast progress) + +### Benefit 4: Regularization Effect + +**Normalization acts as a form of regularization:** +- Reduces internal covariate shift +- Makes optimization easier +- Helps prevent overfitting + +--- + +## 4.6 Pre-Norm vs Post-Norm Architecture + +### Post-Norm (Original Transformer) + +**Order:** +``` +Input → Attention → LayerNorm → Output +``` + +**Equation:** +``` +x_out = LayerNorm(x + Attention(x)) +``` + +**Problems:** +- Can be unstable with many layers +- Gradient flow can be difficult +- Harder to train deep networks + +### Pre-Norm (Modern Approach) + +**Order:** +``` +Input → LayerNorm → Attention → Output +``` + +**Equation:** +``` +x_out = x + Attention(LayerNorm(x)) +``` + +**Benefits:** +- More stable training +- Better gradient flow +- Easier to train deep networks + +**Visual Comparison:** + +**Post-Norm:** +``` +Input + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Attention │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ LayerNorm │ ← Normalization after + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + Output +``` + +**Pre-Norm:** +``` +Input + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ LayerNorm │ ← Normalization before + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Attention │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + ↓ + Output +``` + +**Our Model Uses Pre-Norm!** + +--- + +## 4.7 Visual Representation + +### Normalization Process + +``` +Input Vector + │ + │ [1.0, 2.0, 3.0, 4.0] + ↓ +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ Step 1: Compute Mean │ +│ μ = 2.5 │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + │ + ↓ +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ Step 2: Compute Variance │ +│ σ² = 1.25, σ ā‰ˆ 1.118 │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + │ + ↓ +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ Step 3: Normalize │ +│ xĢ‚ = (x - μ) / σ │ +│ [-1.341, -0.447, 0.447, 1.341] │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + │ + ↓ +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ Step 4: Scale and Shift │ +│ Output = γ āŠ™ xĢ‚ + β │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + │ + ↓ + Output Vector +``` + +### Distribution Transformation + +**Before Normalization:** +``` +Distribution: + │ + 0.4│ ā— + │ ā— ā— + 0.3│ ā— ā— + │ ā— ā— + 0.2│ ā— ā— + ā”‚ā— ā— + 0.1│ ā— + │ + 0.0ā”œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ + 0 1 2 3 4 5 + Mean: 2.5, Std: 1.118 +``` + +**After Normalization:** +``` +Distribution: + │ + 0.4│ ā— + │ ā— ā— + 0.3│ ā— ā— + │ ā— ā— + 0.2│ ā— ā— + ā”‚ā— ā— + 0.1│ ā— + │ + 0.0ā”œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ + -2 -1 0 1 2 3 + Mean: 0, Std: 1 +``` + +**Standardized!** + +### Gradient Flow Visualization + +**Without Normalization:** +``` +Gradient Magnitude: + │ +1000│ ā— + │ + 100│ + │ + 10│ + │ + 1│ ā— + │ + 0.1│ ā— + │ +0.01│ + └──────────────────────── Layer + 1 2 3 4 5 + (Unstable, varying magnitudes) +``` + +**With Normalization:** +``` +Gradient Magnitude: + │ +1000│ + │ + 100│ + │ + 10│ + │ ā— ā— ā— ā— ā— + 1│ + │ + 0.1│ + │ +0.01│ + └──────────────────────── Layer + 1 2 3 4 5 + (Stable, consistent magnitudes) +``` + +--- + +## 4.8 Key Takeaways: Normalization + +āœ… **Normalization standardizes activations to mean 0, variance 1** +āœ… **Stabilizes training by preventing exploding/vanishing gradients** +āœ… **Enables faster convergence and higher learning rates** +āœ… **Pre-norm architecture is preferred for deep networks** +āœ… **Learnable parameters (γ, β) allow optimal scaling** + +--- + +## Complete Mathematical Formula + +### Layer Normalization Formula + +For input $\mathbf{x} \in \mathbb{R}^d$: + +```math +\mu = \frac{1}{d} \sum_{i=1}^{d} x_i +``` + +```math +\sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2 +``` + +```math +\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} +``` + +```math +\text{LayerNorm}(\mathbf{x}) = \gamma \odot \hat{\mathbf{x}} + \beta +``` + +**Where:** +- $\epsilon$ = small constant (default: 1e-5) to prevent division by zero +- $\gamma$ = learnable scale parameter (initialized to 1) +- $\beta$ = learnable shift parameter (initialized to 0) +- $\odot$ = element-wise multiplication +- $d$ = number of dimensions + +### In Transformer Block + +**Pre-Norm Architecture:** + +```math +\mathbf{x}_{norm} = \text{LayerNorm}(\mathbf{x}_{in}) +``` + +```math +\mathbf{x}_{attn} = \text{Attention}(\mathbf{x}_{norm}) +``` + +```math +\mathbf{x}_{out} = \mathbf{x}_{in} + \mathbf{x}_{attn} \quad \text{(residual connection)} +``` + +**Normalization happens before attention and feed-forward!** + +--- + +*This document provides a step-by-step explanation of normalization, the critical component that stabilizes training and enables efficient learning in transformer models.* + diff --git a/docs/OPTIMIZATIONS.md b/docs/OPTIMIZATIONS.md new file mode 100644 index 0000000..62a9357 --- /dev/null +++ b/docs/OPTIMIZATIONS.md @@ -0,0 +1,223 @@ +# Optimizations for Production RAG Systems + +This document describes the optimizations implemented based on the paper "Optimizing LLM Inference and Retrieval: Novel Data Structures and Algorithms for Production RAG Systems". + +## Implemented Optimizations + +### 1. KV Cache for Efficient Autoregressive Generation + +**Location**: `models/optimized_attention.py` + +The KV (Key-Value) cache mechanism stores computed keys and values from previous tokens during autoregressive generation, eliminating redundant computations. + +**Benefits**: +- Reduces computational cost from O(n²) to O(n) for each new token +- Significantly faster generation for long sequences +- Lower memory bandwidth usage + +**Usage**: +```python +from models import TransformerModel, OptimizedInference + +model = TransformerModel(...) +optimizer = model.get_optimized_inference() + +# Generate with KV caching +generated = optimizer.generate_with_cache( + input_ids=input_ids, + max_length=100, + temperature=0.8, +) +``` + +### 2. Optimized Attention Computation + +**Location**: `models/optimized_attention.py` + +Implements optimized attention computation using PyTorch's `scaled_dot_product_attention` when available (similar to Flash Attention). + +**Features**: +- Uses PyTorch's optimized attention implementation +- Supports causal masking efficiently +- Reduces memory usage during attention computation + +**Usage**: +```python +from models import TransformerModel +from models.blocks import TransformerBlock + +# Use optimized attention in transformer blocks +block = TransformerBlock( + d_model=512, + num_heads=8, + use_optimized_attention=True, # Enable optimized attention +) +``` + +### 3. Retrieval Cache for Similar Queries + +**Location**: `models/optimized_attention.py` + +Implements approximate caching for retrieval results, reducing expensive vector database lookups by caching similar queries. + +**Features**: +- Cosine similarity-based cache lookup +- Configurable similarity threshold +- Automatic cache eviction when full + +**Usage**: +```python +from models.optimized_attention import RetrievalCache + +cache = RetrievalCache(max_size=1000, similarity_threshold=0.9) + +# Store retrieval results +cache.set(query_hash, query_embedding, retrieved_docs) + +# Retrieve cached results +results = cache.get(query_hash, query_embedding) +``` + +### 4. Prefetching Mechanisms + +**Location**: `models/prefetching.py` + +#### 4.1 PrefetchDataLoader +Prefetches batches in background threads, reducing GPU idle time. + +**Usage**: +```python +from models.prefetching import PrefetchDataLoader +from data import create_dataloader + +dataloader = create_dataloader(...) +prefetch_loader = PrefetchDataLoader( + dataloader=dataloader, + prefetch_factor=2, + device=device, +) +``` + +#### 4.2 LookaheadRetriever +Prefetches retrieval results for anticipated queries. + +**Usage**: +```python +from models.prefetching import LookaheadRetriever + +def retrieve(query: str): + # Your retrieval function + return documents + +retriever = LookaheadRetriever( + retrieval_fn=retrieve, + lookahead_window=3, +) + +# Start prefetching +retriever.start_prefetching(query_stream) + +# Get results (checks cache first) +results = retriever.get(query) +``` + +#### 4.3 BatchPrefetcher +Groups queries into batches for efficient batch retrieval. + +**Usage**: +```python +from models.prefetching import BatchPrefetcher + +def batch_retrieve(queries: List[str]): + # Batch retrieval function + return [documents for each query] + +prefetcher = BatchPrefetcher( + batch_retrieval_fn=batch_retrieve, + batch_size=8, +) + +prefetcher.start_prefetching(query_stream) +results = prefetcher.get(query) +``` + +### 5. Optimized Batch Inference + +**Location**: `models/optimized_attention.py` + +The `OptimizedInference` class provides batch generation utilities for processing multiple prompts efficiently. + +**Features**: +- Batch processing for multiple prompts +- Automatic padding and batching +- Efficient memory usage + +**Usage**: +```python +from models import OptimizedInference + +optimizer = model.get_optimized_inference() + +# Generate for multiple prompts in batches +results = optimizer.batch_generate( + input_ids_list=[prompt1_ids, prompt2_ids, ...], + max_length=100, + batch_size=8, +) +``` + +## Performance Improvements + +These optimizations provide the following benefits: + +1. **Faster Inference**: KV caching reduces generation time by 2-5x for long sequences +2. **Reduced Latency**: Prefetching reduces end-to-end latency by overlapping computation and I/O +3. **Lower Costs**: Retrieval caching reduces expensive vector database calls +4. **Better Throughput**: Batch processing increases throughput for multiple requests + +## Integration + +### Using Optimized Inference in Production + +1. **Enable optimized attention** (for inference only): +```python +model = TransformerModel( + ..., + use_optimized_attention=True, # Set in TransformerBlock +) +``` + +2. **Use optimized inference utility**: +```python +optimizer = model.get_optimized_inference() +generated = optimizer.generate_with_cache(...) +``` + +3. **Enable prefetching**: +```python +prefetch_loader = PrefetchDataLoader(dataloader, prefetch_factor=2) +``` + +### CLI Usage + +Use the `--optimized` flag when running inference: + +```bash +python inference.py \ + --checkpoint checkpoints/best_checkpoint.pt \ + --prompt "Your prompt here" \ + --optimized \ + --max-length 100 +``` + +## Example Script + +See `example_optimized.py` for complete examples of all optimizations. + +## References + +Based on optimizations from: +- "Optimizing LLM Inference and Retrieval: Novel Data Structures and Algorithms for Production RAG Systems" +- TeleRAG: Lookahead Retrieval Mechanism +- Flash Attention optimization techniques + diff --git a/docs/OPTIMIZATION_EXPLAINED.md b/docs/OPTIMIZATION_EXPLAINED.md new file mode 100644 index 0000000..91019b3 --- /dev/null +++ b/docs/OPTIMIZATION_EXPLAINED.md @@ -0,0 +1,848 @@ +# What is Optimization? Step-by-Step Explanation + +Complete step-by-step explanation of optimization in neural networks: how optimizers update weights to minimize loss. + +## Table of Contents + +1. [What is Optimization?](#71-what-is-optimization) +2. [The Optimization Problem](#72-the-optimization-problem) +3. [Gradient Descent](#73-gradient-descent) +4. [AdamW Optimizer](#74-adamw-optimizer) +5. [Why Optimization Matters](#75-why-optimization-matters) +6. [Complete Mathematical Formulation](#76-complete-mathematical-formulation) +7. [Exercise: Optimizer Step-by-Step](#77-exercise-optimizer-step-by-step) +8. [Key Takeaways](#78-key-takeaways) + +--- + +## 7.1 What is Optimization? + +### Simple Definition + +**Optimization** is the process of finding the best set of weights (parameters) that minimize the loss function and make the model's predictions as accurate as possible. + +### Visual Analogy + +**Think of optimization like finding the lowest point in a valley:** + +``` +Loss Landscape: + + High Loss + │ + │ ā— (current position) + │ ╱│╲ + │ ╱ │ ╲ + │ ╱ │ ╲ + │╱ │ ╲ + │ ā–¼ │ + │ (goal)│ + │ │ + Low Loss ā”€ā”€ā”€ā”€ā”€ā”˜ +``` + +**Goal:** Find the bottom of the valley (minimum loss) + +**Optimizer:** Your guide down the mountain + +### What Optimization Does + +**Optimization:** +1. **Measures** how wrong the model is (loss) +2. **Calculates** direction to improve (gradients) +3. **Updates** weights to reduce loss +4. **Repeats** until convergence + +**Result:** Model learns to make accurate predictions! + +### Optimization Process Flow + +```mermaid +graph TB + Start[Training Start] --> Init[Initialize Weights
Random Values] + Init --> Loop[Training Loop] + + Loop --> Forward[Forward Pass
Model Prediction] + Forward --> Loss["Compute Loss
L = loss(pred, target)"] + Loss --> Check{Converged?} + + Check -->|Yes| End[Training Complete] + Check -->|No| Gradient["Compute Gradients
āˆ‡L = āˆ‚L/āˆ‚Īø"] + + Gradient --> Optimize[Optimizer
Update Weights] + Optimize --> Update["New Weights
Īø = Īø - update"] + Update --> Loop + + style Start fill:#e1f5ff + style End fill:#e1ffe1 + style Optimize fill:#fff4e1 + style Check fill:#ffe1f5 +``` + +--- + +## 7.2 The Optimization Problem + +### The Objective + +**We want to minimize:** + +```math +L(\theta) = \frac{1}{N} \sum_{i=1}^{N} \ell(y_i, f(x_i; \theta)) +``` + +**Where:** +- $\theta$ = all model parameters (weights) +- $L$ = total loss +- $\ell$ = loss function (e.g., cross-entropy) +- $y_i$ = correct answer +- $f(x_i; \theta)$ = model prediction +- $N$ = number of examples + +### The Challenge + +**Problem:** Loss function is complex and high-dimensional + +**Solution:** Use iterative optimization algorithms + +**Process:** +``` +Initialize weights randomly +Repeat: + 1. Compute loss + 2. Compute gradients + 3. Update weights +Until convergence +``` + +### Optimization Problem Flowchart + +```mermaid +graph LR + subgraph "Optimization Problem" + A["Loss Function
L(Īø)"] --> B["Find Minimum
min L(Īø)"] + B --> C["Optimal Weights
Īø*"] + end + + subgraph "Solution Approach" + D["Initialize Īø"] --> E[Iterative Updates] + E --> F[Compute Loss] + F --> G[Compute Gradient] + G --> H[Update Weights] + H --> I{Converged?} + I -->|No| E + I -->|Yes| C + end + + A -.-> F + C -.-> C + + style A fill:#ffcccc + style B fill:#ffffcc + style C fill:#ccffcc + style E fill:#cce5ff +``` + +--- + +## 7.3 Gradient Descent + +### What is Gradient Descent? + +**Gradient Descent** is a basic optimization algorithm that updates weights by moving in the direction of steepest descent. + +### How It Works + +**Step 1: Compute Gradient** + +```math +\nabla_\theta L = \frac{\partial L}{\partial \theta} +``` + +**Gradient tells us:** +- Direction: Which way to go +- Magnitude: How steep the slope + +**Step 2: Update Weights** + +```math +\theta_{t+1} = \theta_t - \eta \nabla_\theta L +``` + +**Where:** +- $\theta_t$ = current weights +- $\eta$ = learning rate (step size) +- $\nabla_\theta L$ = gradient + +**Meaning:** Move weights in direction opposite to gradient + +### Visual Example + +``` +Loss Landscape (2D): + + Gradient + Direction + ↓ + ā— ──────┼───── → Lower Loss + │ + │ +``` + +**Move in direction of negative gradient!** + +### Gradient Descent Flowchart + +```mermaid +graph TB + subgraph "Gradient Descent Algorithm" + Start["Start: Initialize Īøā‚€"] --> Loop["For each iteration t"] + + Loop --> Forward[Forward Pass
Compute Predictions] + Forward --> Loss["Compute Loss
L(Īøā‚œ)"] + Loss --> Grad["Compute Gradient
g = āˆ‡L(Īøā‚œ)"] + + Grad --> Direction["Determine Direction
-g points to minimum"] + Direction --> Step["Take Step
Ī· Ɨ g"] + Step --> Update["Update Weights
Īøā‚œā‚Šā‚ = Īøā‚œ - Ī·g"] + + Update --> Check{"Converged?
|g| < ε"} + Check -->|No| Loop + Check -->|Yes| End["Found Minimum
Īø*"] + end + + subgraph "Gradient Information" + GradInfo["Gradient g contains:
- Direction: Which way to go
- Magnitude: How steep"] + end + + Grad -.-> GradInfo + + style Start fill:#e1f5ff + style End fill:#e1ffe1 + style Grad fill:#fff4e1 + style Check fill:#ffe1f5 + style Update fill:#ccffcc +``` + +### Types of Gradient Descent + +**1. Batch Gradient Descent:** +- Uses all training examples +- Most accurate gradients +- Slow for large datasets + +**2. Stochastic Gradient Descent (SGD):** +- Uses one example at a time +- Fast but noisy +- Can bounce around + +**3. Mini-Batch Gradient Descent:** +- Uses small batch of examples +- Balance of speed and accuracy +- Most commonly used + +### Gradient Descent Types Comparison + +```mermaid +graph TB + subgraph "Batch Gradient Descent" + B1[All Training Data] --> B2[Compute Gradient
on Full Dataset] + B2 --> B3[Single Update
Most Accurate] + B3 --> B4["Slow: O(N)"] + end + + subgraph "Stochastic Gradient Descent" + S1[Single Example] --> S2[Compute Gradient
on One Sample] + S2 --> S3[Many Updates
Fast but Noisy] + S3 --> S4["Fast: O(1)"] + end + + subgraph "Mini-Batch Gradient Descent" + M1[Small Batch
32-256 samples] --> M2[Compute Gradient
on Batch] + M2 --> M3[Balanced Updates
Good Accuracy] + M3 --> M4["Fast: O(batch_size)"] + end + + style B3 fill:#ccffcc + style S3 fill:#ffcccc + style M3 fill:#fff4e1 +``` + +--- + +## 7.4 AdamW Optimizer + +### What is AdamW? + +**AdamW** (Adam with Weight Decay) is an advanced optimizer that combines: +- **Adaptive learning rates** (like Adam) +- **Weight decay** (regularization) + +**Why AdamW?** +- Per-parameter learning rates +- Handles sparse gradients well +- Works great for transformers + +### How AdamW Works + +**Step 1: Compute Gradient** + +```math +g_t = \nabla_\theta L(\theta_t) +``` + +**Step 2: Update Momentum** + +```math +m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t +``` + +**Where:** +- $\beta_1 = 0.9$ (momentum decay) +- $m_t$ = first moment estimate + +**Meaning:** Moving average of gradients + +**Step 3: Update Variance** + +```math +v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 +``` + +**Where:** +- $\beta_2 = 0.999$ (variance decay) +- $v_t$ = second moment estimate + +**Meaning:** Moving average of squared gradients + +**Step 4: Bias Correction** + +```math +\hat{m}_t = \frac{m_t}{1 - \beta_1^t} +``` + +```math +\hat{v}_t = \frac{v_t}{1 - \beta_2^t} +``` + +**Why?** Corrects bias in early iterations + +**Step 5: Update Weights** + +```math +\theta_{t+1} = \theta_t - \eta \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \theta_t \right) +``` + +**Where:** +- $\eta$ = learning rate +- $\epsilon = 10^{-8}$ (small constant) +- $\lambda$ = weight decay coefficient + +**Key Points:** +- $\frac{\hat{m}_t}{\sqrt{\hat{v}_t}}$ = adaptive learning rate per parameter +- $\lambda \theta_t$ = weight decay (regularization) + +### AdamW Optimizer Flowchart + +```mermaid +graph TB + subgraph "AdamW Optimization Process" + Start["Start: Initialize
Īøā‚€, mā‚€=0, vā‚€=0"] --> Loop["For each iteration t"] + + Loop --> Forward["Forward Pass
Compute Loss L(Īøā‚œ)"] + Forward --> Grad["Step 1: Compute Gradient
gā‚œ = āˆ‡L(Īøā‚œ)"] + + Grad --> Mom["Step 2: Update Momentum
mā‚œ = β₁mā‚œā‚‹ā‚ + (1-β₁)gā‚œ"] + Mom --> Var["Step 3: Update Variance
vā‚œ = β₂vā‚œā‚‹ā‚ + (1-β₂)gā‚œĀ²"] + + Var --> Bias["Step 4: Bias Correction
mĢ‚ā‚œ = mā‚œ/(1-β₁ᵗ)
vĢ‚ā‚œ = vā‚œ/(1-β₂ᵗ)"] + + Bias --> Adapt["Step 5: Adaptive LR
LR = Ī·/(√vĢ‚ā‚œ + ε)"] + + Adapt --> Decay["Step 6: Weight Decay
Ī»Īøā‚œ"] + + Decay --> Update["Step 7: Update Weights
Īøā‚œā‚Šā‚ = Īøā‚œ - LRƗmĢ‚ā‚œ - Ī»Īøā‚œ"] + + Update --> Check{Converged?} + Check -->|No| Loop + Check -->|Yes| End["Optimal Weights Īø*"] + end + + subgraph "Key Components" + C1["Momentum mā‚œ
Moving avg of gradients"] + C2["Variance vā‚œ
Moving avg of g²"] + C3["Adaptive LR
Per-parameter learning rate"] + C4["Weight Decay
Regularization"] + end + + Mom -.-> C1 + Var -.-> C2 + Adapt -.-> C3 + Decay -.-> C4 + + style Start fill:#e1f5ff + style End fill:#e1ffe1 + style Grad fill:#fff4e1 + style Adapt fill:#ccffcc + style Update fill:#ccffcc + style Check fill:#ffe1f5 +``` + +### AdamW Detailed Subgraph + +```mermaid +graph LR + subgraph "Input" + I1["Gradient gā‚œ"] + I2["Previous Momentum mā‚œā‚‹ā‚"] + I3["Previous Variance vā‚œā‚‹ā‚"] + I4["Current Weights Īøā‚œ"] + end + + subgraph "Momentum Update" + M1["Multiply: β₁mā‚œā‚‹ā‚"] --> M2["Combine: β₁mā‚œā‚‹ā‚ + (1-β₁)gā‚œ"] + I2 --> M1 + I1 --> M2 + end + + subgraph "Variance Update" + V1["Square: gā‚œĀ²"] --> V2["Combine: β₂vā‚œā‚‹ā‚ + (1-β₂)gā‚œĀ²"] + I3 --> V2 + I1 --> V1 + end + + subgraph "Bias Correction" + M2 --> BC1["mĢ‚ā‚œ = mā‚œ/(1-β₁ᵗ)"] + V2 --> BC2["vĢ‚ā‚œ = vā‚œ/(1-β₂ᵗ)"] + end + + subgraph "Adaptive Learning Rate" + BC2 --> ALR["LR = Ī·/(√vĢ‚ā‚œ + ε)"] + end + + subgraph "Weight Update" + BC1 --> WU1["Adaptive Step: LR Ɨ mĢ‚ā‚œ"] + ALR --> WU1 + I4 --> WU2["Decay Step: Ī»Īøā‚œ"] + WU1 --> WU3["Update: Īøā‚œā‚Šā‚ = Īøā‚œ - LRƗmĢ‚ā‚œ - Ī»Īøā‚œ"] + WU2 --> WU3 + end + + style M2 fill:#e1f5ff + style V2 fill:#e1f5ff + style BC1 fill:#fff4e1 + style BC2 fill:#fff4e1 + style ALR fill:#ccffcc + style WU3 fill:#ccffcc +``` + +### Why AdamW is Better + +**Compared to SGD:** + +**SGD:** +``` +Same learning rate for all parameters +→ Slow convergence +→ Manual tuning needed +``` + +**AdamW:** +``` +Adaptive learning rate per parameter +→ Faster convergence +→ Less manual tuning +``` + +**Benefits:** +1. **Adaptive:** Each parameter gets its own learning rate +2. **Robust:** Works well with noisy gradients +3. **Efficient:** Converges faster than SGD +4. **Regularized:** Weight decay prevents overfitting + +### SGD vs AdamW Comparison + +```mermaid +graph TB + subgraph "Stochastic Gradient Descent" + SGD1["Gradient gā‚œ"] --> SGD2["Fixed Learning Rate Ī·"] + SGD2 --> SGD3["Update: Īøā‚œā‚Šā‚ = Īøā‚œ - Ī·gā‚œ"] + SGD3 --> SGD4["All params same LR"] + SGD4 --> SGD5["Slow Convergence
Manual Tuning"] + end + + subgraph "AdamW Optimizer" + AD1["Gradient gā‚œ"] --> AD2["Momentum mā‚œ"] + AD1 --> AD3["Variance vā‚œ"] + AD2 --> AD4[Bias Correction] + AD3 --> AD4 + AD4 --> AD5["Adaptive LR per param"] + AD5 --> AD6["Update: Īøā‚œā‚Šā‚ = Īøā‚œ - LRƗmĢ‚ā‚œ - Ī»Īøā‚œ"] + AD6 --> AD7["Fast Convergence
Less Tuning"] + end + + subgraph "Comparison" + Comp1["Same Model
Same Data"] + Comp1 --> Comp2["SGD: Loss = 2.5
After 100 epochs"] + Comp1 --> Comp3["AdamW: Loss = 1.8
After 100 epochs"] + Comp3 --> Comp4[AdamW is Better!] + end + + SGD5 -.-> Comp2 + AD7 -.-> Comp3 + + style SGD5 fill:#ffcccc + style AD7 fill:#ccffcc + style Comp4 fill:#e1ffe1 +``` + +--- + +## 7.5 Why Optimization Matters + +### Reason 1: Without Optimization + +**Random weights:** +``` +Weights: Random values +Loss: Very high +Predictions: Random +Model: Useless +``` + +### Reason 2: With Optimization + +**Learned weights:** +``` +Weights: Optimized values +Loss: Low +Predictions: Accurate +Model: Useful +``` + +### Reason 3: Determines Learning Speed + +**Good optimizer:** +- Fast convergence +- Stable training +- Good final performance + +**Poor optimizer:** +- Slow convergence +- Unstable training +- Poor final performance + +### Reason 4: Affects Final Performance + +**Same model, different optimizers:** + +``` +SGD: Loss = 2.5 (after 100 epochs) +AdamW: Loss = 1.8 (after 100 epochs) +``` + +**Better optimizer = Better model!** + +### Optimization Impact Visualization + +```mermaid +graph LR + subgraph "Without Optimization" + WO1[Random Weights] --> WO2["High Loss
L ā‰ˆ 8-10"] + WO2 --> WO3[Random Predictions] + WO3 --> WO4[Model Useless] + end + + subgraph "With Optimization" + W1[Random Weights] --> W2[Optimization Loop] + W2 --> W3[Update Weights] + W3 --> W4["Low Loss
L ā‰ˆ 1-2"] + W4 --> W5[Accurate Predictions] + W5 --> W6[Model Useful] + end + + subgraph "Optimizer Quality" + O1["Poor Optimizer
SGD, Bad LR"] --> O2["Slow Convergence
Loss = 2.5"] + O3["Good Optimizer
AdamW, Proper LR"] --> O4["Fast Convergence
Loss = 1.8"] + end + + W2 -.-> O1 + W2 -.-> O3 + + style WO4 fill:#ffcccc + style W6 fill:#ccffcc + style O2 fill:#ffcccc + style O4 fill:#ccffcc +``` + +--- + +## 7.6 Complete Mathematical Formulation + +### Optimization Problem + +```math +\theta^* = \arg\min_{\theta} L(\theta) +``` + +**Where $\theta^*$ is the optimal set of weights** + +### Gradient Descent Update + +```math +\theta_{t+1} = \theta_t - \eta \nabla_\theta L(\theta_t) +``` + +### AdamW Update (Complete) + +**For each parameter $\theta_i$:** + +**Gradient:** +```math +g_{t,i} = \frac{\partial L}{\partial \theta_{t,i}} +``` + +**Momentum:** +```math +m_{t,i} = \beta_1 m_{t-1,i} + (1 - \beta_1) g_{t,i} +``` + +**Variance:** +```math +v_{t,i} = \beta_2 v_{t-1,i} + (1 - \beta_2) g_{t,i}^2 +``` + +**Bias Correction:** +```math +\hat{m}_{t,i} = \frac{m_{t,i}}{1 - \beta_1^t} +``` + +```math +\hat{v}_{t,i} = \frac{v_{t,i}}{1 - \beta_2^t} +``` + +**Update:** +```math +\theta_{t+1,i} = \theta_{t,i} - \eta \left( \frac{\hat{m}_{t,i}}{\sqrt{\hat{v}_{t,i}} + \epsilon} + \lambda \theta_{t,i} \right) +``` + +**Where:** +- $\beta_1 = 0.9$ +- $\beta_2 = 0.999$ +- $\epsilon = 10^{-8}$ +- $\lambda$ = weight decay (e.g., 0.01) + +### Complete Mathematical Flow + +```mermaid +graph TB + subgraph "Optimization Problem" + OP1["Loss Function L(Īø)"] --> OP2["Find: Īø* = argmin L(Īø)"] + end + + subgraph "Gradient Computation" + GC1[Forward Pass] --> GC2[Compute Loss L] + GC2 --> GC3[Backpropagation] + GC3 --> GC4["Gradient gįµ¢ = āˆ‚L/āˆ‚Īøįµ¢"] + end + + subgraph "AdamW Update Steps" + GC4 --> AU1["Momentum: mįµ¢ = β₁m + (1-β₁)gįµ¢"] + AU1 --> AU2["Variance: vįµ¢ = β₂v + (1-β₂)gᵢ²"] + AU2 --> AU3["Bias Correction:
mĢ‚ = m/(1-β₁ᵗ), vĢ‚ = v/(1-β₂ᵗ)"] + AU3 --> AU4["Adaptive LR: Ī·/(√vĢ‚ + ε)"] + AU4 --> AU5["Update: Īøįµ¢ = Īøįµ¢ - LRƗmĢ‚ - λθᵢ"] + end + + subgraph "Convergence Check" + AU5 --> CC1{Converged?} + CC1 -->|No| GC1 + CC1 -->|Yes| CC2["Optimal Weights Īø*"] + end + + OP2 -.-> GC1 + CC2 -.-> OP2 + + style OP2 fill:#ffffcc + style GC4 fill:#fff4e1 + style AU5 fill:#ccffcc + style CC2 fill:#e1ffe1 +``` + +--- + +## 7.7 Exercise: Optimizer Step-by-Step + +### Problem + +**Given:** +- Current weight: $\theta_0 = 2.0$ +- Loss function: $L(\theta) = (\theta - 1)^2$ +- Learning rate: $\eta = 0.1$ +- Use AdamW with $\beta_1 = 0.9$, $\beta_2 = 0.999$, $\lambda = 0.01$ +- Initial moments: $m_0 = 0$, $v_0 = 0$ + +**Calculate the weight update for step 1.** + +### Step-by-Step Solution + +#### Step 1: Compute Gradient + +**Loss function:** +```math +L(\theta) = (\theta - 1)^2 +``` + +**Gradient:** +```math +g_1 = \frac{\partial L}{\partial \theta} = 2(\theta - 1) +``` + +**At $\theta_0 = 2.0$:** +```math +g_1 = 2(2.0 - 1) = 2(1.0) = 2.0 +``` + +#### Step 2: Update Momentum + +```math +m_1 = \beta_1 m_0 + (1 - \beta_1) g_1 +``` + +```math +m_1 = 0.9 \times 0 + (1 - 0.9) \times 2.0 = 0 + 0.1 \times 2.0 = 0.2 +``` + +#### Step 3: Update Variance + +```math +v_1 = \beta_2 v_0 + (1 - \beta_2) g_1^2 +``` + +```math +v_1 = 0.999 \times 0 + (1 - 0.999) \times (2.0)^2 = 0 + 0.001 \times 4.0 = 0.004 +``` + +#### Step 4: Bias Correction + +```math +\hat{m}_1 = \frac{m_1}{1 - \beta_1^1} = \frac{0.2}{1 - 0.9} = \frac{0.2}{0.1} = 2.0 +``` + +```math +\hat{v}_1 = \frac{v_1}{1 - \beta_2^1} = \frac{0.004}{1 - 0.999} = \frac{0.004}{0.001} = 4.0 +``` + +#### Step 5: Compute Update + +```math +\Delta \theta_1 = \eta \left( \frac{\hat{m}_1}{\sqrt{\hat{v}_1} + \epsilon} + \lambda \theta_0 \right) +``` + +```math +\Delta \theta_1 = 0.1 \left( \frac{2.0}{\sqrt{4.0} + 10^{-8}} + 0.01 \times 2.0 \right) +``` + +```math +\Delta \theta_1 = 0.1 \left( \frac{2.0}{2.0 + 0.00000001} + 0.02 \right) +``` + +```math +\Delta \theta_1 = 0.1 \left( \frac{2.0}{2.0} + 0.02 \right) = 0.1 (1.0 + 0.02) = 0.1 \times 1.02 = 0.102 +``` + +#### Step 6: Update Weight + +```math +\theta_1 = \theta_0 - \Delta \theta_1 = 2.0 - 0.102 = 1.898 +``` + +### Answer + +**After one step:** +- Old weight: $\theta_0 = 2.0$ +- New weight: $\theta_1 = 1.898$ +- Update: $\Delta \theta_1 = -0.102$ + +**The weight moved closer to the optimal value (1.0)!** + +### Verification + +**Check loss:** +- Old loss: $L(2.0) = (2.0 - 1)^2 = 1.0$ +- New loss: $L(1.898) = (1.898 - 1)^2 = 0.806$ + +**Loss decreased! āœ“** + +### Exercise Solution Flowchart + +```mermaid +graph TB + subgraph "Given Values" + G1["Īøā‚€ = 2.0"] --> Start + G2["mā‚€ = 0"] --> Start + G3["vā‚€ = 0"] --> Start + G4["Ī· = 0.1, β₁ = 0.9
β₂ = 0.999, Ī» = 0.01"] --> Start + end + + Start[Start] --> Step1["Step 1: Compute Gradient
L(θ) = (θ-1)²
g₁ = 2(Īøā‚€-1) = 2.0"] + + Step1 --> Step2["Step 2: Update Momentum
m₁ = 0.9Ɨ0 + 0.1Ɨ2.0
m₁ = 0.2"] + + Step2 --> Step3["Step 3: Update Variance
v₁ = 0.999Ɨ0 + 0.001Ɨ4.0
v₁ = 0.004"] + + Step3 --> Step4["Step 4: Bias Correction
m̂₁ = 0.2/(1-0.9) = 2.0
v̂₁ = 0.004/(1-0.999) = 4.0"] + + Step4 --> Step5["Step 5: Compute Update
Δθ₁ = 0.1Ɨ(2.0/√4.0 + 0.01Ɨ2.0)
Δθ₁ = 0.102"] + + Step5 --> Step6["Step 6: Update Weight
θ₁ = 2.0 - 0.102
θ₁ = 1.898"] + + Step6 --> Verify["Verification:
L(2.0) = 1.0 → L(1.898) = 0.806
Loss Decreased!"] + + Verify --> End["Result: θ₁ = 1.898
Closer to optimum Īø* = 1.0"] + + style Start fill:#e1f5ff + style Step1 fill:#fff4e1 + style Step2 fill:#fff4e1 + style Step3 fill:#fff4e1 + style Step4 fill:#fff4e1 + style Step5 fill:#fff4e1 + style Step6 fill:#ccffcc + style Verify fill:#e1ffe1 + style End fill:#e1ffe1 +``` + +--- + +## 7.8 Key Takeaways + +### Optimization + +āœ… **Optimization finds best weights to minimize loss** +āœ… **Uses gradients to determine update direction** +āœ… **Iterative process: compute → update → repeat** + +### Gradient Descent + +āœ… **Basic algorithm: move opposite to gradient** +āœ… **Learning rate controls step size** +āœ… **Can be slow for complex problems** + +### AdamW + +āœ… **Advanced optimizer with adaptive learning rates** +āœ… **Each parameter gets its own learning rate** +āœ… **Combines momentum and variance estimates** +āœ… **Includes weight decay for regularization** +āœ… **Works great for transformers** + +### Why Important + +āœ… **Determines how fast model learns** +āœ… **Affects final model performance** +āœ… **Essential for training neural networks** + +--- + +*This document provides a comprehensive explanation of optimization in neural networks, including gradient descent and AdamW optimizer with mathematical formulations and solved exercises.* + diff --git a/docs/PAIN_POINTS_AND_OPPORTUNITIES.md b/docs/PAIN_POINTS_AND_OPPORTUNITIES.md new file mode 100644 index 0000000..5b40467 --- /dev/null +++ b/docs/PAIN_POINTS_AND_OPPORTUNITIES.md @@ -0,0 +1,706 @@ +# LLM Pain Points & Market Opportunities + +A comprehensive analysis of the main challenges in language models and emerging opportunities in the market. + +## Table of Contents + +1. [Main Pain Points](#main-pain-points) +2. [Market Opportunities](#market-opportunities) +3. [Technical Solutions](#technical-solutions) +4. [Market Segments](#market-segments) +5. [Future Trends](#future-trends) + +--- + +## Main Pain Points + +### 1. Training Costs & Resource Requirements + +**The Problem:** +- **Extremely expensive**: Training GPT-3 cost ~$4.6M, GPT-4 likely $100M+ +- **Massive compute requirements**: Requires thousands of GPUs for months +- **High barrier to entry**: Only large corporations can afford training from scratch +- **Lengthy development cycles**: Months to years to train and iterate + +**Impact:** +``` +Small Companies: Cannot compete +Researchers: Limited access to resources +Innovation: Slowed by cost barriers +``` + +**Numbers:** +- GPT-3: 300B tokens, $4.6M training cost +- GPT-4: Estimated $100M+ training cost +- Training time: 3-6 months on thousands of GPUs +- Infrastructure: Data centers with specialized hardware + +### 2. Inference Latency & Speed + +**The Problem:** +- **Slow generation**: High-quality models generate 10-50 tokens/second +- **High latency**: 500ms-5s response time for queries +- **Poor scalability**: Linear scaling with number of users +- **Real-time constraints**: Difficult to achieve interactive speeds + +**Impact:** +``` +User Experience: Frustrating delays +Applications: Limited to batch processing +Real-time Use: Not feasible for many cases +Cost: More compute = slower response +``` + +**Current Performance:** +- Standard inference: 10-50 tokens/sec +- High-end GPUs: 100-200 tokens/sec +- With optimizations: 200-500 tokens/sec +- Target for real-time: 1000+ tokens/sec + +### 3. Memory Consumption + +**The Problem:** +- **Massive memory requirements**: + - GPT-3 175B: ~350GB GPU memory + - GPT-4: Estimated ~700GB+ memory +- **Inefficient memory usage**: Attention matrices scale quadratically +- **Limited device support**: Cannot run on consumer hardware +- **High infrastructure costs**: Requires expensive GPUs + +**Impact:** +``` +Deployment: Expensive server infrastructure +Accessibility: Limited to cloud providers +Edge Devices: Impossible without optimization +Cost: High memory = high server costs +``` + +**Memory Breakdown:** +- Model weights: 50-70% of memory +- KV cache: 20-30% during inference +- Activations: 10-20% during forward pass +- Overhead: 5-10% for framework + +### 4. Energy Consumption & Environmental Impact + +**The Problem:** +- **Extremely high energy usage**: + - GPT-3 training: ~3,287 MWh (~$1.4M electricity) + - Continuous inference: High carbon footprint +- **Environmental concerns**: Equivalent to significant CO2 emissions +- **Sustainability issues**: Unsustainable scaling + +**Impact:** +``` +Environment: Significant carbon footprint +Cost: High electricity bills +Regulation: Increasing environmental regulations +Public Perception: Growing concern about AI's impact +``` + +**Numbers:** +- Training GPT-3: ~552 metric tons CO2 equivalent +- Daily inference: Thousands of MWh per day globally +- Cost: Electricity is major operational expense + +### 5. Data Dependency & Quality + +**The Problem:** +- **Massive data requirements**: Billions of tokens needed +- **Data quality issues**: Garbage in, garbage out +- **Bias in training data**: Models inherit societal biases +- **Copyright concerns**: Training on copyrighted material +- **Data scarcity**: High-quality data is limited + +**Impact:** +``` +Quality: Poor data = poor models +Bias: Perpetuates existing biases +Legal: Copyright and licensing issues +Cost: Data acquisition is expensive +``` + +**Requirements:** +- GPT-3: 300B tokens (~45TB of text) +- Data cleaning: 70-80% of data preparation time +- Quality control: Critical but expensive +- Diversity: Need diverse, representative data + +### 6. Hallucination & Reliability + +**The Problem:** +- **Factual inaccuracies**: Models generate plausible but false information +- **Inconsistent outputs**: Same prompt can give different answers +- **Difficulty verifying**: Hard to distinguish truth from hallucination +- **Confidence estimation**: Models don't know when they're wrong + +**Impact:** +``` +Trust: Users lose confidence +Applications: Cannot use for critical tasks +Verification: Requires human oversight +Legal: Liability concerns +``` + +**Examples:** +- Medical advice: Could be dangerous +- Financial information: Could cause losses +- Legal documents: Could have serious consequences +- Scientific facts: Could mislead researchers + +### 7. Fine-tuning & Customization Complexity + +**The Problem:** +- **Time-consuming**: Days to weeks for fine-tuning +- **Expensive**: Requires significant compute resources +- **Technical expertise**: Requires deep ML knowledge +- **Dataset preparation**: Complex and time-consuming +- **Hyperparameter tuning**: Trial and error process + +**Impact:** +``` +Adoption: High barrier for businesses +Iteration: Slow feedback loops +Cost: Expensive experimentation +Expertise: Limited talent pool +``` + +**Challenges:** +- LoRA vs full fine-tuning: Trade-offs unclear +- Data requirements: How much data is needed? +- Evaluation: How to measure success? +- Deployment: Complex integration process + +### 8. Scalability & Infrastructure + +**The Problem:** +- **Horizontal scaling**: Difficult to distribute inference +- **Load balancing**: Complex for stateful models +- **Cost scaling**: Linear cost increase with users +- **Infrastructure management**: Requires DevOps expertise +- **High availability**: Complex to achieve 99.9%+ uptime + +**Impact:** +``` +Growth: Limits ability to scale +Cost: Infrastructure costs grow with usage +Reliability: Complex to maintain +Engineering: Requires significant resources +``` + +**Issues:** +- State management: KV cache complicates scaling +- Batch processing: Inefficient for single requests +- Geographic distribution: Latency vs consistency +- Cost optimization: Balancing performance and cost + +--- + +## Market Opportunities + +### 1. Efficient Training & Fine-tuning Solutions + +**Opportunity:** +- **Problem**: Training is too expensive and slow +- **Solution**: Efficient training methods, LoRA, quantization +- **Market Size**: $2-5B by 2027 +- **Key Players**: Hugging Face, Cohere, Anthropic + +**Technologies:** +- **LoRA (Low-Rank Adaptation)**: 10-100x cheaper fine-tuning +- **Quantization**: 4x-8x memory reduction +- **Gradient checkpointing**: 2x memory savings +- **Distributed training**: Optimize multi-GPU setups + +**Market Segments:** +- Enterprise fine-tuning platforms +- Training optimization tools +- Pre-trained model marketplaces +- Model compression services + +**Revenue Models:** +- SaaS platforms for fine-tuning +- Consulting services +- Model licensing +- Training infrastructure + +### 2. Inference Optimization & Acceleration + +**Opportunity:** +- **Problem**: Inference is too slow and expensive +- **Solution**: KV caching, quantization, model pruning +- **Market Size**: $5-10B by 2027 +- **Key Players**: NVIDIA, TensorRT, vLLM + +**Technologies:** +- **KV Caching**: 2-5x speedup +- **Quantization**: 4x faster inference +- **Model pruning**: 2-4x speedup +- **Specialized hardware**: TPUs, specialized chips + +**Market Segments:** +- Real-time applications +- Edge deployment +- High-throughput services +- Cost-sensitive applications + +**Competitive Advantages:** +- Ease of integration +- Performance improvements +- Cost reduction +- Developer experience + +### 3. Edge & Mobile Deployment + +**Opportunity:** +- **Problem**: Models too large for edge devices +- **Solution**: Model compression, quantization, distillation +- **Market Size**: $3-8B by 2027 +- **Key Players**: Qualcomm, Apple, Google + +**Technologies:** +- **Model distillation**: Smaller, faster models +- **Quantization**: INT8/INT4 inference +- **Pruning**: Remove unnecessary weights +- **On-device ML**: Specialized hardware + +**Market Segments:** +- Smartphones +- IoT devices +- Autonomous vehicles +- AR/VR devices + +**Applications:** +- Voice assistants +- Camera processing +- Real-time translation +- Personalization + +### 4. Domain-Specific Solutions + +**Opportunity:** +- **Problem**: General models underperform in specific domains +- **Solution**: Specialized models for industries +- **Market Size**: $10-20B by 2027 +- **Key Players**: Industry-specific startups + +**Industries:** +- **Healthcare**: Medical diagnosis, drug discovery +- **Finance**: Fraud detection, trading algorithms +- **Legal**: Contract analysis, legal research +- **Education**: Personalized tutoring, content generation +- **Customer Service**: Support automation, chatbots + +**Value Propositions:** +- Higher accuracy in domain +- Regulatory compliance +- Custom integrations +- Expert knowledge built-in + +**Revenue Models:** +- SaaS subscriptions +- Per-query pricing +- Enterprise licenses +- White-label solutions + +### 5. Model Evaluation & Safety Tools + +**Opportunity:** +- **Problem**: Hard to evaluate model quality and safety +- **Solution**: Comprehensive evaluation frameworks +- **Market Size**: $500M-2B by 2027 +- **Key Players**: OpenAI, Anthropic, startup ecosystem + +**Tools Needed:** +- **Evaluation frameworks**: Benchmark suites +- **Bias detection**: Identify and measure bias +- **Safety testing**: Jailbreak detection, adversarial testing +- **Explainability**: Understanding model decisions + +**Market Segments:** +- Enterprise model validation +- Regulatory compliance +- Research institutions +- Government agencies + +**Applications:** +- Pre-deployment testing +- Continuous monitoring +- Regulatory reporting +- Risk assessment + +### 6. Data & Training Infrastructure + +**Opportunity:** +- **Problem**: Data preparation is expensive and time-consuming +- **Solution**: Automated data pipelines and quality tools +- **Market Size**: $2-5B by 2027 +- **Key Players**: Scale AI, Labelbox, Label Studio + +**Solutions:** +- **Data labeling**: Automated and human-in-the-loop +- **Data quality**: Cleaning and validation tools +- **Data pipelines**: ETL for ML workflows +- **Synthetic data**: Generate training data + +**Market Segments:** +- Data labeling services +- Quality assurance tools +- Data pipeline platforms +- Synthetic data generation + +**Value:** +- Faster data preparation +- Higher quality training data +- Reduced costs +- Better model performance + +### 7. Cost Optimization & Infrastructure + +**Opportunity:** +- **Problem**: Infrastructure costs are prohibitive +- **Solution**: Optimized cloud services, cost management +- **Market Size**: $5-15B by 2027 +- **Key Players**: AWS, Google Cloud, Azure, specialized providers + +**Solutions:** +- **GPU optimization**: Better utilization +- **Model serving**: Efficient inference infrastructure +- **Cost monitoring**: Track and optimize spending +- **Multi-cloud**: Avoid vendor lock-in + +**Market Segments:** +- Cloud providers +- Infrastructure optimization +- Cost management tools +- Managed ML services + +**Value:** +- Reduced infrastructure costs +- Better performance +- Easier scaling +- Cost transparency + +### 8. Open Source & Community Models + +**Opportunity:** +- **Problem**: Proprietary models lock users in +- **Solution**: Open source alternatives +- **Market Size**: Growing rapidly +- **Key Players**: Hugging Face, Stability AI, Meta + +**Trends:** +- **Open source models**: Llama, Mistral, Falcon +- **Model sharing**: Hugging Face Hub +- **Community contributions**: Faster innovation +- **Transparency**: Open weights and training data + +**Market Impact:** +- Lower barriers to entry +- Faster innovation +- More competition +- Better accessibility + +**Business Models:** +- Open source with premium features +- Hosting and infrastructure +- Support and consulting +- Enterprise editions + +--- + +## Technical Solutions + +### Current Solutions Addressing Pain Points + +#### 1. Training Optimization + +**LoRA (Low-Rank Adaptation)** +- **Impact**: 10-100x cheaper fine-tuning +- **Use Case**: Customizing models for specific tasks +- **Adoption**: Widespread in research and industry + +**Quantization** +- **Impact**: 4x-8x memory reduction +- **Use Case**: Fitting larger models on smaller GPUs +- **Adoption**: Growing rapidly + +**Gradient Checkpointing** +- **Impact**: 2x memory savings +- **Use Case**: Training larger models +- **Adoption**: Standard practice + +**Distributed Training** +- **Impact**: Faster training, larger models +- **Use Case**: Training billion-parameter models +- **Adoption**: Required for large models + +#### 2. Inference Optimization + +**KV Caching** +- **Impact**: 2-5x speedup +- **Use Case**: Autoregressive generation +- **Adoption**: Standard in production + +**Quantization** +- **Impact**: 4x faster inference +- **Use Case**: Production deployment +- **Adoption**: Common in production + +**Model Pruning** +- **Impact**: 2-4x speedup, smaller models +- **Use Case**: Edge deployment +- **Adoption**: Growing for edge devices + +**Batch Processing** +- **Impact**: Better GPU utilization +- **Use Case**: High-throughput scenarios +- **Adoption**: Standard practice + +#### 3. Memory Optimization + +**Flash Attention** +- **Impact**: 2x memory reduction +- **Use Case**: Long sequences +- **Adoption**: Standard in new models + +**Gradient Checkpointing** +- **Impact**: 2x memory savings +- **Use Case**: Training +- **Adoption**: Common practice + +**Model Sharding** +- **Impact**: Distribute across GPUs +- **Use Case**: Large models +- **Adoption**: Required for large models + +**Quantization** +- **Impact**: 4x-8x memory reduction +- **Use Case**: Inference and training +- **Adoption**: Increasing rapidly + +--- + +## Market Segments + +### 1. Enterprise Software + +**Size**: $10-30B by 2027 +**Characteristics**: +- High willingness to pay +- Enterprise features required +- Compliance and security critical +- Custom integrations needed + +**Key Players**: OpenAI, Anthropic, Google, Microsoft +**Opportunities**: Vertical solutions, integrations, compliance + +### 2. Developer Tools & APIs + +**Size**: $5-15B by 2027 +**Characteristics**: +- Developer-friendly APIs +- Good documentation +- Competitive pricing +- Reliability critical + +**Key Players**: OpenAI, Anthropic, Cohere, Hugging Face +**Opportunities**: Better APIs, developer experience, pricing + +### 3. Consumer Applications + +**Size**: $5-20B by 2027 +**Characteristics**: +- Price-sensitive +- User experience critical +- Scale requirements +- Privacy concerns + +**Key Players**: +- [ChatGPT](https://chat.openai.com) - OpenAI's conversational AI platform +- [Claude](https://claude.ai) - Anthropic's AI assistant +- [Perplexity](https://www.perplexity.ai) - AI-powered search engine +- [Character.AI](https://character.ai) - Conversational AI characters platform + +**Opportunities**: Better UX, lower costs, privacy + +### 4. Research & Academia + +**Size**: $1-3B by 2027 +**Characteristics**: +- Open access preferred +- Reproducibility important +- Educational pricing +- Community support + +**Key Players**: Hugging Face, EleutherAI, Academic institutions +**Opportunities**: Open source, educational tools, grants + +### 5. Infrastructure & Cloud + +**Size**: $10-25B by 2027 +**Characteristics**: +- Scale critical +- Reliability essential +- Cost optimization +- Multi-cloud support + +**Key Players**: AWS, Google Cloud, Azure, specialized providers +**Opportunities**: Better infrastructure, cost optimization + +--- + +## Future Trends + +### 1. Efficiency Improvements + +**Trend**: Continued focus on efficiency +- **Smaller models**: Better performance per parameter +- **Smarter architectures**: More efficient attention mechanisms +- **Hardware optimization**: Specialized chips for LLMs +- **Algorithm improvements**: Better training and inference methods + +**Impact**: Lower costs, better accessibility, faster adoption + +### 2. Edge Deployment + +**Trend**: Moving LLMs to edge devices +- **Model compression**: Smaller, faster models +- **Hardware acceleration**: Specialized mobile chips +- **Hybrid approaches**: Cloud + edge combination +- **Privacy**: On-device processing + +**Impact**: Better privacy, lower latency, new applications + +### 3. Specialized Models + +**Trend**: Domain-specific models +- **Industry focus**: Healthcare, finance, legal, etc. +- **Better performance**: Domain expertise built-in +- **Regulatory compliance**: Built-in compliance features +- **Integration**: Easier integration with existing systems + +**Impact**: Better performance, regulatory compliance, market segmentation + +### 4. Open Source Growth + +**Trend**: Growing open source ecosystem +- **More models**: Better open source alternatives +- **Community innovation**: Faster development +- **Transparency**: Open weights and training data +- **Accessibility**: Lower barriers to entry + +**Impact**: More competition, faster innovation, better accessibility + +### 5. Safety & Alignment + +**Trend**: Focus on safety and alignment +- **Evaluation frameworks**: Better testing tools +- **Safety mechanisms**: Built-in safety features +- **Alignment research**: Better understanding of alignment +- **Regulation**: Increasing regulatory requirements + +**Impact**: Safer models, regulatory compliance, public trust + +### 6. Multimodal Expansion + +**Trend**: Beyond text to images, audio, video +- **Multimodal models**: Text + images + audio +- **New applications**: Creative tools, video generation +- **Unified models**: Single model for multiple modalities +- **Interactions**: Better human-AI interaction + +**Impact**: New applications, larger market, more complexity + +### 7. Personalization + +**Trend**: Highly personalized models +- **Fine-tuning**: Easy personalization +- **User data**: Learning from user interactions +- **Privacy**: Balancing personalization and privacy +- **Customization**: User-controlled customization + +**Impact**: Better user experience, privacy challenges, new applications + +### 8. Cost Reduction + +**Trend**: Continued cost reduction +- **Efficiency**: Better algorithms and hardware +- **Competition**: More providers, lower prices +- **Optimization**: Better resource utilization +- **Accessibility**: Lower costs enable more use cases + +**Impact**: More adoption, new applications, democratization + +--- + +## Summary + +### Key Pain Points + +1. **Training Costs**: Extremely expensive, limiting access +2. **Inference Speed**: Too slow for many applications +3. **Memory Usage**: Too large for most devices +4. **Energy Consumption**: Environmental concerns +5. **Data Dependency**: Need massive, high-quality data +6. **Hallucination**: Reliability and trust issues +7. **Fine-tuning Complexity**: Difficult to customize +8. **Scalability**: Infrastructure challenges + +### Major Opportunities + +1. **Efficient Training**: LoRA, quantization, optimization tools +2. **Inference Optimization**: KV caching, acceleration, compression +3. **Edge Deployment**: Mobile and IoT applications +4. **Domain-Specific Solutions**: Industry verticals +5. **Evaluation Tools**: Safety and quality frameworks +6. **Data Infrastructure**: Automated pipelines and quality tools +7. **Cost Optimization**: Infrastructure and cloud services +8. **Open Source**: Community-driven innovation + +### Market Size + +**Total Addressable Market**: $50-100B+ by 2027 +- Enterprise Software: $10-30B +- Developer Tools: $5-15B +- Consumer Applications: $5-20B +- Infrastructure: $10-25B +- Research & Academia: $1-3B +- Specialized Solutions: $5-10B + +### Competitive Landscape + +**Established Players**: OpenAI, Google, Anthropic, Microsoft +**Rising Stars**: Hugging Face, Cohere, Stability AI +**Infrastructure**: AWS, Google Cloud, Azure, NVIDIA +**Open Source**: Meta, EleutherAI, Community + +### Success Factors + +- **Technical Excellence**: Best performance and efficiency +- **Developer Experience**: Easy to use and integrate +- **Cost Effectiveness**: Competitive pricing +- **Reliability**: Consistent performance +- **Innovation**: Continuous improvement +- **Community**: Strong ecosystem support + +--- + +## Conclusion + +The LLM market presents significant challenges but also enormous opportunities. The main pain points—cost, speed, memory, and reliability—create clear market opportunities for companies that can solve these problems. + +**Key Takeaways:** + +1. **Cost is the primary barrier**: Solutions that reduce training and inference costs will have significant market value +2. **Speed matters**: Real-time applications require optimization +3. **Efficiency is critical**: Better algorithms and hardware unlock new use cases +4. **Specialization wins**: Domain-specific solutions better than general models +5. **Open source is growing**: Community-driven innovation is accelerating +6. **Infrastructure is key**: Better infrastructure enables adoption + +The market is still early, with huge growth potential. Companies focusing on solving real pain points while building sustainable business models will capture significant value in this rapidly growing market. + +--- + +*This document provides a comprehensive overview of the current state of LLMs, their challenges, and the opportunities they present. The market is evolving rapidly, with new solutions and opportunities emerging continuously.* diff --git a/docs/REPOSITORY_DOWNLOAD_GUIDE.md b/docs/REPOSITORY_DOWNLOAD_GUIDE.md new file mode 100644 index 0000000..d5983ea --- /dev/null +++ b/docs/REPOSITORY_DOWNLOAD_GUIDE.md @@ -0,0 +1,656 @@ +# Repository Download Guide + +This guide explains how to automatically download GitHub repositories with open licenses for code training using the repository downloader scripts. + +## Overview + +The repository downloader allows you to automatically find and clone GitHub repositories based on: +- **Categories**: Neovim configs, Lua repos, Bash scripts, Zsh configs, Python repos, ethical hacking tools, security tools, and all open-license repos +- **Languages**: Python, JavaScript, Go, Rust, and 15+ more +- **Licenses**: MIT, Apache, BSD, GPL, and other open source licenses +- **Quality**: Filter by minimum stars (popularity) +- **Size Limits**: Automatic stopping when reaching storage limits (default: 1 TB) + +## Scripts + +There are two scripts available: + +1. **`download_all_repos.py`** - Convenience script to download all common categories at once +2. **`download_repos.py`** - Full-featured script with all options and flexibility + +## Quick Start + +### Download All Categories (Recommended) + +The easiest way to download all repository categories: + +```bash +python3 download_all_repos.py +``` + +This will download: +- šŸ“¦ Neovim configurations and plugins +- šŸ“¦ Lua programming repositories +- šŸ“¦ Bash/shell script repositories +- šŸ“¦ Zsh configuration and plugins +- šŸ“¦ Python programming repositories +- šŸ“¦ Ethical hacking and cybersecurity tools + +**Default settings:** +- Max repos per category: 50 +- Min stars: 100 +- Output directory: `data/repos` +- Size limit: 1 TB (1024 GB) +- Shallow clones (faster, less disk space) + +### Download Specific Categories + +```bash +python3 download_repos.py --categories nvim lua bash zsh python hacking --max-repos 50 +``` + +### Download All Open-License Repos + +Download repositories with any open license (any language): + +```bash +python3 download_repos.py --categories all-open --max-repos 1000 --max-size 1024.0 +``` + +### Download by Language + +```bash +python3 download_repos.py --language python --max-repos 100 +``` + +## Installation + +No additional dependencies required! The script uses: +- Python standard library (`urllib`, `json`, `subprocess`) +- `tqdm` (already in requirements.txt) +- `git` (should be installed on your system) + +## Available Categories + +### Neovim (`nvim`) +Neovim configuration files and plugins written in Lua. + +```bash +python3 download_repos.py --categories nvim --max-repos 100 +``` + +**What it searches for:** +- `neovim OR nvim-config OR neovim-config` +- MIT licensed repositories (default) +- 100+ stars minimum (default) + +### Lua (`lua`) +Lua programming language repositories. + +```bash +python3 download_repos.py --categories lua --max-repos 50 +``` + +**What it searches for:** +- Language: Lua +- MIT licensed repositories (default) +- 100+ stars minimum (default) + +### Bash (`bash`) +Bash and shell script repositories. + +```bash +python3 download_repos.py --categories bash --max-repos 50 +``` + +**What it searches for:** +- Language: Shell +- MIT licensed repositories (default) +- 100+ stars minimum (default) + +### Zsh (`zsh`) +Zsh configuration files and plugins (Oh My Zsh, etc.). + +```bash +python3 download_repos.py --categories zsh --max-repos 50 +``` + +**What it searches for:** +- `zsh-config OR oh-my-zsh OR zsh-plugin` +- MIT licensed repositories (default) +- 100+ stars minimum (default) + +### Python (`python`) +Python programming language repositories. + +```bash +python3 download_repos.py --categories python --max-repos 100 +``` + +**What it searches for:** +- Language: Python +- MIT licensed repositories (default) +- 100+ stars minimum (default) + +### Ethical Hacking (`hacking`) +Ethical hacking and cybersecurity tools. + +```bash +python3 download_repos.py --categories hacking --max-repos 100 +``` + +**What it searches for:** +- `ethical-hacking OR cybersecurity OR penetration-testing OR security-tools OR red-team` +- MIT licensed repositories (default) +- 100+ stars minimum (default) + +### Security (`security`) +General security and cybersecurity repositories. + +```bash +python3 download_repos.py --categories security --max-repos 50 +``` + +**What it searches for:** +- `security-tools OR cybersecurity OR penetration-testing OR red-team OR blue-team` +- MIT licensed repositories (default) +- 100+ stars minimum (default) + +### All Open Licenses (`all-open`) +All repositories with open licenses, any language. This is useful for downloading a diverse set of repositories. + +```bash +python3 download_repos.py --categories all-open --max-repos 1000 --max-size 1024.0 +``` + +**What it searches for:** +- Any open-license repository (no language filter) +- No specific license filter (searches all open licenses) +- 100+ stars minimum (default) + +**Note:** This category searches broadly and may return repositories with various licenses. You can still specify `--license` to filter to a specific license type. + +## Command-Line Options + +### `download_repos.py` Options + +```bash +python3 download_repos.py [OPTIONS] +``` + +**Options:** + +- `--output DIR` - Output directory (default: `data/repos`) +- `--categories CAT1 CAT2 ...` - Categories to download: `nvim`, `lua`, `bash`, `zsh`, `python`, `hacking`, `security`, `all-open` +- `--language LANG` - Single language to filter by +- `--languages LANG1 LANG2 ...` - Multiple languages to download +- `--license LICENSE` - License type (default: `mit`) +- `--min-stars N` - Minimum stars (default: 100) +- `--max-repos N` - Maximum repos per category/language (default: 50) +- `--max-size N` - Maximum total size in GB (stops downloading when reached, e.g., `1024.0` for 1 TB) +- `--full-clone` - Do full clone instead of shallow (slower but includes full history) + +### `download_all_repos.py` Options + +```bash +python3 download_all_repos.py [OPTIONS] +``` + +**Options:** + +- `--max-repos N` - Maximum repos per category (default: 50) +- `--min-stars N` - Minimum stars (default: 100) +- `--output DIR` - Output directory (default: `data/repos`) +- `--max-size N` - Maximum total size in GB (default: 1024.0 = 1 TB) +- `--full-clone` - Do full clone instead of shallow + +**Example:** +```bash +python3 download_all_repos.py --max-repos 100 --min-stars 200 --max-size 2048.0 +``` + +### Available Licenses + +- `mit` (default) +- `apache-2.0` +- `bsd-3-clause` +- `bsd-2-clause` +- `isc` +- `unlicense` +- `mpl-2.0` +- `lgpl-2.1` +- `lgpl-3.0` +- `gpl-2.0` +- `gpl-3.0` + +### Available Languages + +- `python` +- `javascript` +- `typescript` +- `java` +- `cpp` +- `c` +- `go` +- `rust` +- `ruby` +- `php` +- `swift` +- `kotlin` +- `scala` +- `r` +- `sql` +- `lua` +- `shell` (for bash/shell scripts) + +## Examples + +### Example 1: Download All Categories (Simple) + +```bash +python3 download_all_repos.py +``` + +Downloads all categories (nvim, lua, bash, zsh, python, hacking) with default settings and 1 TB size limit. + +### Example 2: Download All Categories with Custom Settings + +```bash +python3 download_all_repos.py --max-repos 100 --min-stars 200 --max-size 2048.0 +``` + +Downloads all categories with: +- 100 repos per category +- Minimum 200 stars +- 2 TB size limit + +### Example 3: Download Specific Categories + +```bash +python3 download_repos.py --categories nvim lua bash zsh python hacking --max-repos 50 +``` + +Downloads specific categories with 50 repos each. + +### Example 4: Download All Open-License Repos with Size Limit + +```bash +python3 download_repos.py --categories all-open --max-repos 1000 --max-size 1024.0 +``` + +Downloads up to 1000 repositories with any open license, stopping at 1 TB. + +### Example 5: Download High-Quality Repos + +```bash +python3 download_repos.py --categories nvim lua bash zsh python hacking --min-stars 1000 --max-repos 20 +``` + +Downloads only highly popular repositories (1000+ stars). + +### Example 6: Download Multiple Languages + +```bash +python3 download_repos.py --languages python javascript go rust --max-repos 50 +``` + +Downloads repositories in multiple programming languages. + +### Example 7: Download with Apache License + +```bash +python3 download_repos.py --categories nvim --license apache-2.0 --max-repos 50 +``` + +Downloads Neovim repos with Apache 2.0 license. + +### Example 8: Custom Output Directory + +```bash +python3 download_repos.py --categories nvim lua bash zsh python hacking --output /path/to/repos +``` + +Saves repositories to a custom directory. + +### Example 9: Full Clone (with History) + +```bash +python3 download_repos.py --categories nvim --full-clone --max-repos 10 +``` + +Does full clone including full git history (slower but more complete). + +### Example 10: Size-Limited Download + +```bash +python3 download_repos.py --categories all-open --max-repos 2000 --max-size 512.0 +``` + +Downloads repositories but stops when reaching 512 GB (0.5 TB). + +## Progress Tracking + +The scripts include visual progress bars showing: + +- **Category progress**: Overall progress across all categories +- **Repository progress**: Progress for each category +- **Real-time statistics**: Current repo, stars, language, cloned/failed counts +- **Size tracking**: Current total size and size limit (when `--max-size` is used) + +**Example output:** + +```text +šŸ“Š Current directory size: 45.23 GB +šŸ“Š Size limit: 1024.00 GB +šŸ“¦ Processing 6 categories... +Category: nvim: 100%|ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 6/6 [15:23<00:00, Size=156.78 GB, Total Cloned=300, Total Failed=2] +Cloning nvim: 45%|ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–Œ | 23/50 [02:15<03:45, Current=awesome-nvim, Stars=5.2k, Lang=Lua, Cloned=22, Failed=1, Size=12.45 GB] +``` + +**Size limit reached:** + +When the size limit is reached, the script will stop downloading and show: + +```text +āš ļø Size limit reached: 1024.00 GB >= 1024.00 GB + Stopping all downloads. +``` + +## GitHub API Rate Limits + +GitHub API has rate limits: +- **Unauthenticated**: 60 requests/hour +- **Authenticated**: 5,000 requests/hour + +### Using a GitHub Token + +To increase rate limits, set a GitHub Personal Access Token: + +```bash +export GITHUB_TOKEN=your_token_here +python3 download_repos.py --categories nvim lua bash hacking +``` + +**How to create a token:** +1. Go to GitHub Settings → Developer settings → Personal access tokens +2. Generate new token (classic) +3. Select scope: `public_repo` (read-only is enough) +4. Copy token and set as environment variable + +## Size Limits + +The repository downloader includes automatic size limit checking to prevent running out of disk space. + +### How It Works + +- **Default limit**: 1 TB (1024 GB) for `download_all_repos.py` +- **Customizable**: Use `--max-size` to set any limit +- **Real-time tracking**: Size is checked before each repository clone +- **Automatic stopping**: Downloads stop when limit is reached +- **Progress display**: Current size shown in progress bars + +### Setting Size Limits + +**With `download_all_repos.py`:** +```bash +# Default 1 TB +python3 download_all_repos.py + +# Custom limit (2 TB) +python3 download_all_repos.py --max-size 2048.0 + +# Smaller limit (500 GB) +python3 download_all_repos.py --max-size 512.0 +``` + +**With `download_repos.py`:** +```bash +# No limit (downloads until max-repos reached) +python3 download_repos.py --categories nvim --max-repos 100 + +# With 1 TB limit +python3 download_repos.py --categories nvim --max-repos 1000 --max-size 1024.0 +``` + +### Size Calculation + +The script calculates total size by: +- Scanning all files in the output directory (`data/repos` by default) +- Summing file sizes recursively +- Checking before each new repository clone +- Displaying human-readable sizes (B, KB, MB, GB, TB) + +**Note:** Size checking happens before cloning, so the actual size may be slightly less than the limit when stopping. + +## Cache and Resuming + +The scripts automatically: + +- **Skips existing repos**: If a repository already exists, it's skipped (no re-download) +- **Resumes downloads**: You can run the script multiple times safely +- **Progress tracking**: Shows what's already downloaded +- **Size awareness**: Accounts for existing repositories when checking size limits + +After downloading repositories, they're automatically processed during training: + +```bash +# Download repos +python3 download_all_repos.py + +# Train with all data (text + code) +python3 train.py --data data/ --config config.json --device cuda +``` + +The training script will: +1. Process all your text data (Wiki, Books, Amazon reviews, etc.) +2. Process all code repositories +3. Combine everything into training data + +## Supported File Types + +The data processor automatically handles code files from repositories: + +- **Text files**: `.txt`, `.md`, `.rst`, `.log`, `.csv`, `.json`, `.jsonl`, `.xml`, `.html`, `.htm` +- **Code files**: `.py`, `.js`, `.ts`, `.java`, `.cpp`, `.c`, `.go`, `.rs`, `.rb`, `.php`, `.swift`, `.lua`, `.sh`, and 30+ more +- **PDF files**: `.pdf` (if pdfplumber is installed) +- **Images**: `.png`, `.jpg`, etc. (if OCR is set up) + +## Troubleshooting + +### Rate Limit Exceeded + +**Error:** `Rate limit exceeded` + +**Solution:** +1. Wait a few minutes and try again +2. Use a GitHub token: `export GITHUB_TOKEN=your_token` +3. Reduce `--max-repos` to download fewer repos per run + +### Repository Clone Fails + +**Error:** `Failed to clone repository` + +**Possible causes:** +- Repository was deleted or made private +- Network issues +- Repository is too large (timeout) + +**Solution:** +- The script continues with other repos +- Failed repos are counted and reported at the end +- You can re-run the script to retry failed repos + +### No Repositories Found + +**Error:** `No repositories found` + +**Possible causes:** +- Search query too restrictive +- License filter too narrow +- Minimum stars too high + +**Solution:** +- Lower `--min-stars` threshold +- Try different `--license` options +- Check if category name is correct + +## Best Practices + +### 1. Start Small + +Test with a small number first: + +```bash +python3 download_repos.py --categories nvim --max-repos 10 +``` + +### 2. Use Size Limits + +Always set a size limit to prevent running out of disk space: + +```bash +# Recommended: 1 TB limit +python3 download_all_repos.py --max-size 1024.0 + +# Or custom limit based on available space +python3 download_repos.py --categories all-open --max-size 512.0 +``` + +### 3. Use Shallow Clones + +Shallow clones are faster and use less disk space: + +```bash +# Default (shallow clone) +python3 download_repos.py --categories nvim + +# Full clone (only if you need history) +python3 download_repos.py --categories nvim --full-clone +``` + +### 4. Filter by Quality + +Use `--min-stars` to get quality repositories: + +```bash +python3 download_repos.py --categories nvim --min-stars 500 --max-repos 50 +``` + +### 5. Use GitHub Token + +For large downloads, use a GitHub token: + +```bash +export GITHUB_TOKEN=your_token_here +python3 download_all_repos.py --max-repos 100 +``` + +### 6. Monitor Disk Space + +Check available disk space before starting: + +```bash +df -h data/repos +``` + +### 7. Use `all-open` Category Wisely + +The `all-open` category downloads broadly. Consider: +- Setting a reasonable `--max-repos` limit +- Using `--min-stars` to filter quality +- Setting `--max-size` to prevent excessive downloads + +```bash +python3 download_repos.py --categories all-open --max-repos 500 --min-stars 200 --max-size 1024.0 +``` + +## Storage Considerations + +### Size Limits and Disk Space Management + +- **Default**: 1 TB (1024 GB) for `download_all_repos.py` +- **Recommended**: Set based on available disk space +- **Monitoring**: Script shows current size vs limit in progress bars + +### Shallow vs Full Clones + +**Shallow clones (default):** +- Faster download +- Less disk space (~10-50% of full clone) +- No git history +- Good for training data + +**Full clones:** +- Slower download +- More disk space (includes full history) +- Includes full git history +- Useful if you need version history + +**Typical sizes (shallow clones):** +- Small repo: 1-10 MB +- Medium repo: 10-100 MB +- Large repo: 100 MB - 1 GB +- Very large repo: 1-10 GB+ + +**Example:** Downloading 300 repositories with shallow clones typically uses 5-30 GB, depending on repository sizes. + +### Estimating Storage Needs + +To estimate how many repositories you can download: + +1. **Check current size:** + ```bash + du -sh data/repos + ``` + +2. **Calculate average repo size:** + - Small repos: ~5 MB average + - Medium repos: ~50 MB average + - Large repos: ~500 MB average + +3. **Estimate:** + - 100 small repos: ~500 MB + - 100 medium repos: ~5 GB + - 100 large repos: ~50 GB + - 1000 mixed repos: ~50-200 GB + +4. **Set appropriate limit:** + ```bash + # For 1 TB available space, use 900 GB limit (leave buffer) + python3 download_all_repos.py --max-size 900.0 + ``` + +## Summary + +The repository downloader makes it easy to: +- āœ… Automatically find high-quality open-source repositories +- āœ… Filter by category, language, license, and popularity +- āœ… Download with progress tracking and size monitoring +- āœ… Set size limits to prevent running out of disk space +- āœ… Integrate seamlessly with training pipeline +- āœ… Resume interrupted downloads + +**Available categories:** +- `nvim` - Neovim configurations and plugins +- `lua` - Lua programming repositories +- `bash` - Bash/shell script repositories +- `zsh` - Zsh configuration and plugins +- `python` - Python programming repositories +- `hacking` - Ethical hacking and cybersecurity tools +- `security` - Security and cybersecurity repositories +- `all-open` - All repositories with open licenses (any language) + +**Quick commands to get started:** + +```bash +# Download all categories with 1 TB limit (recommended) +python3 download_all_repos.py + +# Download specific categories +python3 download_repos.py --categories nvim lua bash zsh python hacking --max-repos 50 + +# Download all open-license repos with size limit +python3 download_repos.py --categories all-open --max-repos 1000 --max-size 1024.0 +``` + +This downloads repositories and prepares them for training! \ No newline at end of file diff --git a/docs/RETRAINING_GUIDE.md b/docs/RETRAINING_GUIDE.md new file mode 100644 index 0000000..cfcb48a --- /dev/null +++ b/docs/RETRAINING_GUIDE.md @@ -0,0 +1,143 @@ +# Retraining Guide for Better Model Performance + +## Current Issues + +Your model only trained for **40 global steps** across 10 epochs, which means: +- Very little training data (~4 batches per epoch) +- Model hasn't learned language patterns +- Model just repeats input and stops + +## Retraining Recommendations + +### 1. **Increase Training Data** + +The model needs much more data. Check your current data: + +```bash +# Check how much data you have +wc -l data/*.txt +``` + +**Recommendations:** +- **Minimum**: 10,000+ text samples +- **Good**: 100,000+ text samples +- **Better**: 1,000,000+ text samples + +### 2. **Update Training Configuration** + +Edit `config.json` for better training: + +```json +{ + "training": { + "batch_size": 32, + "max_epochs": 50, // Increase from 10 to 50+ + "learning_rate": 1e-4, + "weight_decay": 0.01, + "warmup_steps": 1000, + "max_grad_norm": 1.0, + "gradient_accumulation_steps": 4, // Increase to simulate larger batches + "use_amp": true, + "save_dir": "./checkpoints", + "log_interval": 10, // More frequent logging + "eval_interval": 500 // More frequent evaluation + } +} +``` + +### 3. **Add Validation Set** + +Split your data for validation: + +```python +# In train.py, add validation split +from sklearn.model_selection import train_test_split + +train_texts, val_texts = train_test_split(texts, test_size=0.1, random_state=42) +``` + +### 4. **Improve Training Data Quality** + +Ensure your training data: +- āœ… Contains complete sentences/paragraphs +- āœ… Has diverse topics and styles +- āœ… Doesn't have excessive padding +- āœ… Uses proper text formatting + +### 5. **Monitor Training** + +Watch for: +- **Loss decreasing**: Should trend downward +- **Perplexity**: Should decrease (lower is better) +- **Generation quality**: Test periodically during training + +### 6. **Training Command** + +```bash +# Train with more data +python3 train.py \ + --data data/your_training_data.txt \ + --config config.json \ + --output ./checkpoints \ + --device cpu # or cuda/mps +``` + +### 7. **Check Training Progress** + +During training, you should see: +``` +Epoch 1: Train Loss = 8.5 → Epoch 10: Train Loss = 6.0 → Epoch 50: Train Loss = 3.5 +``` + +If loss stops decreasing, the model has converged. + +### 8. **Early Stopping** + +Consider adding early stopping if validation loss plateaus: +- Stop if validation loss doesn't improve for 5 epochs +- Save the best model based on validation loss + +### 9. **Test During Training** + +After each epoch, test generation: + +```bash +python3 inference.py \ + --checkpoint checkpoints/checkpoint_epoch_X.pt \ + --prompt "The future of" \ + --optimized +``` + +Good training should show: +- āœ… Model generates coherent text +- āœ… Model continues beyond input prompt +- āœ… Model doesn't immediately generate padding tokens + +## Quick Start Retraining + +1. **Get more training data** (most important!) +2. **Update config.json** with more epochs +3. **Start training**: + ```bash + python3 train.py --data data/your_data.txt --config config.json + ``` +4. **Monitor loss** - should decrease over time +5. **Test periodically** - check if generation improves + +## Expected Results + +After proper training: +- Loss should decrease from ~8-10 to ~2-4 +- Perplexity should decrease from ~3000 to ~10-50 +- Model should generate 50+ tokens before stopping +- Generated text should be coherent and diverse + +## Next Steps + +1. āœ… Early stopping is now fixed (prevents padding tokens) +2. ā³ **Retrain with more data and epochs** +3. ā³ Monitor training metrics +4. ā³ Test generation quality during training + +Good luck with retraining! šŸš€ + diff --git a/docs/SCHEDULING_EXPLAINED.md b/docs/SCHEDULING_EXPLAINED.md new file mode 100644 index 0000000..fc1c38f --- /dev/null +++ b/docs/SCHEDULING_EXPLAINED.md @@ -0,0 +1,618 @@ +# What is Scheduling? Step-by-Step Explanation + +Complete step-by-step explanation of learning rate scheduling: how scheduling adjusts learning rates during training to improve convergence. + +## Table of Contents + +1. [What is Scheduling?](#81-what-is-scheduling) +2. [Why Do We Need Scheduling?](#82-why-do-we-need-scheduling) +3. [Fixed Learning Rate](#83-fixed-learning-rate) +4. [Cosine Annealing](#84-cosine-annealing) +5. [Other Scheduling Strategies](#85-other-scheduling-strategies) +6. [Why Scheduling Matters](#86-why-scheduling-matters) +7. [Complete Mathematical Formulation](#87-complete-mathematical-formulation) +8. [Exercise: Schedule Calculation](#88-exercise-schedule-calculation) +9. [Key Takeaways](#89-key-takeaways) + +--- + +## 8.1 What is Scheduling? + +### Simple Definition + +**Scheduling** (learning rate scheduling) is the process of adjusting the learning rate during training to improve convergence and final model performance. + +### Visual Analogy + +**Think of scheduling like adjusting speed while driving:** + +``` +Fixed Learning Rate: + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Speed: 60 mph (constant) │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + → Hard to stop precisely! + +Scheduled Learning Rate: + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + │ Speed: 60 → 40 → 20 → 10 │ + ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + → Smooth deceleration! +``` + +**Scheduling adjusts speed (learning rate) as you approach the destination (convergence)!** + +### What Scheduling Does + +**Scheduling:** +1. **Starts** with higher learning rate (fast learning) +2. **Gradually reduces** learning rate (precise fine-tuning) +3. **Converges** to optimal solution + +**Result:** Better convergence and performance! + +--- + +## 8.2 Why Do We Need Scheduling? + +### The Problem with Fixed Learning Rate + +**High Learning Rate:** +``` +Learning Rate: 0.001 (constant) +→ Fast initial learning āœ“ +→ But overshoots minimum āœ— +→ Bounces around āœ— +→ Poor convergence āœ— +``` + +**Low Learning Rate:** +``` +Learning Rate: 0.0001 (constant) +→ Stable convergence āœ“ +→ But very slow learning āœ— +→ Takes forever to converge āœ— +``` + +**Can't have both!** + +### The Solution: Scheduling + +**Adaptive Learning Rate:** +``` +Start: 0.001 (fast learning) +Middle: 0.0005 (moderate) +End: 0.0001 (fine-tuning) +→ Fast initial learning āœ“ +→ Stable convergence āœ“ +→ Best of both worlds! +``` + +### Benefits of Scheduling + +**1. Faster Convergence** +- High initial rate = Fast progress +- Lower later rate = Precise convergence + +**2. Better Final Performance** +- Fine-tuning at end = Better solution +- Avoids overshooting = More stable + +**3. More Stable Training** +- Gradual reduction = Smooth optimization +- Less oscillation = More reliable + +--- + +## 8.3 Fixed Learning Rate + +### What is Fixed Learning Rate? + +**Learning rate stays constant throughout training:** + +```math +\eta_t = \eta_0 \quad \text{for all } t +``` + +**Where:** +- $\eta_0$ = initial learning rate +- $t$ = training step + +### Example + +**Fixed Rate:** +``` +Step 0: Ī· = 0.001 +Step 100: Ī· = 0.001 +Step 1000: Ī· = 0.001 +Step 10000: Ī· = 0.001 +``` + +**Constant throughout!** + +### Visualization + +``` +Learning Rate + │ +0.001│───────────────────────────────────── + │ + │ + │ + │ + └───────────────────────────────────── Steps +``` + +### Problems + +**1. Too High:** +- Overshoots minimum +- Oscillates around solution +- Never converges precisely + +**2. Too Low:** +- Very slow training +- Takes forever to converge +- May get stuck + +**Solution:** Use scheduling! + +--- + +## 8.4 Cosine Annealing + +### What is Cosine Annealing? + +**Cosine Annealing** reduces the learning rate following a cosine curve from maximum to minimum. + +### Formula + +```math +\eta_t = \eta_{min} + (\eta_{max} - \eta_{min}) \times \frac{1 + \cos\left(\frac{\pi t}{T_{max}}\right)}{2} +``` + +**Where:** +- $\eta_t$ = learning rate at step $t$ +- $\eta_{min}$ = minimum learning rate (default: 0) +- $\eta_{max}$ = initial/maximum learning rate +- $T_{max}$ = total number of steps +- $t$ = current step + +### How It Works + +**Step 1: Calculate Cosine Value** +```math +\cos\left(\frac{\pi t}{T_{max}}\right) +``` + +**Step 2: Shift to [0, 1] Range** +```math +\frac{1 + \cos\left(\frac{\pi t}{T_{max}}\right)}{2} +``` + +**Step 3: Scale to Learning Rate Range** +```math +\eta_{min} + (\eta_{max} - \eta_{min}) \times \text{scale} +``` + +### Example Calculation + +**Given:** +- $\eta_{max} = 0.001$ +- $\eta_{min} = 0$ +- $T_{max} = 10000$ + +**At step $t = 0$:** +```math +\eta_0 = 0 + (0.001 - 0) \times \frac{1 + \cos(0)}{2} = 0.001 \times 1 = 0.001 +``` + +**At step $t = 2500$:** +```math +\eta_{2500} = 0 + 0.001 \times \frac{1 + \cos(\pi/4)}{2} = 0.001 \times \frac{1 + 0.707}{2} \approx 0.000854 +``` + +**At step $t = 5000$:** +```math +\eta_{5000} = 0 + 0.001 \times \frac{1 + \cos(\pi/2)}{2} = 0.001 \times \frac{1 + 0}{2} = 0.0005 +``` + +**At step $t = 7500$:** +```math +\eta_{7500} = 0 + 0.001 \times \frac{1 + \cos(3\pi/4)}{2} = 0.001 \times \frac{1 + (-0.707)}{2} \approx 0.000146 +``` + +**At step $t = 10000$:** +```math +\eta_{10000} = 0 + 0.001 \times \frac{1 + \cos(\pi)}{2} = 0.001 \times \frac{1 + (-1)}{2} = 0 +``` + +### Visualization + +``` +Learning Rate + │ +0.001 ā”‚ā—ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€\ + │ \ + │ \ +0.0005│ \ + │ \ + │ \ + │ \ + │ \ + │ \ + │ \ + 0│ ā—ā”€ā”€ā”€ā”€ā”€ + └───────────────────────────────────── Steps + 0 2500 5000 7500 10000 +``` + +**Smooth cosine curve!** + +### Why Cosine Annealing? + +**Benefits:** +1. **Smooth decay:** No abrupt changes +2. **Gradual reduction:** Better fine-tuning +3. **Works well:** Commonly used in practice +4. **High initial rate:** Fast learning +5. **Low final rate:** Precise convergence + +--- + +## 8.5 Other Scheduling Strategies + +### 1. Step Decay + +**Reduce learning rate at fixed intervals:** + +```math +\eta_t = \eta_0 \times \gamma^{\lfloor t / s \rfloor} +``` + +**Where:** +- $\gamma$ = decay factor (e.g., 0.1) +- $s$ = step size (e.g., every 1000 steps) + +**Example:** +``` +Step 0-999: Ī· = 0.001 +Step 1000-1999: Ī· = 0.0001 (Ɨ0.1) +Step 2000-2999: Ī· = 0.00001 (Ɨ0.1) +``` + +**Visualization:** +``` +Learning Rate + │ +0.001 │───────┐ + │ │ + │ └───────┐ +0.0001│ │ + │ └───────┐ + │ │ + └───────────────────────── Steps +``` + +### 2. Exponential Decay + +**Continuous exponential reduction:** + +```math +\eta_t = \eta_0 \times \gamma^t +``` + +**Where:** +- $\gamma$ = decay rate (e.g., 0.9995) + +**Visualization:** +``` +Learning Rate + │ +0.001ā”‚ā—ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€\ + │ \ + │ \ + │ \ + │ \ + │ \ + │ \ + │ \ + └──────────────────────── Steps +``` + +### 3. Warmup Scheduling + +**Start with low rate, increase, then decrease:** + +**Warmup Phase:** +```math +\eta_t = \eta_{max} \times \frac{t}{T_{warmup}} +``` + +**After Warmup:** +```math +\eta_t = \text{Cosine Annealing or other schedule} +``` + +**Visualization:** +``` +Learning Rate + │ +0.001│ ╱───────\ + │ ╱ \ + │ ╱ \ + │ ╱ \ + │ ╱ \ + │ ╱ \ + │╱ \ + └───────────────────── Steps +``` + +### 4. One Cycle Learning Rate + +**One cycle: increase then decrease:** + +```math +\eta_t = \begin{cases} +\eta_{min} + (\eta_{max} - \eta_{min}) \times \frac{t}{T_1} & t \leq T_1 \\ +\eta_{max} - (\eta_{max} - \eta_{min}) \times \frac{t - T_1}{T_2} & t > T_1 +\end{cases} +``` + +**Visualization:** +``` +Learning Rate + │ +0.001│ ╱─────\ + │ ╱ \ + │ ╱ \ + │ ╱ \ + │ ╱ \ + │ ╱ \ + │╱ \ + └─────────────────── Steps +``` + +--- + +## 8.6 Why Scheduling Matters + +### Benefit 1: Better Convergence + +**Without Scheduling:** +``` +Loss: 3.0 → 2.5 → 2.3 → 2.2 → 2.15 → 2.12 → ... + (slow convergence at end) +``` + +**With Scheduling:** +``` +Loss: 3.0 → 2.5 → 2.3 → 2.2 → 2.1 → 2.05 → ... + (faster convergence, better final loss) +``` + +### Benefit 2: More Stable Training + +**Fixed High Rate:** +``` +Loss: 3.0 → 2.5 → 2.3 → 2.4 → 2.3 → 2.4 → ... + (oscillating, unstable) +``` + +**Scheduled Rate:** +``` +Loss: 3.0 → 2.5 → 2.3 → 2.2 → 2.15 → 2.12 → ... + (smooth, stable) +``` + +### Benefit 3: Better Final Performance + +**Comparison:** +``` +Fixed LR: Final Loss = 2.15 +Scheduled LR: Final Loss = 2.05 + +→ 5% improvement! +``` + +--- + +## 8.7 Complete Mathematical Formulation + +### General Scheduling Formula + +```math +\eta_t = f(t, \eta_0, \eta_{min}, T_{max}, ...) +``` + +**Where $f$ is the scheduling function** + +### Cosine Annealing (Complete) + +```math +\eta_t = \eta_{min} + (\eta_{max} - \eta_{min}) \times \frac{1 + \cos\left(\frac{\pi t}{T_{max}}\right)}{2} +``` + +**Boundary Conditions:** +- At $t = 0$: $\eta_0 = \eta_{max}$ +- At $t = T_{max}$: $\eta_{T_{max}} = \eta_{min}$ + +### Step Decay + +```math +\eta_t = \eta_0 \times \gamma^{\lfloor t / s \rfloor} +``` + +### Exponential Decay + +```math +\eta_t = \eta_0 \times \gamma^t +``` + +### Warmup + Cosine Annealing + +**Warmup Phase ($t \leq T_{warmup}$):** +```math +\eta_t = \eta_{max} \times \frac{t}{T_{warmup}} +``` + +**Annealing Phase ($t > T_{warmup}$):** +```math +\eta_t = \eta_{min} + (\eta_{max} - \eta_{min}) \times \frac{1 + \cos\left(\frac{\pi (t - T_{warmup})}{T_{max} - T_{warmup}}\right)}{2} +``` + +--- + +## 8.8 Exercise: Schedule Calculation + +### Problem + +**Given Cosine Annealing schedule:** + +- $\eta_{max} = 0.002$ +- $\eta_{min} = 0.0001$ +- $T_{max} = 5000$ steps + +**Calculate the learning rate at:** +1. Step $t = 0$ +2. Step $t = 1250$ +3. Step $t = 2500$ +4. Step $t = 3750$ +5. Step $t = 5000$ + +### Step-by-Step Solution + +#### General Formula + +```math +\eta_t = \eta_{min} + (\eta_{max} - \eta_{min}) \times \frac{1 + \cos\left(\frac{\pi t}{T_{max}}\right)}{2} +``` + +**Substitute values:** +```math +\eta_t = 0.0001 + (0.002 - 0.0001) \times \frac{1 + \cos\left(\frac{\pi t}{5000}\right)}{2} +``` + +```math +\eta_t = 0.0001 + 0.0019 \times \frac{1 + \cos\left(\frac{\pi t}{5000}\right)}{2} +``` + +#### Step 1: t = 0 + +```math +\eta_0 = 0.0001 + 0.0019 \times \frac{1 + \cos(0)}{2} +``` + +```math +\eta_0 = 0.0001 + 0.0019 \times \frac{1 + 1}{2} +``` + +```math +\eta_0 = 0.0001 + 0.0019 \times 1 = 0.0001 + 0.0019 = 0.002 +``` + +**Answer:** $\eta_0 = 0.002$ + +#### Step 2: t = 1250 + +```math +\eta_{1250} = 0.0001 + 0.0019 \times \frac{1 + \cos(\pi/4)}{2} +``` + +```math +\eta_{1250} = 0.0001 + 0.0019 \times \frac{1 + 0.707}{2} +``` + +```math +\eta_{1250} = 0.0001 + 0.0019 \times 0.8535 = 0.0001 + 0.001621 = 0.001721 +``` + +**Answer:** $\eta_{1250} \approx 0.001721$ + +#### Step 3: t = 2500 + +```math +\eta_{2500} = 0.0001 + 0.0019 \times \frac{1 + \cos(\pi/2)}{2} +``` + +```math +\eta_{2500} = 0.0001 + 0.0019 \times \frac{1 + 0}{2} +``` + +```math +\eta_{2500} = 0.0001 + 0.0019 \times 0.5 = 0.0001 + 0.00095 = 0.00105 +``` + +**Answer:** $\eta_{2500} = 0.00105$ + +#### Step 4: t = 3750 + +```math +\eta_{3750} = 0.0001 + 0.0019 \times \frac{1 + \cos(3\pi/4)}{2} +``` + +```math +\eta_{3750} = 0.0001 + 0.0019 \times \frac{1 + (-0.707)}{2} +``` + +```math +\eta_{3750} = 0.0001 + 0.0019 \times 0.1465 = 0.0001 + 0.000278 = 0.000378 +``` + +**Answer:** $\eta_{3750} \approx 0.000378$ + +#### Step 5: t = 5000 + +```math +\eta_{5000} = 0.0001 + 0.0019 \times \frac{1 + \cos(\pi)}{2} +``` + +```math +\eta_{5000} = 0.0001 + 0.0019 \times \frac{1 + (-1)}{2} +``` + +```math +\eta_{5000} = 0.0001 + 0.0019 \times 0 = 0.0001 + 0 = 0.0001 +``` + +**Answer:** $\eta_{5000} = 0.0001$ + +### Summary Table + +| Step | Cosine Value | Scale Factor | Learning Rate | +|------|--------------|--------------|---------------| +| 0 | 1.0 | 1.0 | 0.002 | +| 1250 | 0.707 | 0.854 | 0.001721 | +| 2500 | 0.0 | 0.5 | 0.00105 | +| 3750 | -0.707 | 0.146 | 0.000378 | +| 5000 | -1.0 | 0.0 | 0.0001 | + +**Smooth decay from 0.002 to 0.0001!** + +--- + +## 8.9 Key Takeaways + +### Scheduling + +āœ… **Scheduling adjusts learning rate during training** +āœ… **Starts high (fast learning), ends low (fine-tuning)** +āœ… **Improves convergence and final performance** + +### Cosine Annealing + +āœ… **Smooth cosine-based decay** +āœ… **Gradual reduction from max to min** +āœ… **Works well for transformers** + +### Why Important + +āœ… **Faster convergence** +āœ… **More stable training** +āœ… **Better final performance** +āœ… **Essential for optimal training** + +--- + +*This document provides a comprehensive explanation of learning rate scheduling, including cosine annealing and other strategies with mathematical formulations and solved exercises.* + diff --git a/docs/TOKENIZATION_EXPLAINED.md b/docs/TOKENIZATION_EXPLAINED.md new file mode 100644 index 0000000..98af8e9 --- /dev/null +++ b/docs/TOKENIZATION_EXPLAINED.md @@ -0,0 +1,978 @@ +# Tokenization Explained - Mathematical Formulation + +Complete mathematical derivation and step-by-step explanation of tokenization, Byte Pair Encoding (BPE), and UTF-8 encoding used in the SheepOp Language Model. + +## Table of Contents + +1. [Introduction to Tokenization](#1-introduction-to-tokenization) +2. [UTF-8 Encoding](#2-utf-8-encoding) +3. [Byte Pair Encoding Algorithm](#3-byte-pair-encoding-algorithm) +4. [Vocabulary Construction](#4-vocabulary-construction) +5. [Encoding Process](#5-encoding-process) +6. [Decoding Process](#6-decoding-process) +7. [Regex Pattern Splitting](#7-regex-pattern-splitting) +8. [Special Tokens](#8-special-tokens) +9. [Complete Tokenization Pipeline](#9-complete-tokenization-pipeline) +10. [Tokenization Challenges and Solutions](#10-tokenization-challenges-and-solutions) + +--- + +## 1. Introduction to Tokenization + +### 1.1 What is Tokenization? + +**Tokenization** is the process of converting raw text into a sequence of discrete tokens (integers) that can be processed by neural networks. + +**Mathematical Definition:** + +Given a text string $s \in \Sigma^*$ where $\Sigma$ is the character alphabet, tokenization maps: + +```math +\mathcal{T}: \Sigma^* \rightarrow \mathbb{N}^* +``` + +```math +s = c_1 c_2 \ldots c_n \mapsto \mathbf{t} = [t_1, t_2, \ldots, t_m] +``` + +where: + +- $s$ = input text string +- $c_i$ = individual characters +- $\mathbf{t}$ = sequence of token IDs +- $n$ = number of characters +- $m$ = number of tokens (typically $m \leq n$) + +### 1.2 Why Tokenization? + +**Problem:** Neural networks require numerical inputs, not raw text. + +**Solution:** Convert text to token sequences: + +```mermaid +graph LR + A["Raw Text
'Hello world'"] --> B[Tokenizer] + B --> C["Token IDs
[15496, 1917]"] + C --> D[Embedding Layer] + D --> E["Embeddings
ā„^nƗd"] + + style A fill:#e1f5ff + style B fill:#fff4e1 + style C fill:#e1ffe1 + style E fill:#ffe1f5 +``` + +### 1.3 Tokenization Approaches + +**Three Main Approaches:** + +1. **Character-Level**: Each character becomes a token + - Vocabulary size: ~100-200 + - Sequence length: Very long + +2. **Word-Level**: Each word becomes a token + - Vocabulary size: ~50,000-100,000 + - Sequence length: Moderate + +3. **Subword-Level (BPE)**: Sequences of bytes/characters become tokens + - Vocabulary size: ~30,000-100,000 + - Sequence length: Efficient + +```mermaid +graph TB + subgraph "Character-Level" + C1["'Hello'"] --> C2["['H','e','l','l','o']"] + C2 --> C3["5 tokens"] + end + + subgraph "Word-Level" + W1["'Hello world'"] --> W2["['Hello', 'world']"] + W2 --> W3["2 tokens"] + end + + subgraph "Subword-Level (BPE)" + B1["'Hello world'"] --> B2["['Hello', ' world']"] + B2 --> B3["2 tokens"] + end + + style C3 fill:#ffcccc + style W3 fill:#ccffcc + style B3 fill:#ccffcc +``` + +--- + +## 2. UTF-8 Encoding + +### 2.1 Unicode Code Points + +**Unicode** defines a mapping from characters to integers (code points): + +```math +f: \mathcal{C} \rightarrow \{0, 1, \ldots, 0x10FFFF\} +``` + +where $\mathcal{C}$ is the set of all Unicode characters. + +**Example:** + +```math +f('H') = 72, \quad f('ε') = 949, \quad f('äø­') = 20013 +``` + +### 2.2 UTF-8 Encoding Function + +**UTF-8** encodes Unicode code points into variable-length byte sequences: + +```math +\text{UTF-8}: \{0, \ldots, 0x10FFFF\} \rightarrow \{0, \ldots, 255\}^* +``` + +**Encoding Rules:** + +For a code point $c$: + +```math +\text{UTF-8}(c) = \begin{cases} +[b_0] & \text{if } c < 128 \\ +[b_0, b_1] & \text{if } c < 2048 \\ +[b_0, b_1, b_2] & \text{if } c < 65536 \\ +[b_0, b_1, b_2, b_3] & \text{if } c < 0x10FFFF +\end{cases} +``` + +where $b_i \in \{0, \ldots, 255\}$ are bytes. + +### 2.3 UTF-8 Encoding Process + +```mermaid +graph TB + A["Unicode String
'Hello äø–ē•Œ'"] --> B[Extract Code Points] + B --> C["[72, 101, 108, 108, 111, 32, 19990, 30028]"] + C --> D[UTF-8 Encode] + D --> E["Bytes
[72, 101, 108, 108, 111, 32, 228, 184, 150, 231, 149, 140]"] + + style A fill:#e1f5ff + style E fill:#e1ffe1 +``` + +**Mathematical Formulation:** + +For a string $s = c_1 c_2 \ldots c_n$: + +```math +\text{bytes}(s) = \bigoplus_{i=1}^n \text{UTF-8}(f(c_i)) +``` + +where $\bigoplus$ denotes byte concatenation. + +**Example:** + +```math +\text{bytes}("Hi") = \text{UTF-8}(72) \oplus \text{UTF-8}(105) = [72] \oplus [105] = [72, 105] +``` + +--- + +## 3. Byte Pair Encoding Algorithm + +### 3.1 BPE Overview + +**Byte Pair Encoding (BPE)** is a data compression algorithm that iteratively merges the most frequent consecutive pairs. + +**Goal:** Create an efficient vocabulary by merging frequent byte/character pairs. + +### 3.2 BPE Training Algorithm + +**Initial State:** + +```math +V^{(0)} = \{0, 1, \ldots, 255\} \quad \text{(all bytes)} +``` + +```math +\text{tokens}^{(0)} = \text{bytes}(s) = [b_1, b_2, \ldots, b_n] +``` + +**Iterative Merging:** + +For iteration $k = 1, 2, \ldots, K$: + +**Step 1: Calculate Pair Frequencies** + +```math +\text{stats}^{(k)} = \text{CountPairs}(\text{tokens}^{(k-1)}) +``` + +```math +\text{stats}^{(k)}(i, j) = |\{p : \text{tokens}^{(k-1)}_p = i \land \text{tokens}^{(k-1)}_{p+1} = j\}| +``` + +**Step 2: Find Most Frequent Pair** + +```math +(i^*, j^*) = \arg\max_{(i,j)} \text{stats}^{(k)}(i, j) +``` + +**Step 3: Create New Token** + +```math +V^{(k)} = V^{(k-1)} \cup \{256 + k - 1\} +``` + +```math +\text{merges}^{(k)} = \text{merges}^{(k-1)} \cup \{(i^*, j^*) \mapsto 256 + k - 1\} +``` + +**Step 4: Apply Merge** + +```math +\text{tokens}^{(k)} = \text{Merge}(\text{tokens}^{(k-1)}, (i^*, j^*), 256 + k - 1) +``` + +where: + +```math +\text{Merge}(T, (a, b), \text{new\_id})_i = \begin{cases} +\text{new\_id} & \text{if } T_i = a \land T_{i+1} = b \\ +T_i & \text{otherwise} +\end{cases} +``` + +### 3.3 BPE Training Flowchart + +```mermaid +graph TB + Start[Start Training] --> Init["Initialize
V⁽⁰⁾ = {0...255}
tokens⁽⁰⁾ = bytes(text)"] + Init --> Iter{More Merges?} + + Iter -->|Yes| Stats["Calculate Pair Frequencies
statsā½įµā¾(i,j)"] + Stats --> Find["Find Most Frequent Pair
(i*, j*) = argmax stats"] + Find --> Create["Create New Token
id = 256 + k"] + Create --> Merge["Merge All Occurrences
tokensā½įµā¾ = Merge(tokensā½įµā»Ā¹ā¾, (i*,j*), id)"] + Merge --> Update["Update Vocabulary
Vā½įµā¾ = Vā½įµā»Ā¹ā¾ ∪ {id}"] + Update --> Iter + + Iter -->|No| End[Training Complete] + + style Start fill:#e1f5ff + style End fill:#e1ffe1 + style Stats fill:#fff4e1 + style Merge fill:#ffe1f5 +``` + +### 3.4 BPE Example + +**Example:** Training on text "aaab" + +**Iteration 0:** + +```math +V^{(0)} = \{0, 1, \ldots, 255\} +``` + +```math +\text{tokens}^{(0)} = [97, 97, 97, 98] \quad \text{(bytes for 'aaab')} +``` + +**Calculate frequencies:** + +```math +\text{stats}^{(0)} = \{(97, 97): 2, (97, 98): 1\} +``` + +**Iteration 1:** + +```math +(i^*, j^*) = (97, 97), \quad \text{new\_id} = 256 +``` + +```math +\text{tokens}^{(1)} = [256, 97, 98] +``` + +```math +V^{(1)} = V^{(0)} \cup \{256\} +``` + +**Iteration 2:** + +```math +\text{stats}^{(1)} = \{(256, 97): 1, (97, 98): 1\} +``` + +**Choose one:** \((256, 97)\), \(\text{new_id} = 257\) + +```math +\text{tokens}^{(2)} = [257, 98] +``` + +--- + +## 4. Vocabulary Construction + +### 4.1 Vocabulary Structure + +**Final Vocabulary:** + +```math +V = V_{\text{bytes}} \cup V_{\text{merges}} \cup V_{\text{special}} +``` + +where: + +- $V_{\text{bytes}} = \{0, 1, \ldots, 255\}$ (256 tokens) +- $V_{\text{merges}} = \{256, 257, \ldots, 256 + K - 1\}$ (K merged tokens) +- $V_{\text{special}} = \{\text{}, \text{}, \text{}, \text{}\}$ + +**Vocabulary Size:** + +```math +|V| = 256 + K + |V_{\text{special}}| +``` + +### 4.2 Merge Dictionary + +**Merge Dictionary:** + +```math +M: \mathbb{N} \times \mathbb{N} \rightarrow \mathbb{N} +``` + +```math +M(i, j) = \begin{cases} +\text{merged\_id} & \text{if } (i, j) \text{ was merged} \\ +\text{undefined} & \text{otherwise} +\end{cases} +``` + +**Vocabulary Mapping:** + +```math +\text{vocab}: \mathbb{N} \rightarrow \{0, \ldots, 255\}^* +``` + +```math +\text{vocab}(\text{id}) = \begin{cases} +[\text{id}] & \text{if } \text{id} < 256 \\ +\text{vocab}(i) \oplus \text{vocab}(j) & \text{if } M(i, j) = \text{id} +\end{cases} +``` + +### 4.3 Vocabulary Construction Diagram + +```mermaid +graph TB + subgraph "Initial Vocabulary" + A1["256 Byte Tokens
{0, 1, ..., 255}"] + end + + subgraph "BPE Merges" + B1["Merge 1: (101, 32) → 256"] + B2["Merge 2: (256, 108) → 257"] + B3["Merge K: ... → 256+K-1"] + end + + subgraph "Special Tokens" + C1[" → 50256"] + C2[" → 50257"] + C3[" → 50258"] + C4[" → 50259"] + end + + A1 --> D[Final Vocabulary] + B1 --> D + B2 --> D + B3 --> D + C1 --> D + C2 --> D + C3 --> D + C4 --> D + + D --> E["|V| = 256 + K + 4"] + + style A1 fill:#e1f5ff + style D fill:#e1ffe1 + style E fill:#fff4e1 +``` + +--- + +## 5. Encoding Process + +### 5.1 Encoding Algorithm + +**Input:** Text string $s$ + +**Step 1: Regex Splitting** + +```math +\text{chunks} = \text{RegexSplit}(s, \text{pattern}) +``` + +**Step 2: Convert to Bytes** + +For each chunk $c \in \text{chunks}$: + +```math +\text{bytes}(c) = \text{UTF-8}(c) = [b_1, b_2, \ldots, b_n] +``` + +**Step 3: Apply BPE Merges** + +Initialize: $\text{tokens} = \text{bytes}(c)$ + +While merges possible: + +```math +\text{Find earliest merge: } (i^*, j^*) = \arg\min_{(i,j) \in M} \text{merge\_index}(M(i, j)) +``` + +```math +\text{Apply merge: } \text{tokens} = \text{Merge}(\text{tokens}, (i^*, j^*), M(i^*, j^*)) +``` + +**Step 4: Combine Results** + +```math +\text{token\_ids} = \bigoplus_{c \in \text{chunks}} \text{BPE}(\text{bytes}(c)) +``` + +### 5.2 Encoding Function + +**Mathematical Definition:** + +```math +\text{encode}(s) = \bigoplus_{c \in \text{RegexSplit}(s)} \text{BPE}(\text{UTF-8}(c)) +``` + +where $\bigoplus$ denotes token sequence concatenation. + +### 5.3 Encoding Flowchart + +```mermaid +graph TB + A["Input Text
'Hello world'"] --> B[Regex Split] + B --> C["Chunks
['Hello', ' world']"] + + C --> D[For Each Chunk] + D --> E["UTF-8 Encode
bytes = [72, 101, 108, 108, 111]"] + E --> F["Initialize Tokens
tokens = bytes"] + + F --> G{Merge Possible?} + G -->|Yes| H["Find Earliest Merge
(i*, j*)"] + H --> I["Apply Merge
tokens = Merge(tokens, (i*,j*), id)"] + I --> G + + G -->|No| J[Add to Result] + J --> K{More Chunks?} + K -->|Yes| D + K -->|No| L["Final Token IDs
[15496, 1917]"] + + style A fill:#e1f5ff + style L fill:#e1ffe1 + style G fill:#ffe1f5 +``` + +### 5.4 Encoding Example + +**Input:** "Hello world" + +**Step 1: Regex Split** + +```math +\text{chunks} = ["Hello", " world"] +``` + +**Step 2: UTF-8 Encoding** + +```math +\text{UTF-8}("Hello") = [72, 101, 108, 108, 111] +``` + +```math +\text{UTF-8}(" world") = [32, 119, 111, 114, 108, 100] +``` + +**Step 3: BPE Merging** + +Assume merges: $M(72, 101) = 256, M(256, 108) = 257, M(257, 108) = 258, M(258, 111) = 259$ + +```math +[72, 101, 108, 108, 111] \xrightarrow{M(72,101)} [256, 108, 108, 111] +``` + +```math +\xrightarrow{M(256,108)} [257, 108, 111] \xrightarrow{M(257,108)} [258, 111] +``` + +```math +\xrightarrow{M(258,111)} [259] +``` + +**Final:** $[259, 1917]$ + +--- + +## 6. Decoding Process + +### 6.1 Decoding Algorithm + +**Input:** Token IDs $\mathbf{t} = [t_1, t_2, \ldots, t_n]$ + +**Step 1: Handle Special Tokens** + +```math +\text{Stop at EOS: } \mathbf{t}' = \mathbf{t}[:i] \text{ where } t_i = \text{} +``` + +**Step 2: Lookup Bytes** + +For each token $t_i$: + +```math +\text{bytes}_i = \text{vocab}(t_i) +``` + +**Step 3: Concatenate Bytes** + +```math +\text{all\_bytes} = \bigoplus_{i=1}^n \text{bytes}_i +``` + +**Step 4: UTF-8 Decode** + +```math +\text{text} = \text{UTF-8}^{-1}(\text{all\_bytes}) +``` + +### 6.2 Decoding Function + +**Mathematical Definition:** + +```math +\text{decode}(\mathbf{t}) = \text{UTF-8}^{-1}\left(\bigoplus_{i=1}^n \text{vocab}(t_i)\right) +``` + +### 6.3 Decoding Flowchart + +```mermaid +graph TB + A["Token IDs
[15496, 1917]"] --> B{Special Tokens?} + B -->|EOS Found| C[Stop at EOS] + B -->|No EOS| D[Process All Tokens] + C --> D + + D --> E[For Each Token ID] + E --> F["Lookup Bytes
vocab[t_i]"] + F --> G["Concatenate
all_bytes = bytes₁ āŠ• bytesā‚‚ āŠ• ..."] + + G --> H["UTF-8 Decode
text = decode(all_bytes)"] + H --> I["Output Text
'Hello world'"] + + style A fill:#e1f5ff + style I fill:#e1ffe1 + style F fill:#fff4e1 +``` + +### 6.4 Decoding Example + +**Input:** $[259, 1917]$ + +**Step 1: Lookup** + +```math +\text{vocab}(259) = \text{vocab}(258) \oplus \text{vocab}(111) +``` + +```math += (\text{vocab}(257) \oplus \text{vocab}(108)) \oplus [111] +``` + +```math += ((\text{vocab}(256) \oplus \text{vocab}(108)) \oplus [108]) \oplus [111] +``` + +```math += (((\text{vocab}(72) \oplus \text{vocab}(101)) \oplus [108]) \oplus [108]) \oplus [111] +``` + +```math += [[72] \oplus [101]] \oplus [108] \oplus [108] \oplus [111] +``` + +```math += [72, 101, 108, 108, 111] +``` + +```math +\text{vocab}(1917) = [32, 119, 111, 114, 108, 100] +``` + +**Step 2: Concatenate** + +```math +\text{all\_bytes} = [72, 101, 108, 108, 111] \oplus [32, 119, 111, 114, 108, 100] +``` + +```math += [72, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100] +``` + +**Step 3: Decode** + +```math +\text{decode}([72, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100]) = "Hello world" +``` + +--- + +## 7. Regex Pattern Splitting + +### 7.1 GPT-4 Style Pattern + +**Pattern Components:** + +```math +P = P_{\text{contractions}} \cup P_{\text{letters}} \cup P_{\text{numbers}} \cup P_{\text{punctuation}} \cup P_{\text{whitespace}} +``` + +**Pattern Definition:** + +\[ +P = \text{'(?i:[sdmt]|ll|ve|re)} \cup \text{[^\r\n\p{L}\p{N}]?+\p{L}+} \cup \text{\p{N}{1,3}} \cup \text{?[^\s\p{L}\p{N}]++} \cup \text{\r?\n} \cup \text{\s+} +\] + +**Components:** + +1. **Contractions:** \( P\_{\text{contractions}} = \text{'(?i:[sdmt]|ll|ve|re)} \) + - Matches: `'s`, `'t`, `'ll`, `'ve`, `'re` (case-insensitive) + +2. **Letters:** \( P\_{\text{letters}} = \text{[^\r\n\p{L}\p{N}]?+\p{L}+} \) + - Optional space + one or more letters + +3. **Numbers:** \( P\_{\text{numbers}} = \text{\p{N}{1,3}} \) + - **Limit:** 1-3 digits only (prevents long number tokens) + +4. **Punctuation:** \( P\_{\text{punctuation}} = \text{?[^\s\p{L}\p{N}]++} \) + - Optional space + punctuation + +5. **Whitespace:** \( P\_{\text{whitespace}} = \text{\r?\n} \cup \text{\s+} \) + - Newlines and multiple spaces + +### 7.2 Regex Splitting Function + +```math +\text{RegexSplit}(s, P) = \{m_1, m_2, \ldots, m_k : m_i \in \text{Match}(s, P)\} +``` + +where matches are found left-to-right, non-overlapping. + +### 7.3 Regex Splitting Diagram + +```mermaid +graph TB + A["Input Text
'Hello world 123'"] --> B[Apply Regex Pattern] + + B --> C1["Chunk 1: 'Hello'
P_letters"] + B --> C2["Chunk 2: ' world'
P_whitespace + P_letters"] + B --> C3["Chunk 3: ' 123'
P_whitespace + P_numbers"] + + C1 --> D["Chunks List
['Hello', ' world', ' 123']"] + C2 --> D + C3 --> D + + style A fill:#e1f5ff + style D fill:#e1ffe1 + style B fill:#fff4e1 +``` + +**Example:** + +```math +\text{RegexSplit}("Hello world 123", P) = ["Hello", " world", " 123"] +``` + +--- + +## 8. Special Tokens + +### 8.1 Special Token Set + +**Special Tokens:** + +```math +V_{\text{special}} = \{\text{}, \text{}, \text{}, \text{}\} +``` + +**Token IDs:** + +```math +\text{id}(\text{}) = 0, \quad \text{id}(\text{}) = 1, \quad \text{id}(\text{}) = 2, \quad \text{id}(\text{}) = 3 +``` + +### 8.2 Special Token Functions + +**Padding:** + +```math +\text{pad}(\mathbf{t}, \text{max\_length}) = \mathbf{t} \oplus [\text{}]^{\max(\text{max\_length} - |\mathbf{t}|, 0)} +``` + +**Unknown Token:** + +```math +\text{encode}(s) = \begin{cases} +[\text{id}(c)] & \text{if } c \in V \\ +[\text{}] & \text{if } c \notin V +\end{cases} +``` + +**EOS Handling:** + +```math +\text{decode}(\mathbf{t}) = \text{decode}(\mathbf{t}[:i]) \text{ where } t_i = \text{} +``` + +### 8.3 Special Token Flowchart + +```mermaid +graph TB + subgraph "Special Tokens" + A1[" → 0
Padding"] + A2[" → 1
Unknown"] + A3[" → 2
Beginning"] + A4[" → 3
End"] + end + + subgraph "Usage" + B1["Padding:
[1,2,3] → [1,2,3,0,0]"] + B2["Unknown:
unknown_char → 1"] + B3["EOS Stop:
[1,2,3] → stop at 3"] + end + + A1 --> B1 + A2 --> B2 + A4 --> B3 + + style A1 fill:#e1f5ff + style A2 fill:#ffe1f5 + style A3 fill:#e1ffe1 + style A4 fill:#fff4e1 +``` + +--- + +## 9. Complete Tokenization Pipeline + +### 9.1 Full Pipeline + +**Complete Tokenization Process:** + +```mermaid +graph TB + Start[Input Text] --> Split[Regex Split] + Split --> UTF8[UTF-8 Encode] + UTF8 --> BPE[BPE Merge] + BPE --> Special{Special Tokens?} + Special -->|Yes| Handle[Handle Special] + Special -->|No| Combine[Combine Tokens] + Handle --> Combine + Combine --> Output[Token IDs] + + Output --> Embed[Embedding Layer] + + style Start fill:#e1f5ff + style Output fill:#e1ffe1 + style Embed fill:#ffe1f5 +``` + +### 9.2 Mathematical Formulation + +**Complete Encoding:** + +```math +\text{Tokenize}(s) = \text{HandleSpecial}\left(\bigoplus_{c \in \text{RegexSplit}(s)} \text{BPE}(\text{UTF-8}(c))\right) +``` + +**Complete Decoding:** + +```math +\text{Detokenize}(\mathbf{t}) = \text{UTF-8}^{-1}\left(\bigoplus_{i=1}^n \text{vocab}(\text{RemoveSpecial}(t_i))\right) +``` + +### 9.3 Pipeline Example + +**Input:** "Hello world!" + +**Step 1: Regex Split** + +```math +\text{chunks} = ["Hello", " world", "!"] +``` + +**Step 2: UTF-8 Encode** + +```math +\text{bytes} = [[72,101,108,108,111], [32,119,111,114,108,100], [33]] +``` + +**Step 3: BPE Merge** + +```math +\text{tokens} = [[15496], [1917], [0]] +``` + +**Step 4: Combine** + +```math +\text{final} = [15496, 1917, 0] +``` + +--- + +## 10. Tokenization Challenges and Solutions + +### 10.1 Challenge: Long Tokens + +**Problem:** Some tokens become very long (e.g., "defaultstyle" as single token). + +**Mathematical Impact:** + +```math +\text{LongToken}(s) = \begin{cases} +1 \text{ token} & \text{if } |s| > 10 \text{ chars} \\ +\text{multiple tokens} & \text{otherwise} +\end{cases} +``` + +**Solution:** Better regex splitting prevents over-merging. + +### 10.2 Challenge: Number Tokenization + +**Problem:** Numbers tokenized arbitrarily (sometimes 1 token, sometimes 2-3). + +**Solution:** Limit number merging to 1-3 digits: + +```math +P_{\text{numbers}} = \text{\p{N}{1,3}} +``` + +**Impact:** + +```math +\text{TokenCount}(n) = \begin{cases} +1 & \text{if } n < 1000 \\ +2-3 & \text{if } n \geq 1000 +\end{cases} +``` + +### 10.3 Challenge: Python Code Efficiency + +**Problem:** Each space is separate token (GPT-2 issue). + +**Solution:** Merge multiple spaces: + +```math +P_{\text{whitespace}} = \text{\s+} \quad \text{(matches multiple spaces)} +``` + +**Efficiency Gain:** + +```math +\text{TokensBefore} = |\text{spaces}| \quad \text{(one token per space)} +``` + +```math +\text{TokensAfter} = \lceil |\text{spaces}| / 4 \rceil \quad \text{(grouped spaces)} +``` + +### 10.4 Challenge: Trailing Whitespace + +**Problem:** Trailing spaces cause poor tokenization. + +**Detection:** + +```math +\text{HasTrailingSpace}(s) = \begin{cases} +\text{True} & \text{if } s[-1] = ' ' \\ +\text{False} & \text{otherwise} +\end{cases} +``` + +**Warning:** + +```math +\text{Tokenize}(s) = \begin{cases} +\text{encode}(s) + \text{warning} & \text{if } \text{HasTrailingSpace}(s) \\ +\text{encode}(s) & \text{otherwise} +\end{cases} +``` + +### 10.5 Challenge: Multilingual Support + +**Problem:** Non-English languages tokenize inefficiently. + +**Solution:** UTF-8 byte-level encoding handles all languages: + +```math +\text{Tokenize}(s) = \text{BPE}(\text{UTF-8}(s)) \quad \forall s \in \Sigma^* +``` + +**Efficiency:** + +```math +\text{TokenRatio} = \frac{|\text{Tokenize}(s_{\text{non-english}})|}{|\text{Tokenize}(s_{\text{english}})|} \approx 1.5-2.0 +``` + +### 10.6 Challenge: Untrained Tokens + +**Problem:** Some tokens never appear in training data. + +**Solution:** Fallback handling: + +```math +\text{Decode}(t) = \begin{cases} +\text{vocab}(t) & \text{if } t \in V \\ +\text{} & \text{if } t \notin V \\ +\text{fallback\_bytes} & \text{if } t < 256 +\end{cases} +``` + +--- + +## Summary + +### Key Formulas + +**Encoding:** + +```math +\text{encode}(s) = \text{HandleSpecial}\left(\bigoplus_{c \in \text{RegexSplit}(s)} \text{BPE}(\text{UTF-8}(c))\right) +``` + +**Decoding:** + +```math +\text{decode}(\mathbf{t}) = \text{UTF-8}^{-1}\left(\bigoplus_{i=1}^n \text{vocab}(t_i)\right) +``` + +**BPE Merge:** + +```math +\text{tokens}^{(k)} = \text{Merge}(\text{tokens}^{(k-1)}, \arg\max_{(i,j)} \text{stats}^{(k)}(i,j), 256+k-1) +``` + +**Vocabulary Size:** + +```math +|V| = 256 + K + |V_{\text{special}}| +``` + +### Key Takeaways + +1. **UTF-8 Encoding**: Handles all Unicode characters consistently +2. **BPE Algorithm**: Creates efficient vocabulary through iterative merging +3. **Regex Splitting**: Prevents over-merging with pattern-based chunking +4. **Special Tokens**: Control flow and handle edge cases +5. **Byte-Level**: Works at byte level for universal language support diff --git a/docs/TOKENIZER_IMPROVEMENTS.md b/docs/TOKENIZER_IMPROVEMENTS.md new file mode 100644 index 0000000..f347e3a --- /dev/null +++ b/docs/TOKENIZER_IMPROVEMENTS.md @@ -0,0 +1,165 @@ +# Tokenizer Improvements + +## Overview + +Based on the tokenization challenges discussed in the transcript, we've implemented an improved BPE (Byte Pair Encoding) tokenizer that addresses common issues with tokenization in large language models. + +## Key Improvements + +### 1. UTF-8 Byte-Level Encoding +- **What**: Tokenizer works at the byte level (0-255) instead of character level +- **Why**: Handles all Unicode characters consistently, regardless of language +- **Benefit**: Better support for non-English languages and special characters + +### 2. GPT-4 Style Regex Pattern +- **What**: Improved regex pattern for splitting text into chunks +- **Improvements**: + - Case-insensitive matching for contractions (`'s`, `'t`, `'ll`, etc.) + - Better whitespace handling (groups multiple spaces for Python code) + - Limits number merging to 1-3 digits (prevents overly long number tokens) +- **Benefit**: More consistent tokenization, especially for code and numbers + +### 3. BPE Training Algorithm +- **What**: Full Byte Pair Encoding implementation +- **Features**: + - `get_stats()`: Counts consecutive token pairs + - `merge()`: Replaces frequent pairs with single tokens + - Iterative training process +- **Benefit**: Creates efficient vocabulary that compresses common sequences + +### 4. Special Token Handling +- **What**: Proper handling of special tokens (``, ``, ``, ``) +- **Features**: + - Special tokens are excluded from BPE merging + - EOS token stops decoding + - Configurable special token set +- **Benefit**: Clean separation between data tokens and control tokens + +### 5. Trailing Whitespace Detection +- **What**: Warns when text ends with trailing whitespace +- **Why**: Trailing spaces can cause poor tokenization (as seen in GPT-3.5 playground warnings) +- **Benefit**: Helps developers avoid tokenization issues + +### 6. Better Python Code Handling +- **What**: Improved whitespace merging for indentation +- **Why**: GPT-2 had issues with Python code because each space was a separate token +- **Benefit**: More efficient tokenization of Python code (fewer tokens per file) + +### 7. Number Tokenization Limits +- **What**: Limits number merging to 1-3 digits +- **Why**: Prevents creating tokens for very long number sequences +- **Benefit**: Better arithmetic performance (numbers are more consistently tokenized) + +## Usage + +### Basic Usage + +```python +from data.example import SimpleTokenizer + +# Create tokenizer (uses BPE by default) +tokenizer = SimpleTokenizer(use_bpe=True, vocab_size=50257) + +# Encode text +tokens = tokenizer.encode("Hello world!") +print(tokens) # [15496, 1917, 0] + +# Decode tokens +text = tokenizer.decode(tokens) +print(text) # "Hello world!" +``` + +### Training a Custom Tokenizer + +```python +from data.example import BPETokenizer + +# Create tokenizer +tokenizer = BPETokenizer(vocab_size=50257) + +# Train on your corpus +texts = [ + "Your training text here...", + "More training text...", +] + +tokenizer.train(texts, num_merges=50000, verbose=True) + +# Save trained tokenizer +tokenizer.save("merges.json", "vocab.json") +``` + +### Loading a Pre-trained Tokenizer + +```python +from data.example import BPETokenizer + +# Load saved tokenizer +tokenizer = BPETokenizer() +tokenizer.load("merges.json", "vocab.json") + +# Use it +tokens = tokenizer.encode("Hello world!") +``` + +## Addressing Common Issues + +### Issue: "Can't spell words well" +- **Cause**: Long tokens (like "defaultstyle" as single token) +- **Fix**: Better regex splitting prevents over-merging + +### Issue: "Bad at arithmetic" +- **Cause**: Arbitrary number tokenization (sometimes 1 token, sometimes 2-3) +- **Fix**: Limits number merging to 1-3 digits for consistency + +### Issue: "Python code inefficient" +- **Cause**: Each space is separate token (GPT-2 issue) +- **Fix**: Multiple spaces merge into single tokens + +### Issue: "Non-English languages worse" +- **Cause**: Tokenizer trained primarily on English +- **Fix**: UTF-8 byte-level encoding handles all languages consistently + +### Issue: "Trailing whitespace warning" +- **Cause**: Models see very few examples of trailing spaces +- **Fix**: Warning helps developers detect and fix the issue + +### Issue: "Solid gold Magikarp" (untrained tokens) +- **Cause**: Tokenizer creates tokens for strings not in training data +- **Fix**: Proper validation and fallback handling for unknown tokens + +## Backward Compatibility + +The `SimpleTokenizer` class maintains backward compatibility: +- If `use_bpe=False`, uses character-level tokenization (old behavior) +- If `use_bpe=True` (default), uses new BPE tokenizer +- All existing code continues to work without changes + +## Technical Details + +### BPE Algorithm +1. Start with byte-level vocabulary (256 tokens) +2. Count consecutive token pairs in training data +3. Find most frequent pair +4. Merge pair into new token +5. Repeat until target vocabulary size reached + +### Encoding Process +1. Split text using regex pattern +2. Convert each chunk to UTF-8 bytes +3. Apply BPE merges (greedy, left-to-right) +4. Return token IDs + +### Decoding Process +1. Look up token IDs in vocabulary +2. Convert bytes back to UTF-8 +3. Handle special tokens (EOS stops decoding) +4. Return decoded text + +## References + +Based on improvements discussed in: +- GPT-2 paper tokenization section +- GPT-4 tokenizer improvements +- Common tokenization challenges and solutions + diff --git a/docs/TRAINING_EXPLAINED.md b/docs/TRAINING_EXPLAINED.md new file mode 100644 index 0000000..7a35be8 --- /dev/null +++ b/docs/TRAINING_EXPLAINED.md @@ -0,0 +1,1329 @@ +# What is Training? Step-by-Step Explanation + +Complete step-by-step explanation of model training: what it is, why we need data, why more data is better, and how the model learns. + +## Table of Contents + +1. [What is Training?](#51-what-is-training) +2. [Why Do We Need Training?](#52-why-do-we-need-training) +3. [What Does the Model Learn?](#53-what-does-the-model-learn) +4. [Why Do We Need Data?](#54-why-do-we-need-data) +5. [Why More Data is Better](#55-why-more-data-is-better) +6. [How Training Works: Step-by-Step](#56-how-training-works-step-by-step) +7. [The Training Process](#57-the-training-process) +8. [Loss Function](#58-loss-function) +9. [Optimization](#59-optimization) +10. [Evaluation](#510-evaluation) +11. [Common Questions](#511-common-questions) +12. [Training Metrics and Artifacts](#512-training-metrics-and-artifacts) + +--- + +## 5.1 What is Training? + +### Simple Definition + +**Training** is the process of teaching a neural network to make predictions by showing it examples and adjusting its parameters to minimize errors. + +### The Learning Analogy + +**Think of training like teaching a child:** + +**Child Learning:** + +- You show examples: "This is a cat", "This is a dog" +- Child makes mistakes: Calls a cat "dog" +- You correct: "No, that's a cat" +- Child learns patterns from many examples +- Eventually, child recognizes cats and dogs correctly + +**Model Training:** + +- You show examples: "Hello" → next word "World" +- Model makes predictions: "Hello" → predicts "Hi" +- You compute error: Compare prediction to actual +- Model adjusts: Updates parameters to reduce error +- Process repeats: Shows many examples +- Eventually, model learns to predict correctly + +### What Happens During Training? + +**The model:** + +1. **Sees** input data (examples) +2. **Makes** predictions +3. **Compares** predictions to correct answers +4. **Calculates** how wrong it was (loss) +5. **Adjusts** parameters to be less wrong +6. **Repeats** millions of times + +**Result:** A model that can make accurate predictions! + +--- + +## 5.2 Why Do We Need Training? + +### The Problem + +**Untrained models are random:** + +**Initial State:** + +``` +Input: "Hello" +Model Prediction: Random guess +→ "apple" (30%) +→ "zebra" (25%) +→ "World" (5%) +→ Other random words... +``` + +**The model doesn't know anything yet!** + +### The Solution: Training + +**After Training:** + +``` +Input: "Hello" +Model Prediction: Learned pattern +→ "World" (85%) +→ "there" (8%) +→ "friend" (3%) +→ Other reasonable words... +``` + +**The model learned language patterns!** + +### Why Training is Essential + +**Without Training:** + +- Random predictions +- No understanding of language +- No useful output +- Model is useless + +**With Training:** + +- Learned patterns +- Understanding of language +- Useful predictions +- Model is valuable + +--- + +## 5.3 What Does the Model Learn? + +### 1. Language Patterns + +**Learns:** + +- Word relationships ("Hello" often followed by "World") +- Grammar rules (subject-verb agreement) +- Sentence structure (nouns, verbs, adjectives) +- Context understanding (same word means different things) + +**Example:** + +``` +"You" → "are" (learned: pronoun + verb agreement) +"The cat" → "sat" (learned: noun + verb) +"Machine learning" → "is" (learned: compound noun + verb) +``` + +### 2. Semantic Relationships + +**Learns:** + +- Similar words have similar meanings +- Related concepts cluster together +- Word embeddings capture meaning +- Context determines word usage + +**Example:** + +``` +"cat" and "dog" → Similar embeddings (both animals) +"king" - "man" + "woman" ā‰ˆ "queen" (learned relationships) +``` + +### 3. Sequential Patterns + +**Learns:** + +- Predict next token based on context +- Long-range dependencies +- Common phrases and idioms +- Writing style and tone + +**Example:** + +``` +"Once upon a time" → "there" (learned: story beginning) +"The quick brown fox" → "jumps" (learned: common phrase) +``` + +### 4. Statistical Patterns + +**Learns:** + +- How often words appear together +- Probability distributions over vocabulary +- Language statistics +- Common word sequences + +**Example:** + +``` +"The" → very common (appears frequently) +"Antidisestablishmentarianism" → rare (appears rarely) +``` + +--- + +## 5.4 Why Do We Need Data? + +### The Fundamental Need + +**Models learn from examples, not from rules:** + +**Rule-Based Approach (Old Way):** + +``` +Programmer writes rules: +IF word == "Hello" THEN next_word = "World" +IF word == "The" THEN next_word = "cat" +... +``` + +**Problems:** + +- Need to write millions of rules +- Can't capture all patterns +- Brittle and breaks easily +- Doesn't generalize + +**Data-Driven Approach (Modern Way):** + +``` +Model learns from examples: +"Hello World" (example 1) +"The cat sat" (example 2) +... +``` + +**Benefits:** + +- Automatically learns patterns +- Captures complex relationships +- Generalizes to new examples +- Handles ambiguity + +### Why Data is Essential + +**Data provides:** + +**1. Examples to Learn From** + +``` +Without data: Model has no examples +With data: Model sees millions of examples +``` + +**2. Ground Truth** + +``` +Without data: No correct answers +With data: Knows what correct predictions are +``` + +**3. Patterns to Discover** + +``` +Without data: Can't find patterns +With data: Discovers language patterns automatically +``` + +**4. Evaluation** + +``` +Without data: Can't measure performance +With data: Can test if model learned correctly +``` + +### What Data Provides + +**Training Data:** + +- Input-output pairs +- Examples to learn from +- Patterns to discover +- Ground truth for comparison + +**Example:** + +``` +Input: "Hello" +Output: "World" + +Input: "Machine learning" +Output: "is" + +Input: "The cat" +Output: "sat" +``` + +**Each example teaches the model something!** + +--- + +## 5.5 Why More Data is Better + +### The Relationship Between Data and Performance + +**General Rule:** + +``` +More Data → Better Performance +``` + +**But why?** + +### Reason 1: More Patterns + +**Little Data:** + +``` +100 examples: +- See "Hello World" once +- See "Hello there" once +- Model uncertain: Which is more common? +``` + +**More Data:** + +``` +1,000,000 examples: +- See "Hello World" 500,000 times +- See "Hello there" 200,000 times +- See "Hello friend" 300,000 times +- Model confident: "Hello World" is most common +``` + +**More examples = Better pattern recognition** + +### Reason 2: Broader Coverage + +**Little Data:** + +``` +Limited vocabulary: +- Only sees common words +- Misses rare words +- Poor generalization +``` + +**More Data:** + +``` +Comprehensive vocabulary: +- Sees common words frequently +- Sees rare words occasionally +- Good generalization +``` + +**More examples = Better coverage** + +### Reason 3: Better Generalization + +**Little Data:** + +``` +Sees: "The cat sat on the mat" +Learns: Exact pattern +Test: "The dog sat on the rug" +Fails: Never saw "dog" or "rug" +``` + +**More Data:** + +``` +Sees: Many variations +- "The cat sat on the mat" +- "The dog sat on the rug" +- "The bird sat on the branch" +Learns: General pattern +Test: "The dog sat on the rug" +Succeeds: Understands pattern +``` + +**More examples = Better generalization** + +### Reason 4: Reduces Overfitting + +**Little Data:** + +``` +Model memorizes examples: +- Perfect on training data +- Poor on new data +- Overfitting! +``` + +**More Data:** + +``` +Model learns patterns: +- Good on training data +- Good on new data +- Generalizes well! +``` + +**More examples = Less overfitting** + +### Reason 5: Statistical Confidence + +**Little Data:** + +``` +10 examples of "Hello World" +→ Statistically uncertain +→ High variance in predictions +``` + +**More Data:** + +``` +1,000,000 examples of "Hello World" +→ Statistically confident +→ Low variance in predictions +``` + +**More examples = More confident predictions** + +### The Data-Performance Curve + +``` +Performance + │ +100%│ ā—ā”€ā”€ā”€ā”€ā”€ (More data needed) + │ ā—ā”€ā”€ā”€ā”€ā”€ + │ ā—ā”€ā”€ā”€ā”€ā”€ + │ ā—ā”€ā”€ā”€ā”€ā”€ + │ ā—ā”€ā”€ā”€ā”€ā”€ + │ + 0%ā”œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ Data + 0 1K 10K 100K 1M 10M 100M +``` + +**Diminishing Returns:** + +- First 1M examples: Huge improvement +- Next 9M examples: Good improvement +- Next 90M examples: Smaller improvement +- Beyond: Very small improvements + +**But more data is almost always better!** + +### Real-World Examples + +**GPT-3:** + +- Trained on ~300 billion tokens +- Requires massive datasets +- Better performance with more data + +**Our Model:** + +- Can train on any amount of data +- More data = better performance +- Scales with dataset size + +--- + +## 5.6 How Training Works: Step-by-Step + +### High-Level Overview + +``` +1. Initialize model (random weights) +2. For each epoch: + a. For each batch: + - Forward pass (make predictions) + - Compute loss (measure error) + - Backward pass (compute gradients) + - Update weights (reduce error) +3. Evaluate model +4. Repeat until convergence +``` + +### Detailed Step-by-Step + +#### Step 1: Initialize Model + +**Start with random weights:** + +``` +Embedding weights: Random values +Attention weights: Random values +FFN weights: Random values +→ Model makes random predictions +``` + +**Example:** + +``` +Weight initialization: [-0.1, 0.05, 0.2, ...] (random) +Initial prediction: Random token (meaningless) +``` + +#### Step 2: Forward Pass + +**Process input through model:** + +``` +Input: "Hello" + ↓ +Embedding: [0.1, -0.2, 0.3, ...] + ↓ +Attention: [0.15, 0.08, 0.22, ...] + ↓ +FFN: [0.18, -0.12, 0.24, ...] + ↓ +Output: [logits for all tokens] + ↓ +Prediction: "apple" (highest logit) +``` + +#### Step 3: Compute Loss + +**Compare prediction to correct answer:** + +``` +Expected: "World" +Predicted: "apple" +Loss: High (wrong prediction) +``` + +**Loss Function (Cross-Entropy):** + +``` +Loss = -log(P(predicted = correct)) +``` + +**Example:** + +``` +Correct token: "World" (ID 87) +Predicted probability: 0.05 (5%) +Loss = -log(0.05) ā‰ˆ 2.996 (high loss) +``` + +#### Step 4: Backward Pass + +**Compute gradients:** + +``` +Loss: 2.996 + ↓ +Compute gradients: āˆ‚Loss/āˆ‚weights + ↓ +Gradients: [0.5, -0.3, 0.8, ...] + ↓ +Shows direction to reduce loss +``` + +**Meaning:** Gradients tell us how to adjust weights to reduce error + +#### Step 5: Update Weights + +**Adjust weights using optimizer:** + +``` +Current weights: [0.1, -0.2, 0.3, ...] +Gradients: [0.5, -0.3, 0.8, ...] +Learning rate: 0.0001 + +New weights = Old weights - Learning rate Ɨ Gradients + = [0.1, -0.2, 0.3, ...] - 0.0001 Ɨ [0.5, -0.3, 0.8, ...] + = [0.09995, -0.19997, 0.29992, ...] +``` + +**Result:** Weights slightly adjusted to reduce loss + +#### Step 6: Repeat + +**Process repeats for millions of examples:** + +``` +Batch 1: Update weights slightly +Batch 2: Update weights slightly +Batch 3: Update weights slightly +... +Batch 1,000,000: Update weights slightly + +Result: Cumulative improvements → Model learns! +``` + +--- + +## 5.7 The Training Process + +### Complete Training Loop + +**For each epoch:** + +``` +Epoch 1: + Batch 1: [Input: "Hello", Output: "World"] → Loss: 2.996 → Update + Batch 2: [Input: "Machine", Output: "learning"] → Loss: 3.2 → Update + Batch 3: [Input: "The cat", Output: "sat"] → Loss: 3.1 → Update + ... + Average Loss: 3.05 + +Epoch 2: + Batch 1: [Input: "Hello", Output: "World"] → Loss: 2.5 → Update + Batch 2: [Input: "Machine", Output: "learning"] → Loss: 2.8 → Update + Batch 3: [Input: "The cat", Output: "sat"] → Loss: 2.7 → Update + ... + Average Loss: 2.65 (improved!) + +Epoch 3: + ... + Average Loss: 2.3 (improved!) + +... + +Epoch 10: + ... + Average Loss: 1.2 (much better!) +``` + +### Key Training Concepts + +**Epoch:** + +- One complete pass through the training data +- All examples seen once + +**Batch:** + +- Small group of examples processed together +- Enables efficient training + +**Iteration:** + +- Processing one batch +- One weight update + +**Loss:** + +- Measure of prediction error +- Lower is better + +**Learning Rate:** + +- How much to adjust weights +- Controls training speed + +### Training Metrics + +**Loss Over Time:** + +``` +Loss + │ + 4.0ā”‚ā— + │ + 3.0│ ā— + │ + 2.0│ ā— + │ + 1.0│ ā— + │ + 0.0ā”œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ Epochs + 0 2 4 6 8 10 +``` + +**Decreasing loss = Model learning!** + +--- + +## 5.8 Loss Function + +### What is Loss? + +**Loss measures how wrong the model is:** + +**Low Loss:** + +``` +Prediction: "World" (95% confidence) +Correct: "World" +Loss: 0.05 (very low, almost perfect!) +``` + +**High Loss:** + +``` +Prediction: "apple" (10% confidence) +Correct: "World" +Loss: 2.3 (high, very wrong!) +``` + +### Cross-Entropy Loss + +**Formula:** + +```math +L = -\frac{1}{N} \sum_{i=1}^{N} \log P(y_i | x_i) +``` + +**Where:** + +- $N$ = number of tokens +- $y_i$ = correct token +- $p(y_i | x_i)$ = predicted probability of correct token + +**Example:** + +**Input:** "Hello" +**Correct:** "World" +**Predicted probabilities:** + +``` +"World": 0.05 (5%) +"there": 0.03 (3%) +"Hello": 0.02 (2%) +... +``` + +**Loss:** + +``` +L = -log(0.05) ā‰ˆ 2.996 +``` + +**Meaning:** Model is uncertain, high loss + +**After Training:** + +``` +"World": 0.85 (85%) +"there": 0.10 (10%) +"Hello": 0.03 (3%) +... + +Loss = -log(0.85) ā‰ˆ 0.162 +``` + +**Meaning:** Model is confident, low loss! + +### Why Cross-Entropy? + +**Properties:** + +1. **Penalizes confident wrong predictions:** High loss for wrong + confident +2. **Rewards confident correct predictions:** Low loss for correct + confident +3. **Smooth gradient:** Easy to optimize +4. **Probabilistic interpretation:** Works with probabilities + +--- + +## 5.9 Optimization + +### What is Optimization? + +**Optimization = Finding best weights** + +**Goal:** + +``` +Minimize Loss(weights) +``` + +**How:** + +``` +1. Compute gradients +2. Update weights in direction that reduces loss +3. Repeat until convergence +``` + +### AdamW Optimizer + +**Our model uses AdamW:** + +**Why AdamW?** + +- Adaptive learning rate per parameter +- Handles sparse gradients well +- Weight decay for regularization +- Works well for transformers + +**How it works:** + +**Step 1: Compute Gradients** + +``` +g_t = āˆ‚Loss/āˆ‚weights +``` + +**Step 2: Update Momentum** + +``` +m_t = β₁ Ɨ m_{t-1} + (1 - β₁) Ɨ g_t +``` + +**Step 3: Update Variance** + +``` +v_t = β₂ Ɨ v_{t-1} + (1 - β₂) Ɨ g_t² +``` + +**Step 4: Update Weights** + +``` +weights_t = weights_{t-1} - lr Ɨ (m_t / (√v_t + ε)) - Ī» Ɨ weights_{t-1} +``` + +**Where:** + +- β₁ = 0.9 (momentum decay) +- β₂ = 0.999 (variance decay) +- lr = learning rate +- Ī» = weight decay +- ε = small constant + +**Result:** Efficient weight updates! + +### Learning Rate Scheduling + +**Cosine Annealing:** + +**Start:** High learning rate (fast learning) +**Middle:** Decreasing learning rate +**End:** Low learning rate (fine-tuning) + +**Visualization:** + +``` +Learning Rate + │ + 0.001ā”‚ā—ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ + │ \ + │ \ + │ \ + │ \ + │ \ + 0.000│ ā—ā”€ā”€ā”€ā”€ā”€ + └────────────────────────── Steps + Training Progress +``` + +**Benefits:** + +- Fast initial learning +- Stable convergence +- Better final performance + +--- + +## 5.10 Evaluation + +### Why Evaluate? + +**Check if model learned:** + +``` +Training Loss: 0.5 (low) +→ Model learned training data well + +But is it good on new data? +``` + +### Evaluation Metrics + +**1. Loss (Perplexity)** + +``` +Lower is better +Measures prediction uncertainty +``` + +**2. Accuracy** + +``` +Percentage of correct predictions +Higher is better +``` + +**3. Perplexity** + +``` +Perplexity = exp(loss) +Lower is better +Measures "surprise" of model +``` + +**Example:** + +``` +Loss: 2.0 +Perplexity: exp(2.0) ā‰ˆ 7.39 + +Meaning: Model is "surprised" by about 7.39 choices on average +Lower perplexity = Better predictions +``` + +### Validation Set + +**Separate data for evaluation:** + +``` +Training Set: 80% (learn from this) +Validation Set: 20% (test on this) + +Train on training set +Evaluate on validation set +→ See if model generalizes! +``` + +**Why Separate?** + +- Test on unseen data +- Detect overfitting +- Measure real performance + +--- + +## 5.11 Common Questions + +### Q1: How long does training take? + +**Answer:** Depends on: + +- Dataset size +- Model size +- Hardware (GPU/CPU) +- Number of epochs + +**Example:** + +``` +Small model (1M params), 1M tokens: +- CPU: Days +- GPU: Hours + +Large model (100M params), 100M tokens: +- CPU: Weeks +- GPU: Days +``` + +### Q2: When should training stop? + +**Answer:** + +- When validation loss stops improving +- After fixed number of epochs +- When loss converges +- When overfitting detected + +**Early Stopping:** + +``` +If validation loss doesn't improve for N epochs: +→ Stop training +→ Prevent overfitting +``` + +### Q3: Why does loss sometimes increase? + +**Answer:** Normal! Can happen due to: + +- Learning rate too high +- Difficult batch +- Optimization noise +- Normal fluctuations + +**Long-term trend should decrease:** + +``` +Loss: 3.0 → 2.8 → 2.9 → 2.7 → 2.8 → 2.6 +↑ ↑ ↑ ↑ ↑ ↑ + Small increases OK, overall decreasing +``` + +### Q4: Can I train on different types of data? + +**Answer:** Yes! Model learns from whatever data you provide: + +``` +Books → Learns literary style +Code → Learns programming patterns +Scientific papers → Learns technical language +Mixed → Learns diverse patterns +``` + +**More diverse data = More versatile model** + +### Q5: What if I don't have much data? + +**Answer:** + +- Can still train with small datasets +- May need more epochs +- May need smaller model +- Consider data augmentation + +**However:** + +- More data almost always better +- Try to collect more if possible + +### Q6: How do I know if training is working? + +**Answer: Check:** + +- Loss decreasing over time āœ“ +- Validation loss improving āœ“ +- Predictions getting better āœ“ +- Model generating reasonable text āœ“ + +**Signs of problems:** + +- Loss not decreasing → Check learning rate +- Loss increasing → Check data or model +- Predictions random → Check training + +### Q7: What's the difference between training and inference? + +**Answer:** + +**Training:** + +- Model learns from data +- Updates weights +- Computes gradients +- Optimizes parameters + +**Inference:** + +- Model makes predictions +- Fixed weights (no updates) +- No gradients computed +- Just forward pass + +**Analogy:** + +- **Training:** Student studying (learning) +- **Inference:** Student taking exam (using knowledge) + +--- + +## 5.12 Training Metrics and Artifacts + +When you run training locally, the system automatically generates several files to help you monitor and understand the training process. These files are saved in your checkpoint directory (default: `./checkpoints` or `./checkpoints_test`). + +### Generated Files + +After training completes (or during training), you'll find these files: + +1. **`training_metrics.json`** - Complete training history in JSON format +2. **`training_curve.png`** - Visual plots of loss and learning rate over time +3. **`loss_by_epoch.png`** - Average loss per epoch visualization + +### training_metrics.json + +**Location:** `checkpoints/training_metrics.json` (or your configured save directory) + +**Contents:** + +This JSON file contains the complete training history with the following fields: + +```json +{ + "train_loss": [4.19, 3.70, 3.29, ...], // Training loss at each logging step + "val_loss": [null, null, null, ...], // Validation loss (null if not evaluated) + "learning_rate": [0.0001, 0.0001, ...], // Learning rate at each step + "epochs": [0, 0, 0, ...], // Epoch number for each step + "steps": [5, 10, 15, ...] // Global training step number +} +``` + +**What Each Field Means:** + +- **`train_loss`**: Array of training loss values. Lower is better. Shows how well the model fits the training data. +- **`val_loss`**: Array of validation loss values (or `null` if validation wasn't run). Lower is better. Shows generalization to unseen data. +- **`learning_rate`**: Array of learning rate values. Shows how the learning rate scheduler adjusted the learning rate over time. +- **`epochs`**: Array indicating which epoch each metric was recorded in. +- **`steps`**: Array of global step numbers. Each step represents one batch processed. + +**How to Use:** + +```python +import json + +# Load metrics +with open('checkpoints/training_metrics.json', 'r') as f: + metrics = json.load(f) + +# Get final training loss +final_loss = metrics['train_loss'][-1] +print(f"Final training loss: {final_loss:.4f}") + +# Find minimum validation loss +val_losses = [v for v in metrics['val_loss'] if v is not None] +if val_losses: + min_val_loss = min(val_losses) + print(f"Best validation loss: {min_val_loss:.4f}") + +# Calculate average loss per epoch +epoch_0_losses = [metrics['train_loss'][i] + for i, e in enumerate(metrics['epochs']) if e == 0] +avg_epoch_0_loss = sum(epoch_0_losses) / len(epoch_0_losses) +print(f"Average loss for epoch 0: {avg_epoch_0_loss:.4f}") +``` + +### training_curve.png + +**Location:** `checkpoints/training_curve.png` + +**What It Shows:** + +This plot contains two subplots: + +1. **Top Plot: Training and Validation Loss** + - X-axis: Training steps + - Y-axis: Loss value + - Blue line: Training loss over time + - Red line: Validation loss (if available) + - Shows how loss decreases during training + +2. **Bottom Plot: Learning Rate Schedule** + - X-axis: Training steps + - Y-axis: Learning rate (log scale) + - Green line: Learning rate over time + - Shows how the learning rate scheduler adjusted the learning rate + +**How to Interpret:** + +**Good Training:** + +![Training Curve Example](images/training_curve.png) + +*Example training curve showing smooth loss decrease and learning rate schedule. Your actual plot will be saved in your checkpoint directory.* + +**Signs of Problems:** + +- **Loss not decreasing**: Learning rate too low, or model too small +- **Loss increasing**: Learning rate too high, or data issues +- **Loss oscillating wildly**: Learning rate too high +- **Training loss much lower than validation loss**: Overfitting + +**Example from Your Training:** + +Based on your `training_metrics.json`, your training shows: +- Initial loss: ~4.19 (high, model is random) +- Final loss: ~0.92 (much lower, model learned!) +- Smooth decrease: Training progressed well +- Learning rate decayed from ~0.0001 to near zero: Proper cosine annealing schedule + +### loss_by_epoch.png + +**Location:** `checkpoints/loss_by_epoch.png` + +**What It Shows:** + +- X-axis: Epoch number +- Y-axis: Average loss for that epoch +- Single data point per epoch +- Shows overall training progress at epoch level + +**How to Interpret:** + +This plot gives you a high-level view of training progress: + +**Good Training:** + +![Loss by Epoch Example](images/loss_by_epoch.png) + +*Example loss by epoch plot showing steady decrease. Your actual plot will be saved in your checkpoint directory.* + +**What to Look For:** + +- **Decreasing trend**: Model is learning āœ“ +- **Plateau**: Model may have converged +- **Increasing**: Possible overfitting or learning rate issues + +### Interpreting Your Training Results + +Based on the metrics from your local training run: + +**Training Progress:** +- Started at loss ~4.19 (random initialization) +- Ended at loss ~0.92 (significant improvement!) +- Total steps: ~5,625 steps +- Loss decreased smoothly throughout training + +**Learning Rate Schedule:** +- Started at ~0.0001 (1e-4) +- Followed cosine annealing schedule +- Decayed smoothly to near zero +- Proper warmup and decay phases + +**What This Means:** +- āœ… Training was successful - loss decreased significantly +- āœ… Learning rate schedule worked correctly +- āœ… Model learned patterns from the training data +- āœ… No signs of overfitting (smooth decrease, no sudden spikes) + +### Using Metrics for Debugging + +**Problem: Loss Not Decreasing** + +```python +# Check learning rate +metrics = json.load(open('checkpoints/training_metrics.json')) +initial_lr = metrics['learning_rate'][0] +final_lr = metrics['learning_rate'][-1] +print(f"LR: {initial_lr} -> {final_lr}") + +# If LR is too low, increase in config +# If LR is too high, decrease in config +``` + +**Problem: Overfitting** + +```python +# Compare train vs validation loss +train_losses = metrics['train_loss'] +val_losses = [v for v in metrics['val_loss'] if v is not None] + +if val_losses: + final_train = train_losses[-1] + final_val = val_losses[-1] + gap = final_val - final_train + + if gap > 0.5: + print("Warning: Large gap suggests overfitting") + print("Consider: More data, regularization, or early stopping") +``` + +**Problem: Training Too Slow** + +```python +# Check loss decrease rate +losses = metrics['train_loss'] +initial_loss = losses[0] +final_loss = losses[-1] +steps = len(losses) + +decrease_rate = (initial_loss - final_loss) / steps +print(f"Loss decrease per step: {decrease_rate:.6f}") + +# If too slow, consider: +# - Increase learning rate +# - Increase batch size +# - Check data quality +``` + +### Best Practices + +1. **Monitor During Training**: Check `training_metrics.json` periodically to catch issues early +2. **Save Checkpoints**: The metrics file is updated continuously, so you can monitor progress even if training is interrupted +3. **Compare Runs**: Save metrics from different training runs to compare hyperparameters +4. **Visual Inspection**: Always look at the plots - they reveal patterns that numbers alone don't show +5. **Early Stopping**: Use validation loss from metrics to implement early stopping if needed + +### Example: Analyzing Your Training Run + +```python +import json +import matplotlib.pyplot as plt + +# Load your training metrics +with open('checkpoints_test/training_metrics.json', 'r') as f: + metrics = json.load(f) + +# Quick analysis +print("=== Training Summary ===") +print(f"Total steps: {len(metrics['steps'])}") +print(f"Initial loss: {metrics['train_loss'][0]:.4f}") +print(f"Final loss: {metrics['train_loss'][-1]:.4f}") +print(f"Loss reduction: {metrics['train_loss'][0] - metrics['train_loss'][-1]:.4f}") +print(f"Reduction percentage: {(1 - metrics['train_loss'][-1]/metrics['train_loss'][0])*100:.1f}%") + +# Check learning rate schedule +lr_values = [lr for lr in metrics['learning_rate'] if lr is not None] +if lr_values: + print(f"\nLearning Rate:") + print(f" Initial: {lr_values[0]:.6f}") + print(f" Final: {lr_values[-1]:.6f}") + print(f" Decay factor: {lr_values[-1]/lr_values[0]:.6f}") + +# Find best checkpoint (lowest loss) +best_step_idx = metrics['train_loss'].index(min(metrics['train_loss'])) +best_step = metrics['steps'][best_step_idx] +best_loss = metrics['train_loss'][best_step_idx] +print(f"\nBest checkpoint:") +print(f" Step: {best_step}") +print(f" Loss: {best_loss:.4f}") +``` + +--- + +## Summary + +### What is Training? + +**Training** is teaching the model to make accurate predictions by: + +1. Showing examples +2. Computing errors +3. Adjusting parameters +4. Repeating millions of times + +### Why We Need Data + +**Data provides:** + +- Examples to learn from +- Patterns to discover +- Ground truth to compare +- Evaluation to measure progress + +### Why More Data is Better + +**More data enables:** + +- Better pattern recognition +- Broader coverage +- Better generalization +- Reduced overfitting +- Statistical confidence + +### Training Process + +``` +1. Initialize model (random weights) +2. Forward pass (make predictions) +3. Compute loss (measure error) +4. Backward pass (compute gradients) +5. Update weights (reduce error) +6. Repeat for many epochs +``` + +### Key Takeaways + +āœ… **Training teaches models to make predictions** +āœ… **Models learn from data, not rules** +āœ… **More data = Better performance** +āœ… **Loss measures prediction error** +āœ… **Optimization updates weights to reduce loss** +āœ… **Evaluation checks if model learned correctly** + +--- + +_This document provides a comprehensive explanation of model training, why we need data, and why more data leads to better performance in transformer models._ diff --git a/docs/images/loss_by_epoch.png b/docs/images/loss_by_epoch.png new file mode 100644 index 0000000..1a6fe41 Binary files /dev/null and b/docs/images/loss_by_epoch.png differ diff --git a/docs/images/training_curve.png b/docs/images/training_curve.png new file mode 100644 index 0000000..c308492 Binary files /dev/null and b/docs/images/training_curve.png differ diff --git a/download_all_repos.py b/download_all_repos.py new file mode 100755 index 0000000..b857add --- /dev/null +++ b/download_all_repos.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 +""" +Convenience script to download all repository categories at once. +Downloads: Neovim, Lua, Bash, Zsh, Python, and Ethical Hacking repositories. +""" +import sys +from pathlib import Path + +# Import the download function +sys.path.insert(0, str(Path(__file__).parent)) +from download_repos import download_repos + +def main(): + print("šŸš€ SheepOp - Downloading All Repository Categories") + print("=" * 60) + print("\nThis will download:") + print(" šŸ“¦ Neovim configurations and plugins") + print(" šŸ“¦ Lua programming repositories") + print(" šŸ“¦ Bash/shell script repositories") + print(" šŸ“¦ Zsh configuration and plugins") + print(" šŸ“¦ Python programming repositories") + print(" šŸ“¦ Ethical hacking and cybersecurity tools") + print("\n" + "=" * 60) + + # Default settings + categories = ['nvim', 'lua', 'bash', 'zsh', 'python', 'hacking'] + max_repos_per_category = 50 + min_stars = 100 + output_dir = "data/repos" + shallow = True # Default to shallow clones + max_size_gb = 1024.0 # Default 1 TB + + # Check for command line arguments + if len(sys.argv) > 1: + if '--help' in sys.argv or '-h' in sys.argv: + print("\nUsage:") + print(" python3 download_all_repos.py [options]") + print("\nOptions:") + print(" --max-repos N Maximum repos per category (default: 50)") + print(" --min-stars N Minimum stars (default: 100)") + print(" --output DIR Output directory (default: data/repos)") + print(" --max-size N Maximum total size in GB (default: 1024.0 = 1 TB)") + print(" --full-clone Do full clone instead of shallow") + print("\nExample:") + print(" python3 download_all_repos.py --max-repos 100 --min-stars 200 --max-size 1024.0") + return + + # Parse simple arguments + args = sys.argv[1:] + i = 0 + while i < len(args): + if args[i] == '--max-repos' and i + 1 < len(args): + max_repos_per_category = int(args[i + 1]) + i += 2 + elif args[i] == '--min-stars' and i + 1 < len(args): + min_stars = int(args[i + 1]) + i += 2 + elif args[i] == '--output' and i + 1 < len(args): + output_dir = args[i + 1] + i += 2 + elif args[i] == '--max-size' and i + 1 < len(args): + max_size_gb = float(args[i + 1]) + i += 2 + elif args[i] == '--full-clone': + shallow = False + i += 1 + else: + i += 1 + + print(f"\nšŸ“Š Settings:") + print(f" Categories: {', '.join(categories)}") + print(f" Max repos per category: {max_repos_per_category}") + print(f" Min stars: {min_stars}") + print(f" Output directory: {output_dir}") + print(f" Max size: {max_size_gb} GB ({max_size_gb / 1024.0:.2f} TB)") + print(f" Shallow clone: {shallow}") + print() + + # Confirm before starting + try: + response = input("Continue? [Y/n]: ").strip().lower() + if response and response != 'y': + print("Cancelled.") + return + except KeyboardInterrupt: + print("\nCancelled.") + return + + # Download all categories + success = download_repos( + output_dir=output_dir, + license='mit', # Default to MIT license + min_stars=min_stars, + max_repos=max_repos_per_category, + shallow=shallow, + categories=categories, + max_size_gb=max_size_gb, + ) + + if success: + print(f"\nšŸŽ‰ All downloads complete!") + print(f"\nšŸ“š You can now train with:") + print(f" python3 train.py --data data/ --config config.json --device cuda") + print(f"\n This will process:") + print(f" - Your existing 196 GB of text data") + print(f" - All downloaded code repositories") + else: + print("\nāŒ Some downloads failed. Check the output above for details.") + sys.exit(1) + + +if __name__ == '__main__': + main() + diff --git a/download_data.py b/download_data.py new file mode 100644 index 0000000..088503b --- /dev/null +++ b/download_data.py @@ -0,0 +1,264 @@ +""" +Data download utilities for training the language model +""" +import urllib.request +import gzip +import os +from pathlib import Path +from typing import Optional + + +def download_text_file(url: str, output_path: str, decompress: bool = False): + """ + Download a text file from a URL. + + Args: + url: URL to download from + output_path: Path to save the file + decompress: Whether to decompress gzip files + """ + print(f"Downloading from {url}...") + + if decompress: + # Download and decompress gzip file + with urllib.request.urlopen(url) as response: + with gzip.open(response, 'rt', encoding='utf-8') as f: + content = f.read() + + with open(output_path, 'w', encoding='utf-8') as f: + f.write(content) + else: + # Download regular file + urllib.request.urlretrieve(url, output_path) + + print(f"Downloaded to {output_path}") + + +def download_wiki_text(): + """Download a small Wikipedia text dataset.""" + url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" + output_path = "data/wikitext_sample.txt" + + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + + try: + download_text_file(url, output_path) + print(f"Successfully downloaded Wikipedia text sample to {output_path}") + return output_path + except Exception as e: + print(f"Error downloading: {e}") + return None + + +def download_shakespeare(): + """Download Shakespeare text dataset.""" + url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" + output_path = "data/shakespeare.txt" + + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + + try: + download_text_file(url, output_path) + print(f"Successfully downloaded Shakespeare text to {output_path}") + return output_path + except Exception as e: + print(f"Error downloading: {e}") + return None + + +def download_openwebtext_sample(): + """Download a sample from OpenWebText corpus.""" + # Using a smaller sample URL + url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" + output_path = "data/openwebtext_sample.txt" + + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + + try: + download_text_file(url, output_path) + print(f"Successfully downloaded sample text to {output_path}") + return output_path + except Exception as e: + print(f"Error downloading: {e}") + return None + + +def create_sample_data(output_path: str = "data/sample_data.txt", num_samples: int = 100): + """ + Create a sample data.txt file with generated text. + + Args: + output_path: Path to save the file + num_samples: Number of text samples to generate + """ + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + + sample_texts = [ + "The quick brown fox jumps over the lazy dog.", + "Machine learning is transforming artificial intelligence.", + "Natural language processing enables computers to understand text.", + "Deep learning models can learn complex patterns from data.", + "Transformers have revolutionized the field of NLP.", + "Attention mechanisms allow models to focus on relevant information.", + "Language models can generate coherent text.", + "Neural networks are inspired by the human brain.", + "Training a model requires large amounts of data.", + "Gradient descent is used to optimize neural networks.", + "Python is a popular programming language for machine learning.", + "PyTorch is a flexible deep learning framework.", + "The transformer architecture uses self-attention mechanisms.", + "Tokenization converts text into numerical representations.", + "Embeddings capture semantic meaning of words.", + "The model learns to predict the next word in a sequence.", + "Backpropagation computes gradients for training.", + "Regularization techniques prevent overfitting.", + "Cross-validation helps evaluate model performance.", + "Hyperparameter tuning improves model accuracy.", + "The training process iterates over multiple epochs.", + "Batch processing speeds up training.", + "GPU acceleration makes training faster.", + "Checkpoints save model state during training.", + "Evaluation metrics measure model quality.", + "Perplexity measures how well a model predicts text.", + "BLEU score evaluates translation quality.", + "F1 score combines precision and recall.", + "Accuracy measures correct predictions.", + "Loss function quantifies prediction errors.", + "Optimizers update model parameters.", + "Learning rate controls training speed.", + "Dropout prevents overfitting.", + "Layer normalization stabilizes training.", + "Residual connections help train deep networks.", + "Multi-head attention captures different relationships.", + "Positional encoding adds sequence information.", + "Causal masking ensures autoregressive generation.", + "Sampling strategies control text generation.", + "Temperature scaling adjusts randomness.", + "Top-k sampling limits vocabulary choices.", + "Nucleus sampling uses cumulative probability.", + "Beam search finds high-probability sequences.", + "Greedy decoding selects highest probability tokens.", + "The model architecture determines capabilities.", + "Data quality affects model performance.", + "Preprocessing cleans and formats data.", + "Data augmentation increases training examples.", + "Transfer learning uses pretrained models.", + "Fine-tuning adapts models to specific tasks.", + "Zero-shot learning requires no training examples.", + "Few-shot learning uses few examples.", + "In-context learning adapts during inference.", + "Prompt engineering improves model outputs.", + "Chain-of-thought reasoning breaks down problems.", + "Self-consistency improves reliability.", + "Ensemble methods combine multiple models.", + "Model compression reduces size.", + "Quantization reduces precision.", + "Pruning removes unnecessary connections.", + "Distillation transfers knowledge between models.", + "The field of AI continues to evolve rapidly.", + "Research pushes boundaries of what's possible.", + "Open source enables collaboration.", + "Reproducibility ensures scientific validity.", + "Ethics guides responsible AI development.", + "Bias detection identifies unfairness.", + "Fairness metrics measure equity.", + "Transparency enables understanding.", + "Interpretability reveals model reasoning.", + "Adversarial examples test robustness.", + "Security protects against attacks.", + "Privacy preserves user data.", + "Federated learning protects privacy.", + "Differential privacy adds noise.", + "Homomorphic encryption enables computation.", + "Blockchain provides decentralization.", + "Cryptography ensures security.", + "The future of AI is exciting.", + "Technology empowers human potential.", + "Innovation drives progress.", + "Collaboration accelerates discovery.", + "Education spreads knowledge.", + "Understanding deepens appreciation.", + "Curiosity fuels exploration.", + "Experimentation leads to breakthroughs.", + "Persistence overcomes challenges.", + "Creativity inspires solutions.", + "The journey of learning never ends.", + "Every dataset tells a story.", + "Patterns emerge from complexity.", + "Simplicity reveals elegance.", + "Understanding requires patience.", + "Mastery comes from practice.", + "Progress happens incrementally.", + "Success builds on failures.", + "Wisdom comes from experience.", + ] + + with open(output_path, 'w', encoding='utf-8') as f: + for i in range(num_samples): + # Cycle through sample texts + text = sample_texts[i % len(sample_texts)] + f.write(text + '\n') + + print(f"Created sample data file: {output_path} with {num_samples} samples") + return output_path + + +def scrape_wikipedia_article(title: str, output_path: str = "data/wikipedia_article.txt"): + """ + Download a Wikipedia article (requires wikipedia library). + + Args: + title: Wikipedia article title + output_path: Path to save the file + """ + try: + import wikipedia + + print(f"Downloading Wikipedia article: {title}") + page = wikipedia.page(title) + + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'w', encoding='utf-8') as f: + f.write(page.content) + + print(f"Downloaded to {output_path}") + return output_path + except ImportError: + print("Wikipedia library not installed. Install with: pip install wikipedia") + return None + except Exception as e: + print(f"Error downloading Wikipedia article: {e}") + return None + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser(description='Download or create training data') + parser.add_argument('--type', type=str, choices=['shakespeare', 'sample', 'wiki', 'wikipedia'], + default='sample', help='Type of data to get') + parser.add_argument('--output', type=str, help='Output file path') + parser.add_argument('--samples', type=int, default=100, help='Number of samples for generated data') + parser.add_argument('--title', type=str, help='Wikipedia article title') + + args = parser.parse_args() + + if args.type == 'shakespeare': + output = download_shakespeare() + elif args.type == 'sample': + output_path = args.output or 'data/sample_data.txt' + output = create_sample_data(output_path, args.samples) + elif args.type == 'wiki': + output = download_wiki_text() + elif args.type == 'wikipedia': + if not args.title: + print("Error: --title required for Wikipedia download") + else: + output_path = args.output or f'data/wikipedia_{args.title.replace(" ", "_")}.txt' + output = scrape_wikipedia_article(args.title, output_path) + + if output: + print(f"\nData ready at: {output}") + print(f"You can now train with: python train.py --data {output}") + diff --git a/download_large_data.py b/download_large_data.py new file mode 100755 index 0000000..c7c08f5 --- /dev/null +++ b/download_large_data.py @@ -0,0 +1,428 @@ +#!/usr/bin/env python3 +""" +Download large datasets for training the SheepOp LLM. +Supports Amazon Reviews, WikiText, OpenWebText, BookCorpus, and more. +""" +import argparse +import sys +from pathlib import Path +from typing import Optional + + +def download_amazon_reviews(output: str = "data/amazon_reviews.txt", limit: int = 500000, category: str = "Video_Games_v1_00"): + """ + Download Amazon Product Reviews dataset. + + Args: + output: Output file path + limit: Maximum number of reviews to download + category: Product category (Video_Games_v1_00, Books_v1_00, etc.) + """ + try: + from datasets import load_dataset + except ImportError: + print("Error: 'datasets' library not installed.") + print("Install with: pip install datasets") + return False + + Path(output).parent.mkdir(parents=True, exist_ok=True) + + print(f"šŸ“„ Downloading Amazon Product Reviews (category: {category}, limit: {limit})...") + print(" This may take several minutes depending on your connection...") + + try: + # Try different dataset names/approaches + # Method 1: Try mc4 (Common Crawl) which includes Amazon-like content + print(" Attempting to download from alternative source...") + + # Use amazon_polarity dataset (smaller but works) + try: + print(" Trying amazon_polarity dataset...") + dataset = load_dataset("amazon_polarity", split=f"train[:{limit}]") + + with open(output, "w", encoding="utf-8") as f: + count = 0 + for item in dataset: + review = item.get("content", "").strip() + if not review: + review = item.get("text", "").strip() + if review and len(review) > 20: + f.write(review + "\n") + count += 1 + if count % 50000 == 0: + print(f" āœ“ Downloaded {count:,} reviews...") + + print(f"āœ… Successfully saved {count:,} reviews to {output}") + return True + + except Exception as e1: + print(f" amazon_polarity failed: {e1}") + + # Method 2: Use IMDB reviews (similar structure) + try: + print(" Trying IMDB reviews as alternative...") + dataset = load_dataset("imdb", split=f"train[:{limit}]") + + with open(output, "w", encoding="utf-8") as f: + count = 0 + for item in dataset: + review = item.get("text", "").strip() + if review and len(review) > 20: + f.write(review + "\n") + count += 1 + if count % 50000 == 0: + print(f" āœ“ Downloaded {count:,} reviews...") + + print(f"āœ… Successfully saved {count:,} reviews to {output}") + print(" Note: Using IMDB reviews instead of Amazon reviews") + return True + + except Exception as e2: + print(f" IMDB also failed: {e2}") + raise Exception("Both Amazon and IMDB datasets failed. Try using --alternative flag with a different dataset.") + + except Exception as e: + print(f"āŒ Error downloading reviews: {e}") + print("\nšŸ’” Alternative options:") + print(" 1. Use WikiText instead: python3 download_large_data.py wiki") + print(" 2. Use OpenWebText: python3 download_large_data.py openwebtext --limit 100000") + print(" 3. Try downloading from HuggingFace Hub manually") + return False + + +def download_wikitext(output: str = "data/wikitext.txt", version: str = "103"): + """ + Download WikiText dataset (Wikipedia text). + + Args: + output: Output file path + version: WikiText version ('2' or '103') + """ + try: + from datasets import load_dataset + except ImportError: + print("Error: 'datasets' library not installed.") + print("Install with: pip install datasets") + return False + + Path(output).parent.mkdir(parents=True, exist_ok=True) + + print(f"šŸ“„ Downloading WikiText-{version}...") + print(" This may take several minutes...") + + try: + dataset = load_dataset("wikitext", f"wikitext-{version}-v1", split="train") + + with open(output, "w", encoding="utf-8") as f: + count = 0 + for item in dataset: + text = item.get("text", "").strip() + # Filter out headers and empty lines + if text and len(text) > 20 and not text.startswith("="): + # Split into sentences + sentences = text.split('.') + for s in sentences: + s = s.strip() + if len(s) > 20: + f.write(s + ".\n") + count += 1 + if count % 10000 == 0: + print(f" āœ“ Processed {count:,} sentences...") + + print(f"āœ… Successfully saved {count:,} sentences to {output}") + return True + + except Exception as e: + print(f"āŒ Error downloading WikiText: {e}") + return False + + +def download_openwebtext(output: str = "data/openwebtext.txt", limit: int = 100000): + """ + Download OpenWebText dataset (web text corpus). + + Args: + output: Output file path + limit: Maximum number of samples to download + """ + try: + from datasets import load_dataset + except ImportError: + print("Error: 'datasets' library not installed.") + print("Install with: pip install datasets") + return False + + Path(output).parent.mkdir(parents=True, exist_ok=True) + + print(f"šŸ“„ Downloading OpenWebText (limit: {limit:,})...") + print(" This may take a while - OpenWebText is very large...") + + try: + dataset = load_dataset("openwebtext", split=f"train[:{limit}]") + + with open(output, "w", encoding="utf-8") as f: + count = 0 + for item in dataset: + text = item.get("text", "").strip() + if text: + # Split into sentences + sentences = text.split('.') + for s in sentences: + s = s.strip() + if len(s) > 20: + f.write(s + ".\n") + count += 1 + if count % 10000 == 0: + print(f" āœ“ Processed {count:,} sentences...") + + print(f"āœ… Successfully saved {count:,} sentences to {output}") + return True + + except Exception as e: + print(f"āŒ Error downloading OpenWebText: {e}") + return False + + +def download_bookcorpus(output: str = "data/bookcorpus.txt", limit: int = 100000): + """ + Download BookCorpus dataset (books). + + Args: + output: Output file path + limit: Maximum number of books to download + """ + try: + from datasets import load_dataset + except ImportError: + print("Error: 'datasets' library not installed.") + print("Install with: pip install datasets") + return False + + Path(output).parent.mkdir(parents=True, exist_ok=True) + + print(f"šŸ“„ Downloading BookCorpus (limit: {limit:,} books)...") + print(" This may take a while...") + + try: + dataset = load_dataset("bookcorpus", split=f"train[:{limit}]") + + with open(output, "w", encoding="utf-8") as f: + count = 0 + for item in dataset: + text = item.get("text", "").strip() + if text: + # Split into sentences + sentences = text.split('.') + for s in sentences: + s = s.strip() + if len(s) > 20: + f.write(s + ".\n") + count += 1 + if count % 10000 == 0: + print(f" āœ“ Processed {count:,} sentences...") + + print(f"āœ… Successfully saved {count:,} sentences to {output}") + return True + + except Exception as e: + print(f"āŒ Error downloading BookCorpus: {e}") + return False + + +def download_wikitext_direct(output: str = "data/wikitext_direct.txt"): + """ + Download WikiText directly from URL (no HuggingFace required). + """ + import urllib.request + import zipfile + import tempfile + import os + + Path(output).parent.mkdir(parents=True, exist_ok=True) + + url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip" + + print("šŸ“„ Downloading WikiText-103 directly from URL...") + print(" This may take several minutes...") + + try: + # Download to temp file + with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as tmp_file: + tmp_path = tmp_file.name + print(f" Downloading to temporary file...") + urllib.request.urlretrieve(url, tmp_path) + + # Extract and process + print(" Extracting and processing...") + with zipfile.ZipFile(tmp_path, 'r') as zip_ref: + # Extract wiki.train.tokens + with zip_ref.open('wikitext-103/wiki.train.tokens') as f: + with open(output, 'w', encoding='utf-8') as out_file: + count = 0 + for line in f: + line = line.decode('utf-8').strip() + if line and len(line) > 20 and not line.startswith('='): + sentences = line.split('.') + for s in sentences: + s = s.strip() + if len(s) > 20: + out_file.write(s + ".\n") + count += 1 + if count % 10000 == 0: + print(f" āœ“ Processed {count:,} sentences...") + + # Clean up + os.unlink(tmp_path) + + print(f"āœ… Successfully saved {count:,} sentences to {output}") + return True + + except Exception as e: + print(f"āŒ Error downloading WikiText: {e}") + return False + + +def main(): + parser = argparse.ArgumentParser( + description='Download large datasets for training SheepOp LLM', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Download 500k Amazon reviews + python3 download_large_data.py amazon --limit 500000 + + # Download WikiText-103 + python3 download_large_data.py wiki + + # Download OpenWebText sample + python3 download_large_data.py openwebtext --limit 100000 + + # Download to custom location + python3 download_large_data.py amazon --output data/my_reviews.txt + """ + ) + + parser.add_argument( + 'dataset', + choices=['amazon', 'wiki', 'wikitext', 'openwebtext', 'bookcorpus'], + help='Dataset to download' + ) + + parser.add_argument( + '--output', + type=str, + help='Output file path (default: data/.txt)' + ) + + parser.add_argument( + '--limit', + type=int, + default=500000, + help='Maximum number of samples to download (default: 500000)' + ) + + parser.add_argument( + '--category', + type=str, + default='Video_Games_v1_00', + help='Amazon reviews category (for amazon dataset only, may not work - uses alternative)' + ) + + parser.add_argument( + '--use-imdb', + action='store_true', + help='Use IMDB reviews instead of Amazon (more reliable)' + ) + + parser.add_argument( + '--version', + type=str, + default='103', + choices=['2', '103'], + help='WikiText version: 2 (small) or 103 (large)' + ) + + args = parser.parse_args() + + # Set default output path if not provided + if not args.output: + if args.dataset == 'amazon': + args.output = f"data/amazon_reviews.txt" + elif args.dataset in ['wiki', 'wikitext']: + args.output = f"data/wikitext_{args.version}.txt" + elif args.dataset == 'openwebtext': + args.output = "data/openwebtext.txt" + elif args.dataset == 'bookcorpus': + args.output = "data/bookcorpus.txt" + + print(f"\nšŸš€ SheepOp Dataset Downloader") + print(f" Dataset: {args.dataset}") + print(f" Output: {args.output}") + print(f" Limit: {args.limit:,} samples\n") + + # Download based on dataset type + success = False + if args.dataset == 'amazon': + if args.use_imdb: + # Use IMDB directly + try: + from datasets import load_dataset + print("šŸ“„ Downloading IMDB Reviews...") + dataset = load_dataset("imdb", split=f"train[:{args.limit}]") + Path(args.output).parent.mkdir(parents=True, exist_ok=True) + with open(args.output, "w", encoding="utf-8") as f: + count = 0 + for item in dataset: + review = item.get("text", "").strip() + if review and len(review) > 20: + f.write(review + "\n") + count += 1 + if count % 50000 == 0: + print(f" āœ“ Downloaded {count:,} reviews...") + print(f"āœ… Successfully saved {count:,} reviews to {args.output}") + success = True + except Exception as e: + print(f"āŒ Error: {e}") + success = False + else: + success = download_amazon_reviews(args.output, args.limit, args.category) + elif args.dataset in ['wiki', 'wikitext']: + if args.version == '103': + # Try direct download first (no HuggingFace dependency) + print(" Attempting direct download (no HuggingFace required)...") + success = download_wikitext_direct(args.output) + if not success: + print(" Falling back to HuggingFace download...") + success = download_wikitext(args.output, args.version) + else: + success = download_wikitext(args.output, args.version) + elif args.dataset == 'openwebtext': + success = download_openwebtext(args.output, args.limit) + elif args.dataset == 'bookcorpus': + success = download_bookcorpus(args.output, args.limit) + + if success: + print(f"\nāœ… Download complete!") + print(f" File: {args.output}") + + # Show file info + try: + import os + size_mb = os.path.getsize(args.output) / (1024 * 1024) + with open(args.output, 'r', encoding='utf-8') as f: + lines = sum(1 for _ in f) + print(f" Size: {size_mb:.2f} MB") + print(f" Lines: {lines:,}") + except: + pass + + print(f"\nšŸ“š You can now train with:") + print(f" python3 train.py --data {args.output} --config config.json --device cuda") + else: + print(f"\nāŒ Download failed. Please check the error messages above.") + sys.exit(1) + + +if __name__ == '__main__': + main() + diff --git a/download_repos.py b/download_repos.py new file mode 100755 index 0000000..2f58d5e --- /dev/null +++ b/download_repos.py @@ -0,0 +1,681 @@ +#!/usr/bin/env python3 +""" +Download GitHub repositories with open licenses for code training. +Uses GitHub API to find and clone repositories automatically. +Includes support for Neovim, Lua, Bash, and ethical hacking repos. +""" +import argparse +import subprocess +import sys +import os +from pathlib import Path +from typing import List, Optional, Dict +import json +import urllib.request +import urllib.parse +import time +from tqdm import tqdm + + +def get_directory_size(directory: Path) -> int: + """Get total size of directory in bytes.""" + total = 0 + try: + for entry in directory.rglob('*'): + if entry.is_file(): + try: + total += entry.stat().st_size + except (OSError, PermissionError): + pass + except Exception: + pass + return total + +def format_size(size_bytes: int) -> str: + """Format bytes to human-readable size.""" + for unit in ['B', 'KB', 'MB', 'GB', 'TB']: + if size_bytes < 1024.0: + return f"{size_bytes:.2f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.2f} PB" + + +# Open source licenses (permissive and commonly used) +OPEN_LICENSES = [ + 'mit', + 'apache-2.0', + 'bsd-3-clause', + 'bsd-2-clause', + 'isc', + 'unlicense', + 'mpl-2.0', + 'lgpl-2.1', + 'lgpl-3.0', + 'gpl-2.0', + 'gpl-3.0', +] + +# Popular programming languages +POPULAR_LANGUAGES = [ + 'python', + 'javascript', + 'typescript', + 'java', + 'cpp', + 'c', + 'go', + 'rust', + 'ruby', + 'php', + 'swift', + 'kotlin', + 'scala', + 'r', + 'sql', + 'lua', + 'shell', # For bash/shell scripts +] + +# Predefined repository categories +REPO_CATEGORIES = { + 'nvim': { + 'query': 'neovim OR nvim-config OR neovim-config', + 'language': None, + 'description': 'Neovim configuration and plugins' + }, + 'lua': { + 'query': None, + 'language': 'lua', + 'description': 'Lua programming language repositories' + }, + 'bash': { + 'query': None, + 'language': 'shell', + 'description': 'Bash/shell script repositories' + }, + 'zsh': { + 'query': 'zsh-config OR oh-my-zsh OR zsh-plugin', + 'language': None, + 'description': 'Zsh configuration and plugins' + }, + 'python': { + 'query': None, + 'language': 'python', + 'description': 'Python programming repositories' + }, + 'hacking': { + 'query': 'ethical-hacking OR cybersecurity OR penetration-testing OR security-tools OR red-team', + 'language': None, + 'description': 'Ethical hacking and cybersecurity tools' + }, + 'security': { + 'query': 'security-tools OR cybersecurity OR penetration-testing OR red-team OR blue-team', + 'language': None, + 'description': 'Security and cybersecurity repositories' + }, + 'all-open': { + 'query': None, + 'language': None, + 'description': 'All repositories with open licenses (any language)' + }, +} + + +def search_github_repos( + language: Optional[str] = None, + license: Optional[str] = None, + query: Optional[str] = None, + min_stars: int = 100, + max_repos: int = 100, + sort: str = 'stars', + order: str = 'desc' +) -> List[dict]: + """ + Search GitHub for repositories matching criteria. + + Args: + language: Programming language (e.g., 'python', 'javascript') + license: License type (e.g., 'mit', 'apache-2.0') + query: Custom search query + min_stars: Minimum number of stars + max_repos: Maximum number of repos to return + sort: Sort by ('stars', 'updated', 'created') + order: Order ('desc' or 'asc') + + Returns: + List of repository dictionaries + """ + # Build query + query_parts = [] + + if query: + # Custom query (for categories like nvim, hacking) + query_parts.append(query) + else: + # Standard language-based query + if language: + query_parts.append(f"language:{language}") + + if license: + query_parts.append(f"license:{license}") + query_parts.append(f"stars:>={min_stars}") + + search_query = " ".join(query_parts) + + # GitHub API endpoint + base_url = "https://api.github.com/search/repositories" + params = { + 'q': search_query, + 'sort': sort, + 'order': order, + 'per_page': min(100, max_repos), # GitHub max is 100 per page + } + + url = f"{base_url}?{urllib.parse.urlencode(params)}" + + print(f"šŸ” Searching GitHub for repositories...") + print(f" Query: {search_query}") + print(f" Max repos: {max_repos}") + + try: + # Make request + req = urllib.request.Request(url) + req.add_header('Accept', 'application/vnd.github.v3+json') + req.add_header('User-Agent', 'SheepOp-Repo-Downloader') + + # Add GitHub token if available + github_token = os.environ.get('GITHUB_TOKEN') + if github_token: + req.add_header('Authorization', f'token {github_token}') + + with urllib.request.urlopen(req) as response: + data = json.loads(response.read().decode()) + + repos = data.get('items', [])[:max_repos] + print(f"āœ… Found {len(repos)} repositories") + return repos + + except urllib.error.HTTPError as e: + if e.code == 403: + print("āŒ Rate limit exceeded. Please wait a few minutes or use a GitHub token.") + print(" To use a token, set GITHUB_TOKEN environment variable:") + print(" export GITHUB_TOKEN=your_token_here") + else: + print(f"āŒ Error searching GitHub: {e}") + if e.code == 422: + print(" Tip: Try adjusting your search query or reducing max-repos") + return [] + except Exception as e: + print(f"āŒ Error: {e}") + return [] + + +def clone_repo(repo_url: str, output_dir: Path, depth: Optional[int] = None) -> bool: + """ + Clone a repository. + + Args: + repo_url: Repository URL (https://github.com/user/repo.git) + output_dir: Directory to clone into + depth: Shallow clone depth (None = full clone) + + Returns: + True if successful + """ + repo_name = repo_url.split('/')[-1].replace('.git', '') + target_dir = output_dir / repo_name + + # Skip if already exists + if target_dir.exists(): + return True # Silent skip (progress bar will show it) + + try: + cmd = ['git', 'clone', '--quiet'] # Quiet mode for cleaner output + if depth: + cmd.extend(['--depth', str(depth)]) + cmd.append(repo_url) + cmd.append(str(target_dir)) + + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=300 # 5 minute timeout + ) + + return result.returncode == 0 + + except subprocess.TimeoutExpired: + return False + except Exception as e: + return False + + +def download_category( + category: str, + output_dir: Path, + license: Optional[str] = None, + min_stars: int = 100, + max_repos: int = 50, + shallow: bool = True, + max_size_bytes: Optional[int] = None, +) -> tuple: + """ + Download repositories for a specific category. + + Returns: + (cloned_count, failed_count) + """ + if category not in REPO_CATEGORIES: + print(f"āŒ Unknown category: {category}") + return 0, 0 + + cat_info = REPO_CATEGORIES[category] + print(f"\nšŸ“¦ Downloading {category} repositories...") + print(f" {cat_info['description']}") + + # For 'all-open' category, don't filter by license unless explicitly specified + search_license = None if category == 'all-open' and not license else (license or 'mit') + + repos = search_github_repos( + language=cat_info['language'], + license=search_license, + query=cat_info['query'], + min_stars=min_stars, + max_repos=max_repos, + ) + + if not repos: + print(f" No repositories found for {category}") + return 0, 0 + + print(f" Cloning {len(repos)} repositories...") + + cloned = 0 + failed = 0 + + # Progress bar for cloning + pbar = tqdm( + total=len(repos), + desc=f"Cloning {category}", + unit="repo", + ncols=100, + mininterval=0.1, + maxinterval=1.0, + file=sys.stderr, # Write to stderr to avoid buffering issues + dynamic_ncols=True, # Auto-adjust to terminal width + disable=False, # Explicitly enable + ) + + # Cache size to avoid recalculating every iteration + cached_size = get_directory_size(output_dir) if max_size_bytes else 0 + size_check_counter = 0 + + for i, repo in enumerate(repos, 1): + # Check size limit every 5 repos (to avoid blocking progress bar) + if max_size_bytes: + size_check_counter += 1 + if size_check_counter >= 5: + cached_size = get_directory_size(output_dir) + size_check_counter = 0 + if cached_size >= max_size_bytes: + pbar.close() + print(f"\nāš ļø Size limit reached: {format_size(cached_size)} >= {format_size(max_size_bytes)}") + print(f" Stopping downloads for {category}.") + break + + repo_url = repo['clone_url'] + repo_name = repo['full_name'] + stars = repo['stargazers_count'] + repo_lang = repo.get('language', 'N/A') + + # Update progress bar before clone + pbar.set_postfix({ + 'Current': repo_name.split('/')[-1][:20], + 'Stars': f"{stars:,}", + 'Lang': repo_lang[:8], + 'Cloned': cloned, + 'Failed': failed, + 'Size': format_size(cached_size) if max_size_bytes else 'N/A' + }) + + success = clone_repo( + repo_url, + output_dir, + depth=1 if shallow else None + ) + + if success: + cloned += 1 + else: + failed += 1 + + # Update progress bar after clone (advance by 1) + pbar.update(1) + pbar.refresh() # Force immediate refresh + sys.stderr.flush() # Force flush stderr to ensure progress bar displays + + # Rate limiting: small delay between clones + time.sleep(0.5) + + pbar.close() + + return cloned, failed + + +def download_repos( + output_dir: str = "data/repos", + language: Optional[str] = None, + license: Optional[str] = None, + min_stars: int = 100, + max_repos: int = 50, + shallow: bool = True, + languages: Optional[List[str]] = None, + categories: Optional[List[str]] = None, + max_size_gb: Optional[float] = None, +) -> bool: + """ + Download repositories from GitHub. + + Args: + output_dir: Directory to clone repositories into + language: Single language to filter by + license: License type to filter by + min_stars: Minimum stars + max_repos: Maximum repos to download per category/language + shallow: Use shallow clone (faster, less history) + languages: List of languages to download + categories: List of categories to download (nvim, lua, bash, zsh, python, hacking, security, all-open) + max_size_gb: Maximum total size in GB (stops downloading when reached) + + Returns: + True if successful + """ + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + # Convert GB to bytes + max_size_bytes = int(max_size_gb * 1024**3) if max_size_gb else None + + if max_size_bytes: + current_size = get_directory_size(output_path) + print(f"šŸ“Š Current directory size: {format_size(current_size)}") + if current_size >= max_size_bytes: + print(f"āš ļø Already at size limit: {format_size(current_size)} >= {format_size(max_size_bytes)}") + return False + print(f"šŸ“Š Size limit: {format_size(max_size_bytes)}") + + total_cloned = 0 + total_failed = 0 + + # Download by categories + if categories: + print(f"\nšŸ“¦ Processing {len(categories)} categories...") + + # Overall progress bar for categories + cat_pbar = tqdm( + categories, + desc="Categories", + unit="category", + ncols=100, + position=0, + leave=True, + mininterval=0.1, + maxinterval=1.0, + file=sys.stderr, # Write to stderr to avoid buffering issues + dynamic_ncols=True, # Auto-adjust to terminal width + disable=False, # Explicitly enable + ) + + for category in cat_pbar: + # Check size limit before processing category + if max_size_bytes: + current_size = get_directory_size(output_path) + if current_size >= max_size_bytes: + cat_pbar.close() + print(f"\nāš ļø Size limit reached: {format_size(current_size)} >= {format_size(max_size_bytes)}") + print(f" Stopping all downloads.") + break + + cat_pbar.set_description(f"Category: {category}") + current_size = get_directory_size(output_path) if max_size_bytes else 0 + cat_pbar.set_postfix({ + 'Total Cloned': total_cloned, + 'Total Failed': total_failed, + 'Size': format_size(current_size) if max_size_bytes else 'N/A' + }) + cat_pbar.refresh() # Force refresh + + cloned, failed = download_category( + category=category, + output_dir=output_path, + license=license, + min_stars=min_stars, + max_repos=max_repos, + shallow=shallow, + max_size_bytes=max_size_bytes, + ) + total_cloned += cloned + total_failed += failed + + cat_pbar.close() + + # Download by languages + languages_to_process = languages or ([language] if language else []) + + for lang in languages_to_process: + # Check size limit + if max_size_bytes: + current_size = get_directory_size(output_path) + if current_size >= max_size_bytes: + print(f"\nāš ļø Size limit reached: {format_size(current_size)} >= {format_size(max_size_bytes)}") + break + + print(f"\nšŸ“¦ Processing {lang} repositories...") + + repos = search_github_repos( + language=lang, + license=license or 'mit', + min_stars=min_stars, + max_repos=max_repos, + ) + + if not repos: + print(f" No repositories found for {lang}") + continue + + print(f" Cloning {len(repos)} repositories...") + + # Progress bar for language-based cloning + pbar = tqdm( + total=len(repos), + desc=f"Cloning {lang}", + unit="repo", + ncols=100, + mininterval=0.1, + maxinterval=1.0, + file=sys.stderr, # Write to stderr to avoid buffering issues + dynamic_ncols=True, # Auto-adjust to terminal width + disable=False, # Explicitly enable + ) + + # Cache size to avoid recalculating every iteration + cached_size = get_directory_size(output_path) if max_size_bytes else 0 + size_check_counter = 0 + + for i, repo in enumerate(repos, 1): + # Check size limit every 5 repos + if max_size_bytes: + size_check_counter += 1 + if size_check_counter >= 5: + cached_size = get_directory_size(output_path) + size_check_counter = 0 + if cached_size >= max_size_bytes: + pbar.close() + print(f"\nāš ļø Size limit reached: {format_size(cached_size)} >= {format_size(max_size_bytes)}") + break + + repo_url = repo['clone_url'] + repo_name = repo['full_name'] + stars = repo['stargazers_count'] + + # Update progress bar before clone + pbar.set_postfix({ + 'Current': repo_name.split('/')[-1][:20], + 'Stars': f"{stars:,}", + 'Cloned': total_cloned, + 'Failed': total_failed, + 'Size': format_size(cached_size) if max_size_bytes else 'N/A' + }) + + success = clone_repo( + repo_url, + output_path, + depth=1 if shallow else None + ) + + if success: + total_cloned += 1 + else: + total_failed += 1 + + # Update progress bar after clone (advance by 1) + pbar.update(1) + pbar.refresh() # Force immediate refresh + sys.stderr.flush() # Force flush stderr to ensure progress bar displays + + # Rate limiting + time.sleep(0.5) + + pbar.close() + + final_size = get_directory_size(output_path) if max_size_bytes else 0 + print(f"\nāœ… Download complete!") + print(f" Cloned: {total_cloned}") + print(f" Failed: {total_failed}") + if max_size_bytes: + print(f" Total size: {format_size(final_size)} / {format_size(max_size_bytes)}") + print(f" Location: {output_path.absolute()}") + + return total_cloned > 0 + + +def main(): + parser = argparse.ArgumentParser( + description='Download GitHub repositories with open licenses for code training', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Download Neovim configs + python3 download_repos.py --categories nvim --max-repos 100 + + # Download Lua repos + python3 download_repos.py --categories lua --max-repos 50 + + # Download Bash scripts + python3 download_repos.py --categories bash --max-repos 50 + + # Download ethical hacking repos + python3 download_repos.py --categories hacking --max-repos 100 + + # Download all your categories + python3 download_repos.py --categories nvim lua bash zsh python hacking --max-repos 50 + + # Download with 1 TB size limit + python3 download_repos.py --categories all-open --max-repos 1000 --max-size 1024.0 + + # Download with specific license + python3 download_repos.py --categories nvim --license apache-2.0 --max-repos 50 + """ + ) + parser.add_argument( + '--output', + type=str, + default='data/repos', + help='Output directory (default: data/repos)' + ) + parser.add_argument( + '--language', + type=str, + choices=POPULAR_LANGUAGES, + help='Programming language to filter by' + ) + parser.add_argument( + '--languages', + type=str, + nargs='+', + choices=POPULAR_LANGUAGES, + help='Multiple languages to download' + ) + parser.add_argument( + '--categories', + type=str, + nargs='+', + choices=list(REPO_CATEGORIES.keys()), + help='Categories to download: nvim, lua, bash, zsh, python, hacking, security, all-open' + ) + parser.add_argument( + '--license', + type=str, + choices=OPEN_LICENSES, + default='mit', + help='License type (default: mit)' + ) + parser.add_argument( + '--min-stars', + type=int, + default=100, + help='Minimum stars (default: 100)' + ) + parser.add_argument( + '--max-repos', + type=int, + default=50, + help='Maximum repos per category/language (default: 50)' + ) + parser.add_argument( + '--full-clone', + action='store_true', + help='Do full clone instead of shallow (slower but includes full history)' + ) + parser.add_argument( + '--max-size', + type=float, + help='Maximum total size in GB (stops downloading when reached, e.g., 1024.0 for 1 TB)' + ) + + args = parser.parse_args() + + # Default to categories if nothing specified + if not args.categories and not args.language and not args.languages: + print("ā„¹ļø No categories or languages specified. Use --categories or --language") + print(" Available categories:", ", ".join(REPO_CATEGORIES.keys())) + print(" Example: --categories nvim lua bash hacking") + return + + print("šŸš€ SheepOp Repository Downloader") + print("=" * 50) + + success = download_repos( + output_dir=args.output, + language=args.language, + license=args.license, + min_stars=args.min_stars, + max_repos=args.max_repos, + shallow=not args.full_clone, + languages=args.languages, + categories=args.categories, + max_size_gb=args.max_size, + ) + + if success: + print(f"\nšŸ“š You can now train with:") + print(f" python3 train.py --data {args.output} --config config.json --device cuda") + else: + print("\nāŒ No repositories were downloaded.") + sys.exit(1) + + +if __name__ == '__main__': + main() + diff --git a/example.py b/example.py new file mode 100644 index 0000000..5b03280 --- /dev/null +++ b/example.py @@ -0,0 +1,90 @@ +""" +Example script demonstrating basic usage +""" +import torch +import sys +import importlib.util +from pathlib import Path + +# Ensure current directory is in path +project_root = Path(__file__).parent.absolute() +sys.path.insert(0, str(project_root)) + +# Explicitly import from local data module to avoid conflicts with stdlib 'data' module +data_module_path = project_root / "data" / "__init__.py" +spec = importlib.util.spec_from_file_location("sheepop_data", data_module_path) +sheepop_data = importlib.util.module_from_spec(spec) +spec.loader.exec_module(sheepop_data) +SimpleTokenizer = sheepop_data.SimpleTokenizer + +from models import TransformerModel + + +def example_model_creation(): + """Example of creating a model.""" + print("Creating model...") + + # Create tokenizer + tokenizer = SimpleTokenizer() + print(f"Vocabulary size: {tokenizer.vocab_size}") + + # Create model + model = TransformerModel( + vocab_size=tokenizer.vocab_size, + d_model=256, + num_layers=4, + num_heads=4, + d_ff=1024, + max_seq_len=128, + ) + + print(f"Model created with {model.get_num_params():,} parameters") + + # Test forward pass + input_ids = torch.randint(0, tokenizer.vocab_size, (2, 32)) + logits, _ = model(input_ids) + print(f"Input shape: {input_ids.shape}") + print(f"Output shape: {logits.shape}") + + return model, tokenizer + + +def example_generation(model, tokenizer): + """Example of text generation.""" + print("\nGenerating text...") + + prompt = "Hello, world" + print(f"Prompt: {prompt}") + + # Encode prompt + input_ids = tokenizer.encode(prompt) + input_ids = torch.tensor([input_ids]) + + # Generate + generated = model.generate( + input_ids=input_ids, + max_length=50, + temperature=0.8, + top_k=50, + top_p=0.95, + do_sample=True, + ) + + # Decode + generated_text = tokenizer.decode(generated[0].tolist()) + print(f"Generated: {generated_text}") + + +if __name__ == '__main__': + # Set random seed + torch.manual_seed(42) + + # Create model + model, tokenizer = example_model_creation() + + # Test generation + example_generation(model, tokenizer) + + print("\nExample completed successfully!") + + diff --git a/example_optimized.py b/example_optimized.py new file mode 100644 index 0000000..f100e4b --- /dev/null +++ b/example_optimized.py @@ -0,0 +1,199 @@ +""" +Example usage of optimized inference and retrieval mechanisms +Demonstrates KV caching, retrieval caching, and prefetching for RAG systems +""" +import torch +import sys +import importlib.util +from pathlib import Path + +# Ensure current directory is in path +project_root = Path(__file__).parent.absolute() +sys.path.insert(0, str(project_root)) + +# Explicitly import from local data module to avoid conflicts with stdlib 'data' module +data_module_path = project_root / "data" / "__init__.py" +spec = importlib.util.spec_from_file_location("sheepop_data", data_module_path) +sheepop_data = importlib.util.module_from_spec(spec) +spec.loader.exec_module(sheepop_data) +SimpleTokenizer = sheepop_data.SimpleTokenizer +create_dataloader = sheepop_data.create_dataloader + +from models import TransformerModel, OptimizedInference, RetrievalCache +from models.prefetching import PrefetchDataLoader, LookaheadRetriever + + +def example_optimized_inference(): + """Example: Using optimized inference with KV caching.""" + print("=" * 60) + print("Example: Optimized Inference with KV Caching") + print("=" * 60) + + # Create model (example configuration) + model = TransformerModel( + vocab_size=128, + d_model=512, + num_layers=6, + num_heads=8, + ) + + # Get optimized inference utility + optimizer = model.get_optimized_inference() + + # Example prompt + tokenizer = SimpleTokenizer() + prompt = "The future of AI" + input_ids = torch.tensor([tokenizer.encode(prompt)]) + + # Generate with KV caching (faster for autoregressive generation) + generated = optimizer.generate_with_cache( + input_ids=input_ids, + max_length=50, + temperature=0.8, + top_k=50, + top_p=0.95, + ) + + print(f"Generated: {tokenizer.decode(generated[0].tolist())}") + print() + + +def example_retrieval_caching(): + """Example: Using retrieval cache for similar queries.""" + print("=" * 60) + print("Example: Retrieval Caching") + print("=" * 60) + + # Create retrieval cache + cache = RetrievalCache(max_size=1000, similarity_threshold=0.9) + + # Example: Simulate retrieval function + def retrieve_documents(query: str): + """Mock retrieval function.""" + return [ + {"doc_id": "1", "text": f"Document about {query}", "score": 0.95}, + {"doc_id": "2", "text": f"Another document about {query}", "score": 0.92}, + ] + + # Create query embeddings (simplified) + query1 = "What is machine learning?" + query1_embedding = torch.randn(128) # Example embedding + + query2 = "What is deep learning?" # Similar query + query2_embedding = torch.randn(128) # Example embedding (would be similar in practice) + + # Store first query + import hashlib + query1_hash = hashlib.md5(query1.encode()).hexdigest() + results1 = retrieve_documents(query1) + cache.set(query1_hash, query1_embedding, results1) + + # Retrieve from cache (should find similar query) + query2_hash = hashlib.md5(query2.encode()).hexdigest() + cached_results = cache.get(query2_hash, query2_embedding) + + if cached_results: + print(f"Found cached results for query: {query2}") + print(f"Retrieved {len(cached_results)} documents") + else: + print("Cache miss, performing retrieval...") + results = retrieve_documents(query2) + cache.set(query2_hash, query2_embedding, results) + + print() + + +def example_prefetching(): + """Example: Using prefetching for data loading.""" + print("=" * 60) + print("Example: Prefetching DataLoader") + print("=" * 60) + + # Create sample data + texts = ["This is a sample text."] * 100 + tokenizer = SimpleTokenizer() + + # Create standard dataloader + dataloader = create_dataloader( + texts=texts, + tokenizer=tokenizer, + batch_size=32, + max_length=512, + ) + + # Wrap with prefetching + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + prefetch_loader = PrefetchDataLoader( + dataloader=dataloader, + prefetch_factor=2, + device=device, + ) + + print(f"Created prefetch loader with {len(prefetch_loader)} batches") + print("Prefetching batches in background thread...") + print() + + +def example_batch_generation(): + """Example: Batch generation for multiple prompts.""" + print("=" * 60) + print("Example: Batch Generation") + print("=" * 60) + + # Create model + model = TransformerModel( + vocab_size=128, + d_model=512, + num_layers=6, + num_heads=8, + ) + + # Get optimized inference utility + optimizer = model.get_optimized_inference() + + # Multiple prompts + tokenizer = SimpleTokenizer() + prompts = [ + "The future of AI", + "Machine learning applications", + "Deep learning advances", + ] + + input_ids_list = [torch.tensor([tokenizer.encode(p)]) for p in prompts] + + # Generate for all prompts in batches + results = optimizer.batch_generate( + input_ids_list=input_ids_list, + max_length=30, + temperature=0.8, + batch_size=2, + ) + + print(f"Generated {len(results)} responses:") + for i, (prompt, result) in enumerate(zip(prompts, results)): + # result is already a tensor [batch_size, seq_len], get first item if batch_size > 1 + if result.dim() > 1 and result.shape[0] > 1: + generated_ids = result[0].tolist() + else: + generated_ids = result.squeeze(0).tolist() if result.dim() > 1 else result.tolist() + generated_text = tokenizer.decode(generated_ids) + print(f"{i+1}. Prompt: {prompt}") + print(f" Generated: {generated_text[:50]}...") + print() + + +if __name__ == '__main__': + print("\n" + "=" * 60) + print("Optimized RAG System Examples") + print("=" * 60 + "\n") + + # Run examples + example_optimized_inference() + example_retrieval_caching() + example_prefetching() + example_batch_generation() + + print("=" * 60) + print("All examples completed!") + print("=" * 60) + diff --git a/extensions.txt b/extensions.txt new file mode 100644 index 0000000..65eeb4f --- /dev/null +++ b/extensions.txt @@ -0,0 +1,369 @@ +_10083 +_10083-h +1996 +1997 +1998 +1999 +2000 +2001 +2002 +2003 +2004 +20041201 +2005 +2006 +2007 +2008 +2009 +2010 +2011 +2012 +2013 +2014 +2015 +2016 +2017 +2018 +2019 +2020 +2020-06-27 +2020-12-19 +2020-12-28 +2021 +2021-01-26 +2021-01-27 +2022 +2023 +2024 +2025 +aaaa +aaab +aaac +aaad +aaae +aaaf +aaag +aaah +aaai +aaaj +aaak +aaal +aaam +aaan +aaao +aaap +aaaq +aaar +aaas +aaat +aaau +aaav +aaaw +aaax +aaay +aaaz +aaba +aabb +aabc +aabd +aabe +aabf +aabg +aabh +aabi +aabj +aabk +aabl +aabm +aabn +aabo +aabp +aabq +aabr +aabs +aabt +aabu +aabv +aabw +aabx +aaby +aabz +aaca +aacb +aacc +aacd +aace +aacf +aacg +aach +aaci +aacj +aack +aacl +aacm +aacn +aaco +aacp +aacq +aacr +aacs +aact +aacu +aacv +aacw +aacx +aacy +aacz +aada +aadb +aadc +aadd +aade +aadf +aadg +aadh +aadi +aadj +aadk +aadl +aadm +aadn +aado +aadp +aadq +aadr +aads +aadt +aadu +aadv +aadw +aadx +aady +aadz +aaea +aaeb +aaec +aaed +aaee +aaef +aaeg +aaeh +aaei +aaej +aaek +aael +aaem +aaen +aaeo +aaep +aaeq +aaer +aaes +aaet +aaeu +aaev +aaew +aaex +aaey +aaez +aafa +aafb +aafc +aafd +aafe +aaff +aafg +aafh +aafi +aafj +aafk +aafl +aafm +aafn +aafo +aafp +aafq +aafr +aafs +aaft +aafu +aafv +aafw +aafx +aafy +aafz +aaga +aagb +aagc +aagd +aage +aagf +aagg +aagh +aagi +aagj +aagk +aagl +aagm +aagn +aago +aagp +aagq +aagr +aags +aagt +aagu +aagv +aagw +aagx +aagy +aagz +aaha +aahb +aahc +aahd +aahe +aahf +aahg +aahh +aahi +aahj +aahk +aahl +aahm +aahn +aaho +aahp +aahq +aahr +aahs +aaht +aahu +aahv +aahw +aahx +aahy +aahz +aaia +aaib +aaic +aaid +aaie +aaif +aaig +aaih +aaii +aaij +aaik +aail +aaim +aain +aaio +aaip +aaiq +aair +aais +aait +aaiu +aaiv +aaiw +aaix +aaiy +aaiz +aaja +aajb +aajc +aajd +aaje +aajf +aajg +aajh +aaji +aajj +aajk +aajl +aajm +aajn +aajo +aajp +aajq +aajr +aajs +aajt +aaju +aajv +aajw +aajx +aajy +aajz +aaka +aakb +aakc +aakd +aake +aakf +aakg +aakh +aaki +aakj +aakk +aakl +aakm +aakn +aako +aakp +ALL +AUS +brl +bz2 +cache +css +db +dcs +doc +DS_Store +eps +gif +gz +htm +ico +_images +iso +jigdo +jpg +JPG +jsonl +lit +ly +md5 +message +mid +mp3 +_old +PAR +PAR2 +pdf +png +prc +py +pyc +rar +rtf +selected_editor +sfv +sh +sib +static +svg +tar +template +tex +txt +txt~ +TXT +xml +zcompdump +zip +zip~ +zshrc +zst \ No newline at end of file diff --git a/extract_from_database.py b/extract_from_database.py new file mode 100644 index 0000000..1d3cb54 --- /dev/null +++ b/extract_from_database.py @@ -0,0 +1,304 @@ +""" +Database extraction utility for training data +Extracts text from various database types and formats for LLM training +""" +import sqlite3 +import argparse +from pathlib import Path +from typing import List, Optional, Iterator +import json + + +def extract_from_sqlite( + db_path: str, + table: str, + text_column: str, + limit: Optional[int] = None, + where_clause: Optional[str] = None, +) -> Iterator[str]: + """ + Extract text from SQLite database. + + Args: + db_path: Path to SQLite database file + table: Table name to extract from + text_column: Column name containing text data + limit: Maximum number of rows to extract (None = all) + where_clause: Optional WHERE clause (e.g., "WHERE length(text) > 100") + + Yields: + Text strings from the database + """ + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + query = f"SELECT {text_column} FROM {table}" + if where_clause: + query += f" {where_clause}" + if limit: + query += f" LIMIT {limit}" + + cursor.execute(query) + + for row in cursor: + text = row[0] + if text and isinstance(text, str) and len(text.strip()) > 0: + # Clean and split text into sentences/lines + cleaned_text = text.strip() + yield cleaned_text + + conn.close() + + +def extract_from_sql( + connection_string: str, + query: str, + text_column: int = 0, + batch_size: int = 1000, +) -> Iterator[str]: + """ + Extract text using a raw SQL query. + Works with any database that supports the connection string format. + + Args: + connection_string: Database connection string + query: SQL query to execute + text_column: Column index containing text (0-based) + batch_size: Number of rows to fetch at once + + Yields: + Text strings from the database + """ + try: + import psycopg2 # PostgreSQL + conn = psycopg2.connect(connection_string) + except ImportError: + try: + import pymysql # MySQL + conn = pymysql.connect(connection_string) + except ImportError: + raise ImportError("Install psycopg2 for PostgreSQL or pymysql for MySQL") + + cursor = conn.cursor() + cursor.execute(query) + + while True: + rows = cursor.fetchmany(batch_size) + if not rows: + break + + for row in rows: + text = row[text_column] + if text and isinstance(text, str) and len(text.strip()) > 0: + yield text.strip() + + conn.close() + + +def extract_from_json_file( + json_path: str, + text_field: str, + limit: Optional[int] = None, +) -> Iterator[str]: + """ + Extract text from JSON file (e.g., JSONL format). + + Args: + json_path: Path to JSON file + text_field: Field name containing text (use dot notation for nested: "data.text") + limit: Maximum number of records to extract + + Yields: + Text strings from the JSON file + """ + with open(json_path, 'r', encoding='utf-8') as f: + count = 0 + for line in f: + if limit and count >= limit: + break + + try: + data = json.loads(line) + + # Handle nested fields with dot notation + fields = text_field.split('.') + value = data + for field in fields: + value = value.get(field) + if value is None: + break + + if value and isinstance(value, str) and len(value.strip()) > 0: + yield value.strip() + count += 1 + except json.JSONDecodeError: + continue + + +def clean_and_split_text(text: str, min_length: int = 10) -> List[str]: + """ + Clean text and split into sentences/lines. + + Args: + text: Raw text string + min_length: Minimum length for a text sample + + Returns: + List of cleaned text samples + """ + import re + + # Remove extra whitespace + text = re.sub(r'\s+', ' ', text) + + # Split by sentences (periods, exclamation, question marks) + sentences = re.split(r'[.!?]+\s+', text) + + # Also split by newlines + lines = [] + for sentence in sentences: + lines.extend(sentence.split('\n')) + + # Clean and filter + cleaned = [] + for line in lines: + line = line.strip() + if len(line) >= min_length: + cleaned.append(line) + + return cleaned + + +def save_to_training_file( + texts: Iterator[str], + output_path: str, + min_length: int = 10, + max_samples: Optional[int] = None, + clean_text: bool = True, +): + """ + Save extracted texts to training file. + + Args: + texts: Iterator of text strings + output_path: Path to save training data + min_length: Minimum length for text samples + max_samples: Maximum number of samples to save + clean_text: Whether to clean and split text + """ + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + count = 0 + total_texts = 0 + + with open(output_path, 'w', encoding='utf-8') as f: + for text in texts: + if max_samples and count >= max_samples: + break + + if clean_text: + # Clean and split into sentences + cleaned_texts = clean_and_split_text(text, min_length) + for cleaned in cleaned_texts: + if max_samples and count >= max_samples: + break + f.write(cleaned + '\n') + count += 1 + else: + # Write as-is + if len(text.strip()) >= min_length: + f.write(text.strip() + '\n') + count += 1 + + total_texts += 1 + + # Progress update every 1000 texts + if total_texts % 1000 == 0: + print(f"Processed {total_texts} texts, saved {count} samples...") + + print(f"\nāœ… Extraction complete!") + print(f" Total texts processed: {total_texts}") + print(f" Samples saved: {count}") + print(f" Output file: {output_path}") + print(f" File size: {output_path.stat().st_size / (1024*1024):.2f} MB") + + +def main(): + parser = argparse.ArgumentParser(description='Extract text from database for training') + parser.add_argument('--type', type=str, choices=['sqlite', 'sql', 'json'], + required=True, help='Database type') + parser.add_argument('--output', type=str, default='data/database_extracted.txt', + help='Output file path') + parser.add_argument('--limit', type=int, help='Maximum number of samples to extract') + parser.add_argument('--min-length', type=int, default=10, + help='Minimum text length') + + # SQLite options + parser.add_argument('--db-path', type=str, help='SQLite database path') + parser.add_argument('--table', type=str, help='Table name') + parser.add_argument('--column', type=str, help='Text column name') + parser.add_argument('--where', type=str, help='WHERE clause (e.g., "WHERE length(text) > 100")') + + # SQL query options + parser.add_argument('--connection', type=str, help='Database connection string') + parser.add_argument('--query', type=str, help='SQL query') + parser.add_argument('--text-column', type=int, default=0, help='Text column index (0-based)') + + # JSON options + parser.add_argument('--json-path', type=str, help='JSON/JSONL file path') + parser.add_argument('--text-field', type=str, help='JSON field name containing text') + + parser.add_argument('--no-clean', action='store_true', help='Do not clean/split text') + + args = parser.parse_args() + + # Extract based on type + if args.type == 'sqlite': + if not all([args.db_path, args.table, args.column]): + print("Error: --db-path, --table, and --column required for SQLite") + return + + texts = extract_from_sqlite( + db_path=args.db_path, + table=args.table, + text_column=args.column, + limit=args.limit, + where_clause=args.where, + ) + + elif args.type == 'sql': + if not all([args.connection, args.query]): + print("Error: --connection and --query required for SQL") + return + + texts = extract_from_sql( + connection_string=args.connection, + query=args.query, + text_column=args.text_column, + ) + + elif args.type == 'json': + if not all([args.json_path, args.text_field]): + print("Error: --json-path and --text-field required for JSON") + return + + texts = extract_from_json_file( + json_path=args.json_path, + text_field=args.text_field, + limit=args.limit, + ) + + # Save to training file + save_to_training_file( + texts=texts, + output_path=args.output, + min_length=args.min_length, + max_samples=args.limit, + clean_text=not args.no_clean, + ) + + +if __name__ == '__main__': + main() + diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..ac82cd2 --- /dev/null +++ b/inference.py @@ -0,0 +1,444 @@ +""" +Inference script for generating text +Optimized for production RAG systems with KV caching and efficient inference +""" +import torch +import argparse +from pathlib import Path +import sys +import importlib.util +import time + +# Ensure current directory is in path +project_root = Path(__file__).parent.absolute() +sys.path.insert(0, str(project_root)) + +# Explicitly import from local data module to avoid conflicts with stdlib 'data' module +data_module_path = project_root / "data" / "__init__.py" +spec = importlib.util.spec_from_file_location("sheepop_data", data_module_path) +sheepop_data = importlib.util.module_from_spec(spec) +spec.loader.exec_module(sheepop_data) +SimpleTokenizer = sheepop_data.SimpleTokenizer + +from models import TransformerModel +from models.optimized_attention import OptimizedInference +from inference_metrics import InferenceMetrics + + +def load_model(checkpoint_path: str, device: str = 'cuda', tokenizer=None): + """Load model from checkpoint.""" + checkpoint = torch.load(checkpoint_path, map_location=device) + + # Get model config from checkpoint or use defaults + model_config = checkpoint.get('model_config', {}) + + # If no config in checkpoint, try to infer from model state dict or use defaults + if not model_config: + print("Warning: No model_config found in checkpoint. Using defaults.") + # Try to infer vocab_size from tokenizer if provided + if tokenizer is not None: + vocab_size = tokenizer.vocab_size + else: + # Default vocab size - should match your tokenizer + vocab_size = 128 # Default for SimpleTokenizer + + model_config = { + 'vocab_size': vocab_size, + 'd_model': 512, + 'num_layers': 6, + 'num_heads': 8, + 'd_ff': 2048, + 'max_seq_len': 512, + 'dropout': 0.1, + 'activation': 'gelu', + } + print(f"Using default config with vocab_size={vocab_size}") + + model = TransformerModel(**model_config) + + model.load_state_dict(checkpoint['model_state_dict']) + model.to(device) + model.eval() + + return model + + +def generate_text( + model: TransformerModel, + tokenizer: SimpleTokenizer, + prompt: str, + max_length: int = 100, + temperature: float = 1.0, + top_k: int = 50, + top_p: float = 0.95, + device: str = 'cuda', + optimized: bool = False, +): + """ + Generate text from a prompt. + + Returns: + tuple: (generated_text, generated_ids, input_ids, generation_time) + """ + # Encode prompt + input_ids = tokenizer.encode(prompt) + input_ids = torch.tensor([input_ids], device=device) + + # Measure generation time + start_time = time.time() + + if optimized: + optimizer = model.get_optimized_inference() + generated = optimizer.generate_with_cache( + input_ids=input_ids, + max_length=max_length, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + else: + generated = model.generate( + input_ids=input_ids, + max_length=max_length, + temperature=temperature, + top_k=top_k, + top_p=top_p, + do_sample=True, + ) + + generation_time = time.time() - start_time + + # Decode + generated_ids = generated[0].cpu().tolist() + generated_text = tokenizer.decode(generated_ids) + + return generated_text, generated_ids, input_ids, generation_time + + +def get_memory_usage(device: torch.device) -> float: + """Get current memory usage in MB.""" + if device.type == 'cuda': + return torch.cuda.memory_allocated(device) / (1024 ** 2) # MB + elif device.type == 'mps': + # MPS doesn't have direct memory query, return None + return None + else: + return None + + +def benchmark_inference( + model: TransformerModel, + tokenizer: SimpleTokenizer, + prompt: str, + max_length: int, + temperature: float, + top_k: int, + top_p: float, + device: torch.device, + metrics: InferenceMetrics, + run_name: str, +): + """Run benchmark for both optimized and non-optimized inference.""" + + def remove_trailing_padding(token_ids, pad_token_id): + """Remove trailing padding tokens.""" + while token_ids and token_ids[-1] == pad_token_id: + token_ids.pop() + return token_ids + + print("\n" + "=" * 70) + print(f"BENCHMARK RUN: {run_name}") + print("=" * 70) + + results = {} + + # Run non-optimized first + print("\nšŸ”“ Running NON-OPTIMIZED inference...") + if device.type == 'cuda': + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(device) + + memory_before = get_memory_usage(device) + + generated_text, generated_ids, input_ids, gen_time = generate_text( + model=model, + tokenizer=tokenizer, + prompt=prompt, + max_length=max_length, + temperature=temperature, + top_k=top_k, + top_p=top_p, + device=str(device), + optimized=False, + ) + + memory_after = get_memory_usage(device) + memory_used = memory_after - memory_before if memory_after and memory_before else None + + generated_ids = remove_trailing_padding(generated_ids, tokenizer.pad_token_id) + prompt_length = len(input_ids[0]) + generated_length = len(generated_ids) - prompt_length + + if generated_length > 0: + tokens_per_sec = generated_length / gen_time + time_per_token = (gen_time / generated_length) * 1000 # ms + else: + tokens_per_sec = 0 + time_per_token = 0 + + results['non_optimized'] = { + 'text': generated_text, + 'prompt_length': prompt_length, + 'generated_length': generated_length, + 'total_time': gen_time, + 'tokens_per_sec': tokens_per_sec, + 'time_per_token': time_per_token, + 'memory_mb': memory_used, + } + + print(f" ā±ļø Total Time: {gen_time:.3f} s") + print(f" šŸ“Š Tokens/Second: {tokens_per_sec:.2f}") + print(f" ⚔ Time/Token: {time_per_token:.3f} ms") + if memory_used: + print(f" šŸ’¾ Memory Used: {memory_used:.1f} MB") + print(f" šŸ“ Generated: {generated_text[:100]}...") + + # Log metrics + metrics.log_run( + run_name=f"{run_name}_non_optimized", + optimized=False, + prompt_length=prompt_length, + generated_length=generated_length, + total_time=gen_time, + tokens_per_second=tokens_per_sec, + time_per_token=time_per_token, + memory_used_mb=memory_used, + device=str(device), + ) + + # Run optimized + print("\n🟢 Running OPTIMIZED inference...") + if device.type == 'cuda': + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(device) + + memory_before = get_memory_usage(device) + + generated_text, generated_ids, input_ids, gen_time = generate_text( + model=model, + tokenizer=tokenizer, + prompt=prompt, + max_length=max_length, + temperature=temperature, + top_k=top_k, + top_p=top_p, + device=str(device), + optimized=True, + ) + + memory_after = get_memory_usage(device) + memory_used = memory_after - memory_before if memory_after and memory_before else None + + generated_ids = remove_trailing_padding(generated_ids, tokenizer.pad_token_id) + prompt_length = len(input_ids[0]) + generated_length = len(generated_ids) - prompt_length + + if generated_length > 0: + tokens_per_sec = generated_length / gen_time + time_per_token = (gen_time / generated_length) * 1000 # ms + else: + tokens_per_sec = 0 + time_per_token = 0 + + results['optimized'] = { + 'text': generated_text, + 'prompt_length': prompt_length, + 'generated_length': generated_length, + 'total_time': gen_time, + 'tokens_per_sec': tokens_per_sec, + 'time_per_token': time_per_token, + 'memory_mb': memory_used, + } + + print(f" ā±ļø Total Time: {gen_time:.3f} s") + print(f" šŸ“Š Tokens/Second: {tokens_per_sec:.2f}") + print(f" ⚔ Time/Token: {time_per_token:.3f} ms") + if memory_used: + print(f" šŸ’¾ Memory Used: {memory_used:.1f} MB") + print(f" šŸ“ Generated: {generated_text[:100]}...") + + # Log metrics + metrics.log_run( + run_name=f"{run_name}_optimized", + optimized=True, + prompt_length=prompt_length, + generated_length=generated_length, + total_time=gen_time, + tokens_per_second=tokens_per_sec, + time_per_token=time_per_token, + memory_used_mb=memory_used, + device=str(device), + ) + + # Calculate speedup + if results['non_optimized']['tokens_per_sec'] > 0: + speedup = results['optimized']['tokens_per_sec'] / results['non_optimized']['tokens_per_sec'] + print(f"\nšŸš€ SPEEDUP: {speedup:.2f}x faster with optimizations") + + if results['non_optimized']['memory_mb'] and results['optimized']['memory_mb']: + memory_reduction = (1 - results['optimized']['memory_mb'] / results['non_optimized']['memory_mb']) * 100 + print(f"šŸ’¾ MEMORY REDUCTION: {memory_reduction:.1f}%") + + print("=" * 70) + + return results + + +def main(): + parser = argparse.ArgumentParser(description='Generate text with SheepOp LLM') + parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint') + parser.add_argument('--prompt', type=str, required=True, help='Prompt text') + parser.add_argument('--max-length', type=int, default=100, help='Maximum generation length') + parser.add_argument('--temperature', type=float, default=1.0, help='Sampling temperature') + parser.add_argument('--top-k', type=int, default=50, help='Top-k sampling') + parser.add_argument('--top-p', type=float, default=0.95, help='Top-p (nucleus) sampling') + parser.add_argument('--device', type=str, default='cuda', help='Device to use') + parser.add_argument('--optimized', action='store_true', help='Use optimized inference with KV caching') + parser.add_argument('--benchmark', action='store_true', help='Run benchmark comparing optimized vs non-optimized inference (for research)') + parser.add_argument('--benchmark-dir', type=str, default='./inference_benchmarks', help='Directory to save benchmark results') + + args = parser.parse_args() + + # Setup device + if args.device == 'cuda' and torch.cuda.is_available(): + device = torch.device('cuda') + elif args.device == 'mps' and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + device = torch.device('mps') + else: + device = torch.device('cpu') + print(f"Using device: {device}") + + # Create tokenizer first (needed for vocab_size) + tokenizer = SimpleTokenizer() + + # Load model + print("Loading model...") + model = load_model(args.checkpoint, device=device, tokenizer=tokenizer) + print("Model loaded!") + + # Check if benchmarking mode + if args.benchmark: + print("\nšŸ”¬ BENCHMARK MODE: Comparing optimized vs non-optimized inference") + print("=" * 70) + + # Initialize metrics + metrics = InferenceMetrics(save_dir=args.benchmark_dir) + + # Run benchmark + run_name = f"run_{int(time.time())}" + results = benchmark_inference( + model=model, + tokenizer=tokenizer, + prompt=args.prompt, + max_length=args.max_length, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + device=device, + metrics=metrics, + run_name=run_name, + ) + + # Generate plots and summary + print("\nšŸ“Š Generating comparison plots and data...") + metrics.plot_comparison() + metrics.plot_performance_over_time() + metrics.export_to_csv() + metrics.print_summary() + + print(f"\nāœ… Benchmark complete! Results saved to: {args.benchmark_dir}") + print(f" - JSON metrics: {args.benchmark_dir}/inference_metrics.json") + print(f" - CSV export: {args.benchmark_dir}/inference_metrics.csv") + print(f" - Comparison plot: {args.benchmark_dir}/optimization_comparison.png") + print(f" - Performance plot: {args.benchmark_dir}/performance_over_time.png") + + return + + # Normal inference mode + use_optimized = args.optimized if hasattr(args, 'optimized') else False + + if use_optimized: + print("Using optimized inference with KV caching...") + optimizer = model.get_optimized_inference() + else: + optimizer = None + + # Encode prompt + input_ids = tokenizer.encode(args.prompt) + input_ids = torch.tensor([input_ids], device=device) + + # Generate text + print(f"Prompt: {args.prompt}") + print("Generating...") + + # Filter out padding tokens from the end of generated sequence + def remove_trailing_padding(token_ids, pad_token_id): + """Remove trailing padding tokens.""" + while token_ids and token_ids[-1] == pad_token_id: + token_ids.pop() + return token_ids + + if optimizer is not None: + # Use optimized generation with KV cache + generated = optimizer.generate_with_cache( + input_ids=input_ids, + max_length=args.max_length, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + ) + generated_ids = generated[0].cpu().tolist() + # Remove trailing padding + generated_ids = remove_trailing_padding(generated_ids, tokenizer.pad_token_id) + print(f"Generated {len(generated_ids)} tokens (input had {len(input_ids[0])} tokens, after removing padding)") + else: + # Use standard generation + generated = model.generate( + input_ids=input_ids, + max_length=args.max_length, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + do_sample=True, + ) + generated_ids = generated[0].cpu().tolist() + # Remove trailing padding + generated_ids = remove_trailing_padding(generated_ids, tokenizer.pad_token_id) + print(f"Generated {len(generated_ids)} tokens (input had {len(input_ids[0])} tokens, after removing padding)") + + # Debug: Show some token statistics + vocab_size = tokenizer.vocab_size + valid_tokens = sum(1 for tid in generated_ids if tid in tokenizer.inv_vocab) + unk_tokens = sum(1 for tid in generated_ids if tid not in tokenizer.inv_vocab) + pad_tokens = sum(1 for tid in generated_ids if tid == tokenizer.pad_token_id) + + print(f"Token statistics:") + print(f" Valid tokens: {valid_tokens}/{len(generated_ids)}") + print(f" Unknown tokens: {unk_tokens}") + print(f" Pad tokens: {pad_tokens}") + print(f" Vocab size: {vocab_size}") + print(f" Token ID range: {min(generated_ids) if generated_ids else 'N/A'} - {max(generated_ids) if generated_ids else 'N/A'}") + + # Show first 20 token IDs for debugging + print(f" First 20 token IDs: {generated_ids[:20]}") + + generated_text = tokenizer.decode(generated_ids) + + print(f"\nGenerated: {generated_text}") + print(f"Generated length: {len(generated_text)} characters") + + +if __name__ == '__main__': + main() + + diff --git a/inference_metrics.py b/inference_metrics.py new file mode 100644 index 0000000..c753b4e --- /dev/null +++ b/inference_metrics.py @@ -0,0 +1,376 @@ +""" +Inference metrics tracking and benchmarking utilities +For research purposes: comparing optimized vs non-optimized inference +""" +import json +import matplotlib.pyplot as plt +from pathlib import Path +from typing import Dict, List, Optional +import numpy as np +import time +import torch + + +class InferenceMetrics: + """ + Track and plot inference metrics for benchmarking optimizations. + """ + + def __init__(self, save_dir: str = './inference_benchmarks'): + """ + Args: + save_dir: Directory to save metrics and plots + """ + self.save_dir = Path(save_dir) + self.save_dir.mkdir(parents=True, exist_ok=True) + + self.metrics_file = self.save_dir / 'inference_metrics.json' + + # Load existing metrics if available + if self.metrics_file.exists(): + with open(self.metrics_file, 'r') as f: + self.metrics = json.load(f) + else: + self.metrics = { + 'runs': [], + } + + def log_run( + self, + run_name: str, + optimized: bool, + prompt_length: int, + generated_length: int, + total_time: float, + tokens_per_second: float, + time_per_token: float, + memory_used_mb: Optional[float] = None, + gpu_utilization: Optional[float] = None, + device: str = 'cuda', + ): + """ + Log a single inference run. + + Args: + run_name: Name/ID of the run + optimized: Whether optimized inference was used + prompt_length: Length of input prompt in tokens + generated_length: Length of generated text in tokens + total_time: Total generation time in seconds + tokens_per_second: Tokens generated per second + time_per_token: Time per token in milliseconds + memory_used_mb: Memory used in MB (optional) + gpu_utilization: GPU utilization percentage (optional) + device: Device used ('cuda', 'cpu', 'mps') + """ + run_data = { + 'run_name': run_name, + 'timestamp': time.time(), + 'optimized': optimized, + 'prompt_length': prompt_length, + 'generated_length': generated_length, + 'total_time': total_time, + 'tokens_per_second': tokens_per_second, + 'time_per_token': time_per_token, + 'memory_used_mb': memory_used_mb, + 'gpu_utilization': gpu_utilization, + 'device': device, + } + + self.metrics['runs'].append(run_data) + self.save() + + def save(self): + """Save metrics to JSON file.""" + with open(self.metrics_file, 'w') as f: + json.dump(self.metrics, f, indent=2) + + def get_comparison_data(self) -> Dict: + """ + Get comparison data for optimized vs non-optimized runs. + + Returns: + Dictionary with comparison statistics + """ + runs = self.metrics['runs'] + + optimized_runs = [r for r in runs if r['optimized']] + non_optimized_runs = [r for r in runs if not r['optimized']] + + comparison = { + 'optimized': { + 'count': len(optimized_runs), + 'avg_tokens_per_sec': np.mean([r['tokens_per_second'] for r in optimized_runs]) if optimized_runs else 0, + 'avg_time_per_token': np.mean([r['time_per_token'] for r in optimized_runs]) if optimized_runs else 0, + 'avg_total_time': np.mean([r['total_time'] for r in optimized_runs]) if optimized_runs else 0, + 'avg_memory_mb': np.mean([r['memory_used_mb'] for r in optimized_runs if r['memory_used_mb']]) if optimized_runs else None, + 'avg_gpu_util': np.mean([r['gpu_utilization'] for r in optimized_runs if r['gpu_utilization']]) if optimized_runs else None, + }, + 'non_optimized': { + 'count': len(non_optimized_runs), + 'avg_tokens_per_sec': np.mean([r['tokens_per_second'] for r in non_optimized_runs]) if non_optimized_runs else 0, + 'avg_time_per_token': np.mean([r['time_per_token'] for r in non_optimized_runs]) if non_optimized_runs else 0, + 'avg_total_time': np.mean([r['total_time'] for r in non_optimized_runs]) if non_optimized_runs else 0, + 'avg_memory_mb': np.mean([r['memory_used_mb'] for r in non_optimized_runs if r['memory_used_mb']]) if non_optimized_runs else None, + 'avg_gpu_util': np.mean([r['gpu_utilization'] for r in non_optimized_runs if r['gpu_utilization']]) if non_optimized_runs else None, + }, + } + + # Calculate speedup + if comparison['non_optimized']['avg_tokens_per_sec'] > 0: + speedup = comparison['optimized']['avg_tokens_per_sec'] / comparison['non_optimized']['avg_tokens_per_sec'] + comparison['speedup'] = speedup + else: + comparison['speedup'] = None + + # Calculate memory reduction + if comparison['optimized']['avg_memory_mb'] and comparison['non_optimized']['avg_memory_mb']: + memory_reduction = (1 - comparison['optimized']['avg_memory_mb'] / comparison['non_optimized']['avg_memory_mb']) * 100 + comparison['memory_reduction_percent'] = memory_reduction + else: + comparison['memory_reduction_percent'] = None + + return comparison + + def plot_comparison(self, save_path: Optional[str] = None): + """ + Plot comparison charts for optimized vs non-optimized inference. + + Args: + save_path: Path to save plot (default: save_dir/optimization_comparison.png) + """ + if save_path is None: + save_path = self.save_dir / 'optimization_comparison.png' + + comparison = self.get_comparison_data() + + if comparison['optimized']['count'] == 0 or comparison['non_optimized']['count'] == 0: + print("āš ļø Need both optimized and non-optimized runs for comparison") + return + + fig, axes = plt.subplots(2, 2, figsize=(15, 12)) + + # Plot 1: Tokens per Second + ax1 = axes[0, 0] + categories = ['Optimized', 'Non-Optimized'] + tokens_per_sec = [ + comparison['optimized']['avg_tokens_per_sec'], + comparison['non_optimized']['avg_tokens_per_sec'] + ] + colors = ['#2ecc71', '#e74c3c'] + bars = ax1.bar(categories, tokens_per_sec, color=colors, alpha=0.7, edgecolor='black', linewidth=1.5) + ax1.set_ylabel('Tokens per Second', fontsize=12) + ax1.set_title('Generation Speed: Tokens per Second', fontsize=14, fontweight='bold') + ax1.grid(True, alpha=0.3, axis='y') + + # Add value labels on bars + for bar, value in zip(bars, tokens_per_sec): + height = bar.get_height() + ax1.text(bar.get_x() + bar.get_width()/2., height, + f'{value:.1f}', + ha='center', va='bottom', fontsize=11, fontweight='bold') + + # Add speedup annotation + if comparison['speedup']: + speedup_text = f"Speedup: {comparison['speedup']:.2f}x" + ax1.text(0.5, 0.95, speedup_text, transform=ax1.transAxes, + ha='center', va='top', fontsize=12, fontweight='bold', + bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.5)) + + # Plot 2: Time per Token + ax2 = axes[0, 1] + time_per_token = [ + comparison['optimized']['avg_time_per_token'], + comparison['non_optimized']['avg_time_per_token'] + ] + bars = ax2.bar(categories, time_per_token, color=colors, alpha=0.7, edgecolor='black', linewidth=1.5) + ax2.set_ylabel('Time per Token (ms)', fontsize=12) + ax2.set_title('Latency: Time per Token', fontsize=14, fontweight='bold') + ax2.grid(True, alpha=0.3, axis='y') + + for bar, value in zip(bars, time_per_token): + height = bar.get_height() + ax2.text(bar.get_x() + bar.get_width()/2., height, + f'{value:.2f} ms', + ha='center', va='bottom', fontsize=11, fontweight='bold') + + # Plot 3: Total Generation Time + ax3 = axes[1, 0] + total_time = [ + comparison['optimized']['avg_total_time'], + comparison['non_optimized']['avg_total_time'] + ] + bars = ax3.bar(categories, total_time, color=colors, alpha=0.7, edgecolor='black', linewidth=1.5) + ax3.set_ylabel('Total Time (seconds)', fontsize=12) + ax3.set_title('Total Generation Time', fontsize=14, fontweight='bold') + ax3.grid(True, alpha=0.3, axis='y') + + for bar, value in zip(bars, total_time): + height = bar.get_height() + ax3.text(bar.get_x() + bar.get_width()/2., height, + f'{value:.3f} s', + ha='center', va='bottom', fontsize=11, fontweight='bold') + + # Plot 4: Memory Usage (if available) + ax4 = axes[1, 1] + if comparison['optimized']['avg_memory_mb'] and comparison['non_optimized']['avg_memory_mb']: + memory_usage = [ + comparison['optimized']['avg_memory_mb'], + comparison['non_optimized']['avg_memory_mb'] + ] + bars = ax4.bar(categories, memory_usage, color=colors, alpha=0.7, edgecolor='black', linewidth=1.5) + ax4.set_ylabel('Memory Usage (MB)', fontsize=12) + ax4.set_title('Memory Usage', fontsize=14, fontweight='bold') + ax4.grid(True, alpha=0.3, axis='y') + + for bar, value in zip(bars, memory_usage): + height = bar.get_height() + ax4.text(bar.get_x() + bar.get_width()/2., height, + f'{value:.1f} MB', + ha='center', va='bottom', fontsize=11, fontweight='bold') + + # Add memory reduction annotation + if comparison['memory_reduction_percent']: + reduction_text = f"Reduction: {comparison['memory_reduction_percent']:.1f}%" + ax4.text(0.5, 0.95, reduction_text, transform=ax4.transAxes, + ha='center', va='top', fontsize=12, fontweight='bold', + bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.5)) + else: + ax4.text(0.5, 0.5, 'Memory data\nnot available', + ha='center', va='center', fontsize=12, + transform=ax4.transAxes) + ax4.set_title('Memory Usage', fontsize=14, fontweight='bold') + + plt.suptitle('Inference Optimization Comparison', fontsize=16, fontweight='bold', y=0.995) + plt.tight_layout() + plt.savefig(save_path, dpi=150, bbox_inches='tight') + print(f"šŸ“Š Comparison plot saved to: {save_path}") + plt.close() + + def plot_performance_over_time(self, save_path: Optional[str] = None): + """ + Plot performance metrics over time for research purposes. + + Args: + save_path: Path to save plot (default: save_dir/performance_over_time.png) + """ + if save_path is None: + save_path = self.save_dir / 'performance_over_time.png' + + runs = self.metrics['runs'] + if len(runs) < 2: + print("āš ļø Need at least 2 runs for time series plot") + return + + # Sort by timestamp + sorted_runs = sorted(runs, key=lambda x: x['timestamp']) + + optimized_times = [] + optimized_tokens_per_sec = [] + non_optimized_times = [] + non_optimized_tokens_per_sec = [] + + for run in sorted_runs: + if run['optimized']: + optimized_times.append(run['timestamp']) + optimized_tokens_per_sec.append(run['tokens_per_second']) + else: + non_optimized_times.append(run['timestamp']) + non_optimized_tokens_per_sec.append(run['tokens_per_second']) + + fig, ax = plt.subplots(figsize=(12, 6)) + + if optimized_times: + ax.plot(optimized_times, optimized_tokens_per_sec, 'o-', + label='Optimized', color='#2ecc71', linewidth=2, markersize=8) + + if non_optimized_times: + ax.plot(non_optimized_times, non_optimized_tokens_per_sec, 's-', + label='Non-Optimized', color='#e74c3c', linewidth=2, markersize=8) + + ax.set_xlabel('Time', fontsize=12) + ax.set_ylabel('Tokens per Second', fontsize=12) + ax.set_title('Performance Over Time', fontsize=14, fontweight='bold') + ax.legend(fontsize=11) + ax.grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig(save_path, dpi=150, bbox_inches='tight') + print(f"šŸ“Š Performance over time plot saved to: {save_path}") + plt.close() + + def export_to_csv(self, save_path: Optional[str] = None): + """ + Export metrics to CSV file for analysis. + + Args: + save_path: Path to save CSV (default: save_dir/inference_metrics.csv) + """ + if save_path is None: + save_path = self.save_dir / 'inference_metrics.csv' + + import csv + + runs = self.metrics['runs'] + if not runs: + print("āš ļø No runs to export") + return + + with open(save_path, 'w', newline='') as f: + writer = csv.writer(f) + # Header + writer.writerow([ + 'run_name', 'timestamp', 'optimized', 'prompt_length', 'generated_length', + 'total_time', 'tokens_per_second', 'time_per_token', 'memory_used_mb', + 'gpu_utilization', 'device' + ]) + + # Data rows + for run in runs: + writer.writerow([ + run['run_name'], + run['timestamp'], + run['optimized'], + run['prompt_length'], + run['generated_length'], + run['total_time'], + run['tokens_per_second'], + run['time_per_token'], + run.get('memory_used_mb', ''), + run.get('gpu_utilization', ''), + run['device'], + ]) + + print(f"šŸ“Š Metrics exported to CSV: {save_path}") + + def print_summary(self): + """Print comparison summary.""" + comparison = self.get_comparison_data() + + print("\n" + "=" * 70) + print("INFERENCE OPTIMIZATION BENCHMARK SUMMARY") + print("=" * 70) + + print(f"\nOptimized Runs: {comparison['optimized']['count']}") + if comparison['optimized']['count'] > 0: + print(f" Average Tokens/Second: {comparison['optimized']['avg_tokens_per_sec']:.2f}") + print(f" Average Time/Token: {comparison['optimized']['avg_time_per_token']:.3f} ms") + print(f" Average Total Time: {comparison['optimized']['avg_total_time']:.3f} s") + if comparison['optimized']['avg_memory_mb']: + print(f" Average Memory: {comparison['optimized']['avg_memory_mb']:.1f} MB") + + print(f"\nNon-Optimized Runs: {comparison['non_optimized']['count']}") + if comparison['non_optimized']['count'] > 0: + print(f" Average Tokens/Second: {comparison['non_optimized']['avg_tokens_per_sec']:.2f}") + print(f" Average Time/Token: {comparison['non_optimized']['avg_time_per_token']:.3f} ms") + print(f" Average Total Time: {comparison['non_optimized']['avg_total_time']:.3f} s") + if comparison['non_optimized']['avg_memory_mb']: + print(f" Average Memory: {comparison['non_optimized']['avg_memory_mb']:.1f} MB") + + if comparison['speedup']: + print(f"\nšŸš€ SPEEDUP: {comparison['speedup']:.2f}x faster with optimizations") + + if comparison['memory_reduction_percent']: + print(f"šŸ’¾ MEMORY REDUCTION: {comparison['memory_reduction_percent']:.1f}%") + + print("=" * 70) diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..288ca0c --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,35 @@ +""" +SheepOp LLM - A modern language model implementation +Optimized for production RAG systems +""" +from .transformer import TransformerModel +from .attention import MultiHeadAttention, PositionalEncoding +from .blocks import TransformerBlock, FeedForward +from .optimized_attention import ( + OptimizedMultiHeadAttention, + RetrievalCache, + OptimizedInference, + KVCache, +) +from .prefetching import ( + PrefetchDataLoader, + LookaheadRetriever, + BatchPrefetcher, +) + +__all__ = [ + 'TransformerModel', + 'MultiHeadAttention', + 'PositionalEncoding', + 'TransformerBlock', + 'FeedForward', + 'OptimizedMultiHeadAttention', + 'RetrievalCache', + 'OptimizedInference', + 'KVCache', + 'PrefetchDataLoader', + 'LookaheadRetriever', + 'BatchPrefetcher', +] + + diff --git a/models/attention.py b/models/attention.py new file mode 100644 index 0000000..f1bf945 --- /dev/null +++ b/models/attention.py @@ -0,0 +1,220 @@ +""" +Multi-Head Attention mechanism from "Attention Is All You Need" +Includes optimizations for long context and hallucination reduction +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from typing import Optional, Tuple + + +class MultiHeadAttention(nn.Module): + """ + Multi-Head Attention mechanism with optional causal masking. + + Features: + - Scaled dot-product attention + - Optional causal masking for autoregressive generation + - Efficient attention computation + """ + + def __init__( + self, + d_model: int, + num_heads: int, + dropout: float = 0.1, + bias: bool = False, + causal: bool = False, + ): + """ + Args: + d_model: Model dimension + num_heads: Number of attention heads + dropout: Dropout probability + bias: Whether to use bias in linear layers + causal: Whether to use causal masking + """ + super().__init__() + assert d_model % num_heads == 0, "d_model must be divisible by num_heads" + + self.d_model = d_model + self.num_heads = num_heads + self.d_k = d_model // num_heads + self.causal = causal + + # Linear projections for Q, K, V + self.q_proj = nn.Linear(d_model, d_model, bias=bias) + self.k_proj = nn.Linear(d_model, d_model, bias=bias) + self.v_proj = nn.Linear(d_model, d_model, bias=bias) + self.out_proj = nn.Linear(d_model, d_model, bias=bias) + + self.dropout = nn.Dropout(dropout) + self.scale = 1.0 / math.sqrt(self.d_k) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass of multi-head attention. + + Args: + query: Query tensor [batch_size, seq_len, d_model] + key: Key tensor [batch_size, seq_len, d_model] + value: Value tensor [batch_size, seq_len, d_model] + mask: Optional attention mask [batch_size, seq_len, seq_len] + + Returns: + output: Attention output [batch_size, seq_len, d_model] + attention_weights: Attention weights [batch_size, num_heads, seq_len, seq_len] + """ + batch_size, seq_len, _ = query.shape + + # Project Q, K, V + Q = self.q_proj(query) # [batch_size, seq_len, d_model] + K = self.k_proj(key) + V = self.v_proj(value) + + # Reshape for multi-head attention + Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) # [batch_size, num_heads, seq_len, d_k] + K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) + V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) + + # Compute attention scores + scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # [batch_size, num_heads, seq_len, seq_len] + + # Apply causal mask if needed + if self.causal: + causal_mask = torch.triu( + torch.ones(seq_len, seq_len, device=query.device, dtype=torch.bool), + diagonal=1 + ) + scores.masked_fill_(causal_mask, float('-inf')) + + # Apply external mask if provided + if mask is not None: + scores.masked_fill_(mask.unsqueeze(1) == 0, float('-inf')) + + # Compute attention weights + attention_weights = F.softmax(scores, dim=-1) + attention_weights = self.dropout(attention_weights) + + # Apply attention to values + output = torch.matmul(attention_weights, V) # [batch_size, num_heads, seq_len, d_k] + + # Concatenate heads + output = output.transpose(1, 2).contiguous() # [batch_size, seq_len, num_heads, d_k] + output = output.view(batch_size, seq_len, self.d_model) # [batch_size, seq_len, d_model] + + # Final projection + output = self.out_proj(output) + + return output, attention_weights + + +class PositionalEncoding(nn.Module): + """ + Positional encoding for transformer models. + Uses sinusoidal positional encoding as described in "Attention Is All You Need". + """ + + def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1): + """ + Args: + d_model: Model dimension + max_len: Maximum sequence length + dropout: Dropout probability + """ + super().__init__() + self.dropout = nn.Dropout(p=dropout) + + # Create positional encoding matrix + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) + ) + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) # [1, max_len, d_model] + + # Register as buffer (not a parameter) + self.register_buffer('pe', pe) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Add positional encoding to input. + + Args: + x: Input tensor [batch_size, seq_len, d_model] + + Returns: + Output with positional encoding added + """ + seq_len = x.shape[1] + x = x + self.pe[:, :seq_len, :] + return self.dropout(x) + + +class RotaryPositionalEncoding(nn.Module): + """ + Rotary Position Embedding (RoPE) - More efficient for long sequences. + Better for long-horizon execution tasks. + """ + + def __init__(self, d_model: int, max_len: int = 8192): + """ + Args: + d_model: Model dimension (must be even) + max_len: Maximum sequence length + """ + super().__init__() + assert d_model % 2 == 0, "d_model must be even for RoPE" + + self.d_model = d_model + self.max_len = max_len + + # Precompute frequency matrix + inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x: torch.Tensor, offset: int = 0) -> torch.Tensor: + """ + Apply rotary positional encoding. + + Args: + x: Input tensor [batch_size, seq_len, d_model] + offset: Position offset for relative positions + + Returns: + Rotated input tensor + """ + seq_len = x.shape[1] + device = x.device + + # Generate position indices + t = torch.arange(offset, offset + seq_len, device=device).type_as(self.inv_freq) + freqs = torch.outer(t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + + # Apply rotation + cos = emb.cos() + sin = emb.sin() + + # Split input into two halves + x1, x2 = x.chunk(2, dim=-1) + + # Apply rotation + rotated = torch.cat([ + x1 * cos - x2 * sin, + x1 * sin + x2 * cos + ], dim=-1) + + return rotated + + diff --git a/models/blocks.py b/models/blocks.py new file mode 100644 index 0000000..4c113eb --- /dev/null +++ b/models/blocks.py @@ -0,0 +1,153 @@ +""" +Transformer building blocks: Feed-forward networks and transformer blocks +""" +import torch +import torch.nn as nn +from typing import Optional +from .attention import MultiHeadAttention +from .optimized_attention import OptimizedMultiHeadAttention + + +class FeedForward(nn.Module): + """ + Position-wise Feed-Forward Network. + Implements two linear transformations with activation in between. + """ + + def __init__( + self, + d_model: int, + d_ff: int, + dropout: float = 0.1, + activation: str = 'gelu', + bias: bool = False, + ): + """ + Args: + d_model: Model dimension + d_ff: Feed-forward dimension (typically 4 * d_model) + dropout: Dropout probability + activation: Activation function ('gelu' or 'relu') + bias: Whether to use bias in linear layers + """ + super().__init__() + self.linear1 = nn.Linear(d_model, d_ff, bias=bias) + self.linear2 = nn.Linear(d_ff, d_model, bias=bias) + self.dropout = nn.Dropout(dropout) + + if activation == 'gelu': + self.activation = nn.GELU() + elif activation == 'relu': + self.activation = nn.ReLU() + else: + raise ValueError(f"Unsupported activation: {activation}") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: Input tensor [batch_size, seq_len, d_model] + + Returns: + Output tensor [batch_size, seq_len, d_model] + """ + x = self.linear1(x) + x = self.activation(x) + x = self.dropout(x) + x = self.linear2(x) + return x + + +class TransformerBlock(nn.Module): + """ + Transformer block with self-attention and feed-forward network. + Includes residual connections and layer normalization. + """ + + def __init__( + self, + d_model: int, + num_heads: int, + d_ff: int, + dropout: float = 0.1, + activation: str = 'gelu', + layer_norm_eps: float = 1e-5, + bias: bool = False, + causal: bool = False, + use_optimized_attention: bool = False, + ): + """ + Args: + d_model: Model dimension + num_heads: Number of attention heads + d_ff: Feed-forward dimension + dropout: Dropout probability + activation: Activation function + layer_norm_eps: Epsilon for layer normalization + bias: Whether to use bias in linear layers + causal: Whether to use causal masking + use_optimized_attention: Whether to use optimized attention with KV caching + """ + super().__init__() + + # Self-attention with pre-norm architecture + if use_optimized_attention: + self.self_attn = OptimizedMultiHeadAttention( + d_model=d_model, + num_heads=num_heads, + dropout=dropout, + bias=bias, + causal=causal, + use_flash_attention=True, + ) + else: + self.self_attn = MultiHeadAttention( + d_model=d_model, + num_heads=num_heads, + dropout=dropout, + bias=bias, + causal=causal, + ) + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + + # Feed-forward network + self.feed_forward = FeedForward( + d_model=d_model, + d_ff=d_ff, + dropout=dropout, + activation=activation, + bias=bias, + ) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass through transformer block. + + Args: + x: Input tensor [batch_size, seq_len, d_model] + mask: Optional attention mask + + Returns: + Output tensor [batch_size, seq_len, d_model] + """ + # Pre-norm self-attention with residual connection + residual = x + x = self.norm1(x) + attn_out, _ = self.self_attn(x, x, x, mask=mask) + x = residual + self.dropout(attn_out) + + # Pre-norm feed-forward with residual connection + residual = x + x = self.norm2(x) + ff_out = self.feed_forward(x) + x = residual + self.dropout(ff_out) + + return x + + diff --git a/models/optimized_attention.py b/models/optimized_attention.py new file mode 100644 index 0000000..d9d8225 --- /dev/null +++ b/models/optimized_attention.py @@ -0,0 +1,413 @@ +""" +Optimized attention mechanisms for production RAG systems +Implements KV caching, optimized attention computation, and retrieval optimizations +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from typing import Optional, Tuple, Dict, List +from dataclasses import dataclass + + +@dataclass +class KVCache: + """Key-Value cache for efficient autoregressive generation.""" + keys: torch.Tensor # [batch_size, num_heads, seq_len, d_k] + values: torch.Tensor # [batch_size, num_heads, seq_len, d_k] + + def append(self, new_keys: torch.Tensor, new_values: torch.Tensor): + """Append new keys and values to the cache.""" + self.keys = torch.cat([self.keys, new_keys], dim=2) + self.values = torch.cat([self.values, new_values], dim=2) + + def clear(self): + """Clear the cache.""" + self.keys = None + self.values = None + + +class OptimizedMultiHeadAttention(nn.Module): + """ + Optimized Multi-Head Attention with KV caching and efficient computation. + + Features: + - KV cache for autoregressive generation + - Optimized attention computation + - Support for incremental decoding + """ + + def __init__( + self, + d_model: int, + num_heads: int, + dropout: float = 0.1, + bias: bool = False, + causal: bool = False, + use_flash_attention: bool = False, + ): + """ + Args: + d_model: Model dimension + num_heads: Number of attention heads + dropout: Dropout probability + bias: Whether to use bias in linear layers + causal: Whether to use causal masking + use_flash_attention: Whether to use optimized flash attention (if available) + """ + super().__init__() + assert d_model % num_heads == 0, "d_model must be divisible by num_heads" + + self.d_model = d_model + self.num_heads = num_heads + self.d_k = d_model // num_heads + self.causal = causal + self.use_flash_attention = use_flash_attention + + # Linear projections for Q, K, V + self.q_proj = nn.Linear(d_model, d_model, bias=bias) + self.k_proj = nn.Linear(d_model, d_model, bias=bias) + self.v_proj = nn.Linear(d_model, d_model, bias=bias) + self.out_proj = nn.Linear(d_model, d_model, bias=bias) + + self.dropout = nn.Dropout(dropout) + self.scale = 1.0 / math.sqrt(self.d_k) + + # KV cache for inference + self.kv_cache: Optional[KVCache] = None + + def forward( + self, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + value: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + cache_position: Optional[int] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Forward pass with optional KV caching. + + Args: + query: Query tensor [batch_size, seq_len, d_model] + key: Key tensor [batch_size, seq_len, d_model] (if None, uses query) + value: Value tensor [batch_size, seq_len, d_model] (if None, uses query) + mask: Optional attention mask [batch_size, seq_len, seq_len] + use_cache: Whether to use KV cache + cache_position: Position in cache for incremental decoding + + Returns: + output: Attention output [batch_size, seq_len, d_model] + attention_weights: Attention weights [batch_size, num_heads, seq_len, seq_len] + """ + if key is None: + key = query + if value is None: + value = query + + batch_size, seq_len, _ = query.shape + + # Project Q, K, V + Q = self.q_proj(query) # [batch_size, seq_len, d_model] + K = self.k_proj(key) + V = self.v_proj(value) + + # Reshape for multi-head attention + Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) # [batch_size, num_heads, seq_len, d_k] + K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) + V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) + + # Use KV cache if available and enabled + if use_cache and self.kv_cache is not None: + # Append new keys and values to cache + self.kv_cache.append(K, V) + K = self.kv_cache.keys + V = self.kv_cache.values + kv_seq_len = K.shape[2] + else: + kv_seq_len = seq_len + + # Compute attention scores with optimized computation + if self.use_flash_attention and hasattr(F, 'scaled_dot_product_attention'): + # Use PyTorch's optimized scaled dot product attention + output = F.scaled_dot_product_attention( + Q, K, V, + attn_mask=mask, + dropout_p=self.dropout.p if self.training else 0.0, + is_causal=self.causal, + ) + attention_weights = None # Flash attention doesn't return weights + else: + # Standard attention computation + scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # [batch_size, num_heads, seq_len, kv_seq_len] + + # Apply causal mask if needed + if self.causal: + causal_mask = torch.triu( + torch.ones(seq_len, kv_seq_len, device=query.device, dtype=torch.bool), + diagonal=1 + ) + scores.masked_fill_(causal_mask, float('-inf')) + + # Apply external mask if provided + if mask is not None: + if mask.dim() == 2: + mask = mask.unsqueeze(1).unsqueeze(1) # [batch_size, 1, seq_len, kv_seq_len] + scores.masked_fill_(mask == 0, float('-inf')) + + # Compute attention weights + attention_weights = F.softmax(scores, dim=-1) + attention_weights = self.dropout(attention_weights) + + # Apply attention to values + output = torch.matmul(attention_weights, V) # [batch_size, num_heads, seq_len, d_k] + + # Concatenate heads + output = output.transpose(1, 2).contiguous() # [batch_size, seq_len, num_heads, d_k] + output = output.view(batch_size, seq_len, self.d_model) # [batch_size, seq_len, d_model] + + # Final projection + output = self.out_proj(output) + + return output, attention_weights + + def init_kv_cache(self, batch_size: int, max_length: int, device: torch.device): + """Initialize KV cache for inference.""" + self.kv_cache = KVCache( + keys=torch.empty(batch_size, self.num_heads, 0, self.d_k, device=device), + values=torch.empty(batch_size, self.num_heads, 0, self.d_k, device=device), + ) + + def clear_cache(self): + """Clear the KV cache.""" + self.kv_cache = None + + +class RetrievalCache: + """ + Approximate cache for retrieval results. + Reduces expensive vector database lookups by caching similar queries. + """ + + def __init__(self, max_size: int = 1000, similarity_threshold: float = 0.9): + """ + Args: + max_size: Maximum number of cached entries + similarity_threshold: Minimum similarity to consider a cache hit + """ + self.max_size = max_size + self.similarity_threshold = similarity_threshold + self.cache: Dict[str, List[Dict]] = {} # query_hash -> retrieved_docs + self.query_embeddings: Dict[str, torch.Tensor] = {} # query_hash -> embedding + + def get(self, query_hash: str, query_embedding: torch.Tensor) -> Optional[List[Dict]]: + """ + Retrieve cached results if similar query exists. + + Args: + query_hash: Hash of the query + query_embedding: Embedding of the query + + Returns: + Cached results if found, None otherwise + """ + # Check exact match first + if query_hash in self.cache: + return self.cache[query_hash] + + # Check for similar queries + best_match = None + best_similarity = 0.0 + + for cached_hash, cached_embedding in self.query_embeddings.items(): + # Compute cosine similarity + similarity = F.cosine_similarity( + query_embedding.unsqueeze(0), + cached_embedding.unsqueeze(0) + ).item() + + if similarity > best_similarity: + best_similarity = similarity + best_match = cached_hash + + if best_similarity >= self.similarity_threshold and best_match: + return self.cache[best_match] + + return None + + def set(self, query_hash: str, query_embedding: torch.Tensor, results: List[Dict]): + """ + Store query and results in cache. + + Args: + query_hash: Hash of the query + query_embedding: Embedding of the query + results: Retrieved documents/results + """ + # Remove oldest entry if cache is full + if len(self.cache) >= self.max_size: + oldest_key = next(iter(self.cache)) + del self.cache[oldest_key] + del self.query_embeddings[oldest_key] + + self.cache[query_hash] = results + self.query_embeddings[query_hash] = query_embedding + + def clear(self): + """Clear the cache.""" + self.cache.clear() + self.query_embeddings.clear() + + +class OptimizedInference: + """ + Optimized inference utilities for production RAG systems. + Includes prefetching, batching, and parallel processing. + """ + + def __init__(self, model: nn.Module, device: torch.device): + """ + Args: + model: Model to use for inference + device: Device to run inference on + """ + self.model = model + self.device = device + self.model.eval() + + @torch.no_grad() + def generate_with_cache( + self, + input_ids: torch.Tensor, + max_length: int = 100, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: float = 1.0, + do_sample: bool = True, + ) -> torch.Tensor: + """ + Generate with KV cache for efficient autoregressive generation. + + Args: + input_ids: Starting token indices [batch_size, seq_len] + max_length: Maximum generation length + temperature: Sampling temperature + top_k: Top-k sampling parameter + top_p: Nucleus sampling parameter + do_sample: Whether to sample or use greedy decoding + + Returns: + Generated token sequences + """ + batch_size = input_ids.shape[0] + device = input_ids.device + + # Initialize KV cache in all attention layers + for module in self.model.modules(): + if isinstance(module, OptimizedMultiHeadAttention): + module.init_kv_cache(batch_size, max_length, device) + + generated = input_ids.clone() + + for _ in range(max_length - input_ids.shape[1]): + # Forward pass + logits, _ = self.model(generated) + next_token_logits = logits[:, -1, :] / temperature + + # Apply top-k filtering + if top_k is not None and top_k > 0: + indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] + next_token_logits[indices_to_remove] = float('-inf') + + # Apply top-p (nucleus) filtering + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + next_token_logits[indices_to_remove] = float('-inf') + + # Sample or take argmax + if do_sample: + probs = torch.softmax(next_token_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) + + # Early stopping: stop if EOS or padding token is generated (for batch_size=1) + if batch_size == 1: + eos_token_id = 3 # Default EOS token ID + if next_token.item() == eos_token_id: + break + + # Early stopping: stop if padding token is generated (prevent generating padding) + pad_token_id = 0 # Default padding token ID + if next_token.item() == pad_token_id: + break + + # Append to generated sequence + generated = torch.cat([generated, next_token], dim=1) + + # Clear KV cache + for module in self.model.modules(): + if isinstance(module, OptimizedMultiHeadAttention): + module.clear_cache() + + return generated + + @torch.no_grad() + def batch_generate( + self, + input_ids_list: List[torch.Tensor], + max_length: int = 100, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: float = 1.0, + batch_size: int = 8, + ) -> List[torch.Tensor]: + """ + Generate for multiple prompts in batches for efficiency. + + Args: + input_ids_list: List of starting token sequences + max_length: Maximum generation length + temperature: Sampling temperature + top_k: Top-k sampling parameter + top_p: Nucleus sampling parameter + batch_size: Batch size for processing + + Returns: + List of generated sequences + """ + results = [] + + for i in range(0, len(input_ids_list), batch_size): + batch = input_ids_list[i:i + batch_size] + + # Pad to same length + max_len = max(seq.shape[1] for seq in batch) + padded_batch = [] + for seq in batch: + padding = torch.zeros(seq.shape[0], max_len - seq.shape[1], + dtype=seq.dtype, device=seq.device) + padded_batch.append(torch.cat([seq, padding], dim=1)) + + batch_tensor = torch.cat(padded_batch, dim=0) + + # Generate for batch + generated = self.generate_with_cache( + batch_tensor, + max_length=max_length, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + + results.extend([gen for gen in generated]) + + return results + diff --git a/models/prefetching.py b/models/prefetching.py new file mode 100644 index 0000000..a6cb53e --- /dev/null +++ b/models/prefetching.py @@ -0,0 +1,267 @@ +""" +Prefetching mechanism for parallel data loading and processing +Optimizes RAG systems by prefetching retrieval results +""" +import torch +from torch.utils.data import DataLoader +from typing import List, Dict, Optional, Callable, Any +from threading import Thread +from queue import Queue +import time + + +class PrefetchDataLoader: + """ + DataLoader with prefetching for parallel data loading. + Reduces GPU idle time by prefetching batches in background threads. + """ + + def __init__( + self, + dataloader: DataLoader, + prefetch_factor: int = 2, + device: torch.device = None, + ): + """ + Args: + dataloader: Base DataLoader to wrap + prefetch_factor: Number of batches to prefetch + device: Device to prefetch batches to + """ + self.dataloader = dataloader + self.prefetch_factor = prefetch_factor + self.device = device + + self.queue = Queue(maxsize=prefetch_factor) + self.thread = None + self._stop_thread = False + + def _prefetch_worker(self): + """Worker thread that prefetches batches.""" + for batch in self.dataloader: + if self._stop_thread: + break + + # Move to device if specified + if self.device is not None: + batch = {k: v.to(self.device, non_blocking=True) + for k, v in batch.items()} + + self.queue.put(batch) + + self.queue.put(None) # Signal end of data + + def __iter__(self): + """Start prefetching thread and return iterator.""" + self._stop_thread = False + self.thread = Thread(target=self._prefetch_worker, daemon=True) + self.thread.start() + return self + + def __next__(self): + """Get next prefetched batch.""" + batch = self.queue.get() + if batch is None: + raise StopIteration + return batch + + def __len__(self): + """Return length of underlying dataloader.""" + return len(self.dataloader) + + def stop(self): + """Stop prefetching thread.""" + self._stop_thread = True + if self.thread is not None: + self.thread.join() + + +class LookaheadRetriever: + """ + Lookahead retrieval mechanism for RAG systems. + Prefetches retrieval results for anticipated queries. + """ + + def __init__( + self, + retrieval_fn: Callable[[str], List[Dict]], + lookahead_window: int = 3, + prefetch_queue_size: int = 10, + ): + """ + Args: + retrieval_fn: Function that takes a query and returns retrieved documents + lookahead_window: Number of queries to look ahead + prefetch_queue_size: Maximum size of prefetch queue + """ + self.retrieval_fn = retrieval_fn + self.lookahead_window = lookahead_window + self.prefetch_queue_size = prefetch_queue_size + + self.prefetch_queue: Queue = Queue(maxsize=prefetch_queue_size) + self.prefetch_thread: Optional[Thread] = None + self._stop_thread = False + + def _prefetch_worker(self, query_queue: Queue): + """Worker thread that prefetches retrieval results.""" + while not self._stop_thread: + try: + query = query_queue.get(timeout=1.0) + if query is None: + break + + # Perform retrieval + results = self.retrieval_fn(query) + + # Add to prefetch queue + try: + self.prefetch_queue.put((query, results), timeout=0.1) + except: + pass # Queue full, skip + + except: + continue + + def start_prefetching(self, query_stream: List[str]): + """Start prefetching retrieval results for query stream.""" + query_queue = Queue() + + # Add queries to queue + for query in query_stream: + query_queue.put(query) + query_queue.put(None) # Signal end + + self._stop_thread = False + self.prefetch_thread = Thread(target=self._prefetch_worker, args=(query_queue,), daemon=True) + self.prefetch_thread.start() + + def get(self, query: str, timeout: float = 1.0) -> Optional[List[Dict]]: + """ + Get retrieval results, checking prefetch queue first. + + Args: + query: Query string + timeout: Timeout for checking prefetch queue + + Returns: + Retrieved documents or None if not found + """ + # Check prefetch queue + while not self.prefetch_queue.empty(): + try: + cached_query, results = self.prefetch_queue.get(timeout=timeout) + if cached_query == query: + return results + # Put back if not matching + self.prefetch_queue.put((cached_query, results)) + except: + break + + # Fallback to direct retrieval + return self.retrieval_fn(query) + + def stop(self): + """Stop prefetching thread.""" + self._stop_thread = True + if self.prefetch_thread is not None: + self.prefetch_thread.join() + + +class BatchPrefetcher: + """ + Batched prefetching for multiple queries. + Groups queries into batches for efficient retrieval. + """ + + def __init__( + self, + batch_retrieval_fn: Callable[[List[str]], List[List[Dict]]], + batch_size: int = 8, + prefetch_factor: int = 2, + ): + """ + Args: + batch_retrieval_fn: Function that takes list of queries and returns list of results + batch_size: Size of batches for retrieval + prefetch_factor: Number of batches to prefetch + """ + self.batch_retrieval_fn = batch_retrieval_fn + self.batch_size = batch_size + self.prefetch_factor = prefetch_factor + + self.prefetch_queue: Queue = Queue(maxsize=prefetch_factor) + self.prefetch_thread: Optional[Thread] = None + self._stop_thread = False + + def _prefetch_worker(self, query_queue: Queue): + """Worker thread that prefetches batches of retrieval results.""" + batch = [] + + while not self._stop_thread: + try: + query = query_queue.get(timeout=1.0) + if query is None: + # Process remaining batch + if batch: + results = self.batch_retrieval_fn(batch) + for q, r in zip(batch, results): + self.prefetch_queue.put((q, r)) + break + + batch.append(query) + + # Process batch when full + if len(batch) >= self.batch_size: + results = self.batch_retrieval_fn(batch) + for q, r in zip(batch, results): + try: + self.prefetch_queue.put((q, r), timeout=0.1) + except: + pass # Queue full + batch = [] + + except: + continue + + def start_prefetching(self, query_stream: List[str]): + """Start prefetching retrieval results for query stream.""" + query_queue = Queue() + + for query in query_stream: + query_queue.put(query) + query_queue.put(None) # Signal end + + self._stop_thread = False + self.prefetch_thread = Thread(target=self._prefetch_worker, args=(query_queue,), daemon=True) + self.prefetch_thread.start() + + def get(self, query: str, timeout: float = 1.0) -> Optional[List[Dict]]: + """ + Get retrieval results from prefetch queue. + + Args: + query: Query string + timeout: Timeout for checking prefetch queue + + Returns: + Retrieved documents or None if not found + """ + # Check prefetch queue + while not self.prefetch_queue.empty(): + try: + cached_query, results = self.prefetch_queue.get(timeout=timeout) + if cached_query == query: + return results + # Put back if not matching + self.prefetch_queue.put((cached_query, results)) + except: + break + + return None + + def stop(self): + """Stop prefetching thread.""" + self._stop_thread = True + if self.prefetch_thread is not None: + self.prefetch_thread.join() + diff --git a/models/transformer.py b/models/transformer.py new file mode 100644 index 0000000..0cd3690 --- /dev/null +++ b/models/transformer.py @@ -0,0 +1,268 @@ +""" +Complete Transformer model for language modeling +Incorporates best practices from multiple research papers +Optimized for production RAG systems with KV caching and efficient inference +""" +import torch +import torch.nn as nn +from typing import Optional, Tuple +from .blocks import TransformerBlock +from .attention import PositionalEncoding +from .optimized_attention import OptimizedMultiHeadAttention, RetrievalCache, OptimizedInference + + +class TransformerModel(nn.Module): + """ + Full Transformer Language Model. + + Features: + - Multi-head self-attention + - Positional encoding + - Layer normalization + - Residual connections + - Causal masking for autoregressive generation + """ + + def __init__( + self, + vocab_size: int, + d_model: int = 512, + num_layers: int = 6, + num_heads: int = 8, + d_ff: int = 2048, + max_seq_len: int = 512, + dropout: float = 0.1, + activation: str = 'gelu', + layer_norm_eps: float = 1e-5, + bias: bool = False, + tie_weights: bool = True, + ): + """ + Args: + vocab_size: Vocabulary size + d_model: Model dimension + num_layers: Number of transformer layers + num_heads: Number of attention heads + d_ff: Feed-forward dimension + max_seq_len: Maximum sequence length + dropout: Dropout probability + activation: Activation function ('gelu' or 'relu') + layer_norm_eps: Epsilon for layer normalization + bias: Whether to use bias in linear layers + tie_weights: Whether to tie input and output embeddings + """ + super().__init__() + + self.vocab_size = vocab_size + self.d_model = d_model + self.num_layers = num_layers + self.num_heads = num_heads + self.max_seq_len = max_seq_len + + # Token embeddings + self.token_embedding = nn.Embedding(vocab_size, d_model) + + # Positional encoding + self.pos_encoding = PositionalEncoding( + d_model=d_model, + max_len=max_seq_len, + dropout=dropout, + ) + + # Transformer blocks (use optimized attention if available) + # Note: Set use_optimized_attention=True for production inference + self.layers = nn.ModuleList([ + TransformerBlock( + d_model=d_model, + num_heads=num_heads, + d_ff=d_ff, + dropout=dropout, + activation=activation, + layer_norm_eps=layer_norm_eps, + bias=bias, + causal=True, # Causal masking for autoregressive generation + use_optimized_attention=False, # Set to True for inference optimizations + ) + for _ in range(num_layers) + ]) + + # Final layer norm + self.final_norm = nn.LayerNorm(d_model, eps=layer_norm_eps) + + # Output projection + self.output_proj = nn.Linear(d_model, vocab_size, bias=bias) + + # Optionally tie weights + if tie_weights: + self.output_proj.weight = self.token_embedding.weight + + self.dropout = nn.Dropout(dropout) + + # Retrieval cache for RAG systems + self.retrieval_cache = RetrievalCache(max_size=1000, similarity_threshold=0.9) + + # Initialize weights + self._init_weights() + + def _init_weights(self): + """Initialize weights following best practices.""" + # Initialize embeddings + nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02) + + # Initialize output projection + if self.output_proj.weight is not self.token_embedding.weight: + nn.init.normal_(self.output_proj.weight, mean=0.0, std=0.02) + + # Initialize linear layers + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + elif isinstance(module, nn.LayerNorm): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Forward pass through the transformer model. + + Args: + input_ids: Token indices [batch_size, seq_len] + attention_mask: Optional attention mask [batch_size, seq_len] + + Returns: + logits: Output logits [batch_size, seq_len, vocab_size] + attention_weights: Optional attention weights + """ + batch_size, seq_len = input_ids.shape + + # Token embeddings + x = self.token_embedding(input_ids) # [batch_size, seq_len, d_model] + + # Add positional encoding + x = self.pos_encoding(x) + x = self.dropout(x) + + # Create attention mask if not provided + if attention_mask is None: + attention_mask = torch.ones( + batch_size, seq_len, device=input_ids.device, dtype=torch.bool + ) + + # Expand mask for attention + # [batch_size, seq_len] -> [batch_size, seq_len, seq_len] + if attention_mask.dim() == 2: + attention_mask = attention_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len) + + # Apply attention mask (invert: 1 for valid, 0 for masked) + attention_mask = attention_mask.float() + + # Pass through transformer blocks + for layer in self.layers: + x = layer(x, mask=attention_mask) + + # Final layer norm + x = self.final_norm(x) + + # Output projection + logits = self.output_proj(x) # [batch_size, seq_len, vocab_size] + + return logits, None + + def generate( + self, + input_ids: torch.Tensor, + max_length: int = 100, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: float = 1.0, + do_sample: bool = True, + pad_token_id: Optional[int] = None, + ) -> torch.Tensor: + """ + Autoregressive generation. + + Args: + input_ids: Starting token indices [batch_size, seq_len] + max_length: Maximum generation length + temperature: Sampling temperature + top_k: Top-k sampling parameter + top_p: Nucleus sampling parameter + do_sample: Whether to sample or use greedy decoding + pad_token_id: Padding token ID + + Returns: + Generated token sequences + """ + self.eval() + device = input_ids.device + batch_size = input_ids.shape[0] + + generated = input_ids.clone() + + with torch.no_grad(): + for _ in range(max_length - input_ids.shape[1]): + # Forward pass + logits, _ = self.forward(generated) + next_token_logits = logits[:, -1, :] / temperature + + # Apply top-k filtering + if top_k is not None and top_k > 0: + indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] + next_token_logits[indices_to_remove] = float('-inf') + + # Apply top-p (nucleus) filtering + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above threshold + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + next_token_logits[indices_to_remove] = float('-inf') + + # Sample or take argmax + if do_sample: + probs = torch.softmax(next_token_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) + + # Early stopping: stop if EOS or padding token is generated (for batch_size=1) + if batch_size == 1: + eos_token_id = getattr(self, 'eos_token_id', None) or 3 # Default EOS token + if next_token.item() == eos_token_id: + break + + # Early stopping: stop if padding token is generated (prevent generating padding) + if pad_token_id is not None and next_token.item() == pad_token_id: + break + + # Append to generated sequence + generated = torch.cat([generated, next_token], dim=1) + + return generated + + def get_optimized_inference(self) -> OptimizedInference: + """ + Get optimized inference utility with KV caching and batching. + + Returns: + OptimizedInference instance + """ + return OptimizedInference(self, next(self.parameters()).device) + + def get_num_params(self) -> int: + """Return the number of trainable parameters.""" + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5080c1c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,28 @@ +# IMPORTANT: On modern Debian/Ubuntu systems (Python 3.12+), you MUST use a virtual environment +# before installing these packages. Run: python3 -m venv venv && source venv/bin/activate +# Or use the automated setup script: ./setup.sh + +torch>=2.0.0 +transformers>=4.30.0 +numpy>=1.24.0 +tqdm>=4.65.0 +tensorboard>=2.13.0 +matplotlib>=3.7.0 + +# Optional dependencies for data processing +# Install these if you want to process PDFs or images: +# For PDF processing (choose one - pdfplumber is recommended for better quality): +pdfplumber>=0.9.0 # Recommended: better text extraction quality +# PyPDF2>=3.0.0 # Alternative PDF library (lighter weight but less accurate) + +# For image OCR (requires Tesseract OCR engine installed on system): +# pytesseract>=0.3.10 # For OCR +# Pillow>=10.0.0 # Required for image processing with pytesseract +# +# To install Tesseract OCR engine: +# Ubuntu/Debian: sudo apt-get install tesseract-ocr +# macOS: brew install tesseract +# Windows: Download from https://github.com/UB-Mannheim/tesseract/wiki + +# For downloading datasets from Hugging Face (used by download_large_data.py): +datasets>=2.14.0 # Optional: for downloading WikiText, OpenWebText, BookCorpus, etc. diff --git a/setup_storage.py b/setup_storage.py new file mode 100755 index 0000000..be46f88 --- /dev/null +++ b/setup_storage.py @@ -0,0 +1,291 @@ +#!/usr/bin/env python3 +""" +Move large files to external storage and create symbolic links. +Helps manage large datasets and checkpoints on systems with limited space. +""" +import os +import shutil +import subprocess +from pathlib import Path +import argparse + + +def create_storage_structure(storage_root: str): + """Create directory structure in storage location.""" + storage_path = Path(storage_root) + storage_path.mkdir(parents=True, exist_ok=True) + + # Create subdirectories + (storage_path / "data").mkdir(exist_ok=True) + (storage_path / "checkpoints").mkdir(exist_ok=True) + (storage_path / "checkpoints_test").mkdir(exist_ok=True) + + print(f"āœ… Created storage structure at: {storage_path}") + return storage_path + + +def move_and_link(source_dir: Path, target_dir: Path, link_name: str, dry_run: bool = False): + """ + Move directory contents to storage and create symbolic link. + + Args: + source_dir: Source directory in project + target_dir: Target directory in storage + link_name: Name for the symbolic link (same as source_dir name) + dry_run: If True, only show what would be done + """ + source_dir = Path(source_dir) + target_dir = Path(target_dir) + + if not source_dir.exists(): + print(f"āš ļø Source directory doesn't exist: {source_dir}") + return False + + if dry_run: + print(f"\n[DRY RUN] Would move contents from {source_dir} to {target_dir}") + print(f" Would replace {source_dir} with symlink -> {target_dir}") + return True + + # Move files (skip Python module files) + moved_count = 0 + temp_backup = source_dir.parent / f".{source_dir.name}_backup" + + # First, backup Python files + python_files = [] + for item in source_dir.iterdir(): + if item.suffix == '.py' or item.name.startswith('__'): + python_files.append(item) + + # Move non-Python files to storage + for item in source_dir.iterdir(): + # Skip Python files and hidden files + if item.name.startswith('__') or item.suffix == '.py' or item.name.startswith('.'): + continue + + target_item = target_dir / item.name + if target_item.exists(): + print(f"āš ļø Skipping {item.name} (already exists in storage)") + continue + + print(f"šŸ“¦ Moving {item.name}...") + try: + if item.is_dir(): + shutil.copytree(item, target_item) + shutil.rmtree(item) + else: + shutil.copy2(item, target_item) + item.unlink() + moved_count += 1 + except Exception as e: + print(f"āŒ Error moving {item.name}: {e}") + return False + + # Copy Python files to storage (keep structure) + for item in python_files: + target_item = target_dir / item.name + if not target_item.exists(): + shutil.copy2(item, target_item) + + # Replace source directory with symlink + # Step 1: Remove original directory + try: + shutil.rmtree(source_dir) + except Exception as e: + print(f"āš ļø Could not remove {source_dir}: {e}") + return False + + # Step 2: Create symlink + try: + source_dir.symlink_to(target_dir) + print(f"āœ… Created symbolic link: {source_dir} -> {target_dir}") + return True + except Exception as e: + print(f"āŒ Error creating link: {e}") + return False + + +def main(): + parser = argparse.ArgumentParser( + description='Move large files to external storage and create symbolic links', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Dry run (show what would be done) + python3 setup_storage.py --storage /mnt/storage/sheepOp --dry-run + + # Move data and checkpoints to storage + python3 setup_storage.py --storage /mnt/storage/sheepOp + + # Only move data, not checkpoints + python3 setup_storage.py --storage /mnt/storage/sheepOp --skip-checkpoints + """ + ) + + parser.add_argument( + '--storage', + type=str, + default='/mnt/storage/sheepOp', + help='Storage root directory (default: /mnt/storage/sheepOp)' + ) + + parser.add_argument( + '--project-root', + type=str, + default='.', + help='Project root directory (default: current directory)' + ) + + parser.add_argument( + '--dry-run', + action='store_true', + help='Show what would be done without actually doing it' + ) + + parser.add_argument( + '--skip-checkpoints', + action='store_true', + help='Skip moving checkpoints (only move data)' + ) + + parser.add_argument( + '--skip-data', + action='store_true', + help='Skip moving data (only move checkpoints)' + ) + + args = parser.parse_args() + + project_root = Path(args.project_root).resolve() + storage_root = Path(args.storage).resolve() + + print(f"šŸš€ Setting up storage links") + print(f" Project root: {project_root}") + print(f" Storage root: {storage_root}") + print(f" Dry run: {args.dry_run}\n") + + # Create storage structure (always create, even in dry-run, to check permissions) + try: + create_storage_structure(args.storage) + except Exception as e: + if args.dry_run: + print(f"āš ļø Could not create storage structure (will be created during actual run): {e}") + else: + print(f"āŒ Error creating storage structure: {e}") + return 1 + + storage_data = storage_root / "data" + storage_checkpoints = storage_root / "checkpoints" + storage_checkpoints_test = storage_root / "checkpoints_test" + + project_data = project_root / "data" + project_checkpoints = project_root / "checkpoints" + project_checkpoints_test = project_root / "checkpoints_test" + + success = True + + # Move data + if not args.skip_data: + print(f"\nšŸ“ Processing data directory...") + if project_data.exists(): + if project_data.is_symlink(): + print(f" ā„¹ļø data/ is already a symlink: {project_data.readlink()}") + else: + print(" Moving data files to storage (keeping __init__.py)...") + # Copy __init__.py to storage first + init_file = project_data / "__init__.py" + if init_file.exists(): + storage_init = storage_data / "__init__.py" + if not storage_init.exists(): + if args.dry_run: + print(f" [DRY RUN] Would copy: __init__.py -> {storage_init}") + else: + # Ensure storage directory exists + storage_data.mkdir(parents=True, exist_ok=True) + shutil.copy2(init_file, storage_init) + print(" āœ… Copied __init__.py to storage") + else: + print(" ā„¹ļø __init__.py already exists in storage") + + # Move all other files + moved_files = [] + for item in project_data.iterdir(): + if item.name == '__init__.py' or item.name.startswith('__'): + continue + + target_item = storage_data / item.name + if args.dry_run: + print(f" [DRY RUN] Would move: {item.name} -> {target_item}") + moved_files.append(item.name) + else: + if not target_item.exists(): + if item.is_dir(): + shutil.copytree(item, target_item) + shutil.rmtree(item) + else: + shutil.copy2(item, target_item) + item.unlink() + moved_files.append(item.name) + print(f" āœ… Moved: {item.name}") + else: + print(f" āš ļø Already exists: {item.name}") + + # Replace data/ with symlink + if not args.dry_run and moved_files: + # Remove original directory + init_backup = project_data / "__init__.py" + if init_backup.exists(): + # Keep a reference + pass + shutil.rmtree(project_data) + + # Create symlink + project_data.symlink_to(storage_data) + print(f" āœ… Replaced data/ with symlink -> {storage_data}") + else: + print(" ā„¹ļø data/ directory doesn't exist, creating symlink...") + if not args.dry_run: + project_data.symlink_to(storage_data) + print(f" āœ… Created data/ symlink -> {storage_data}") + + # Move checkpoints + if not args.skip_checkpoints: + print(f"\nšŸ’¾ Processing checkpoints...") + + if project_checkpoints.exists(): + print(" Moving checkpoints to storage...") + success = move_and_link( + project_checkpoints, + storage_checkpoints, + "checkpoints", + args.dry_run + ) and success + + if project_checkpoints_test.exists(): + print(" Moving checkpoints_test to storage...") + success = move_and_link( + project_checkpoints_test, + storage_checkpoints_test, + "checkpoints_test", + args.dry_run + ) and success + + if args.dry_run: + print(f"\nāœ… Dry run complete. Use without --dry-run to execute.") + else: + if success: + print(f"\nāœ… Storage setup complete!") + print(f"\nšŸ“‹ Next steps:") + print(f" 1. Your data files are now in: {storage_root}/data/") + print(f" 2. Your checkpoints will be saved to: {storage_root}/checkpoints/") + print(f" 3. Links are created in your project directory") + print(f" 4. Training will automatically use the storage location") + else: + print(f"\nāŒ Some operations failed. Please check the errors above.") + return 1 + + return 0 + + +if __name__ == '__main__': + exit(main()) + diff --git a/train.py b/train.py new file mode 100644 index 0000000..b3c6fac --- /dev/null +++ b/train.py @@ -0,0 +1,292 @@ +""" +Main training script +""" +import torch +import argparse +from pathlib import Path +import sys +import os +import importlib.util + +# Ensure current directory is in path +project_root = Path(__file__).parent.absolute() +sys.path.insert(0, str(project_root)) + +# Explicitly import from local data module to avoid conflicts with stdlib 'data' module +# Python 3.12 has a standard library 'data' module that conflicts with our local data/ +data_module_path = project_root / "data" / "__init__.py" +if not data_module_path.exists(): + # Try alternative paths + alt_paths = [ + project_root / "data" / "__init__.py", + Path("data") / "__init__.py", + Path.cwd() / "data" / "__init__.py", + ] + + found = False + for alt_path in alt_paths: + if alt_path.exists(): + data_module_path = alt_path + found = True + break + + if not found: + error_msg = f"Could not find data module!\n" + error_msg += f" Searched:\n" + error_msg += f" - {project_root / 'data' / '__init__.py'}\n" + error_msg += f" - {Path('data') / '__init__.py'}\n" + error_msg += f" - {Path.cwd() / 'data' / '__init__.py'}\n" + error_msg += f" Current directory: {Path.cwd()}\n" + error_msg += f" Project root: {project_root}\n" + error_msg += f" Does data/ directory exist? {Path(project_root / 'data').exists()}\n" + error_msg += f"\n Please ensure you're running from the project root directory.\n" + error_msg += f" Try: cd && python3 train.py ..." + raise ImportError(error_msg) + +spec = importlib.util.spec_from_file_location("sheepop_data", data_module_path) +sheepop_data = importlib.util.module_from_spec(spec) +spec.loader.exec_module(sheepop_data) + +# Import from the explicitly loaded module +SimpleTokenizer = sheepop_data.SimpleTokenizer +create_dataloader = sheepop_data.create_dataloader +DataProcessor = sheepop_data.DataProcessor +extract_text_from_directory = sheepop_data.extract_text_from_directory + +from models import TransformerModel +from training import Trainer +from config import Config, get_default_config +from dataclasses import asdict + + +def set_seed(seed: int): + """Set random seed for reproducibility.""" + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def main(): + parser = argparse.ArgumentParser(description='Train SheepOp LLM') + parser.add_argument('--config', type=str, help='Path to config file') + parser.add_argument('--data', type=str, required=True, help='Path to training data') + parser.add_argument('--output', type=str, default='./checkpoints', help='Output directory') + parser.add_argument('--resume', type=str, help='Path to checkpoint to resume from') + + # Auto-detect best device + if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + default_device = 'mps' + elif torch.cuda.is_available(): + default_device = 'cuda' + else: + default_device = 'cpu' + + parser.add_argument('--device', type=str, default=default_device, + help=f'Device to use (default: {default_device})') + + args = parser.parse_args() + + # Load configuration + if args.config: + config = Config.from_json(args.config) + else: + config = get_default_config() + + config.device = args.device + config.training.save_dir = args.output + + # Set seed + set_seed(config.seed) + + # Setup device with smart detection + if config.device == 'cuda' and torch.cuda.is_available(): + device = torch.device('cuda') + elif config.device == 'mps' and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + device = torch.device('mps') + elif config.device == 'auto': + # Auto-detect best available device + if torch.cuda.is_available(): + device = torch.device('cuda') + elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + device = torch.device('mps') + print("Auto-detected MPS (Apple Silicon GPU)") + else: + device = torch.device('cpu') + print("Auto-detected CPU") + else: + device = torch.device(config.device) + + print(f"Using device: {device}") + + # Load data - supports multiple file types (PDFs, images, code, text, etc.) + data_path = Path(args.data) + texts = [] + + if data_path.is_file(): + # Single file - try to process it + print(f"Processing single file: {data_path}") + processor = DataProcessor() + texts = list(processor.process_file(data_path)) + elif data_path.is_dir(): + # Directory - process all supported file types + print(f"Processing directory: {data_path}") + print("Supported file types:") + print(" - Text files: .txt, .md, .rst, .log, .csv, .json, .jsonl, .xml, .html, .htm") + print(" - Code files: .py, .js, .ts, .java, .cpp, .c, .go, .rs, .rb, .php, .swift, etc.") + print(" - PDF files: .pdf (requires PyPDF2 or pdfplumber)") + print(" - Images: .png, .jpg, .jpeg, .gif, .bmp, .tiff (requires pytesseract for OCR)") + print() + + # Process directory with all file types + try: + texts = extract_text_from_directory( + directory=data_path, + recursive=True, + use_ocr=True, # Enable OCR for images + use_pdf_extraction=True, # Enable PDF extraction + min_length=10, # Minimum length for text lines + ) + except KeyboardInterrupt: + print("\n\nāš ļø Data processing interrupted by user (Ctrl+C).") + print(" Note: No checkpoint is saved because training hasn't started yet.") + print(" Checkpoints are only saved during training, not during data extraction.") + print(" Please run the training command again to retry.") + raise + else: + raise ValueError(f"Data path {args.data} does not exist") + + if not texts: + raise ValueError(f"No text data extracted from {args.data}. Please check that the directory contains supported file types.") + + print(f"\nāœ… Successfully loaded {len(texts):,} text samples from {data_path}") + print(f" Sample preview (first 3 lines):") + for i, text in enumerate(texts[:3]): + preview = text[:80] + "..." if len(text) > 80 else text + print(f" {i+1}. {preview}") + + # Create tokenizer + tokenizer = SimpleTokenizer() + print(f"Vocabulary size: {tokenizer.vocab_size}") + + # Create data loaders + train_loader = create_dataloader( + texts=texts, + tokenizer=tokenizer, + batch_size=config.training.batch_size, + max_length=config.data.max_length, + shuffle=True, + num_workers=config.data.num_workers, + ) + + # Create model + model_config = config.model + model_config.vocab_size = tokenizer.vocab_size + + # Resume from checkpoint if provided + start_epoch = 0 + checkpoint = None + if args.resume: + checkpoint_path = Path(args.resume) + if not checkpoint_path.exists(): + print(f"āš ļø Warning: Checkpoint file '{args.resume}' not found!") + print(f" Starting fresh training instead...") + args.resume = None # Disable resume flag + else: + print(f"Resuming from checkpoint: {args.resume}") + checkpoint = torch.load(args.resume, map_location=device) + + # Load model config from checkpoint if available + if 'model_config' in checkpoint: + checkpoint_config = checkpoint['model_config'] + model_config.vocab_size = checkpoint_config.get('vocab_size', model_config.vocab_size) + print(f"Loaded model config from checkpoint") + + model = TransformerModel(**asdict(model_config)) + model.load_state_dict(checkpoint['model_state_dict']) + start_epoch = checkpoint.get('epoch', 0) + 1 # Start from next epoch + print(f"Resuming from epoch {start_epoch}") + + if not args.resume: + model = TransformerModel(**asdict(model_config)) + + print(f"Model created with {model.get_num_params():,} parameters") + + # Setup optimizer + optimizer = torch.optim.AdamW( + model.parameters(), + lr=config.training.learning_rate, + weight_decay=config.training.weight_decay, + betas=(0.9, 0.999), + ) + + # Load optimizer state if resuming + if args.resume: + if 'optimizer_state_dict' in checkpoint: + # Move optimizer state to correct device + optimizer_state = checkpoint['optimizer_state_dict'] + for state in optimizer_state['state'].values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(device) + optimizer.load_state_dict(optimizer_state) + print("Loaded optimizer state from checkpoint") + + # Setup scheduler + total_steps = len(train_loader) * config.training.max_epochs + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=total_steps, + ) + + # Load scheduler state if resuming + if args.resume: + if 'scheduler_state_dict' in checkpoint and scheduler is not None: + # Scheduler state usually doesn't need device transfer, but let's be safe + scheduler_state = checkpoint['scheduler_state_dict'] + scheduler.load_state_dict(scheduler_state) + print("Loaded scheduler state from checkpoint") + + # Create trainer + trainer = Trainer( + model=model, + train_loader=train_loader, + val_loader=None, # Can add validation loader + optimizer=optimizer, + scheduler=scheduler, + device=device, + max_epochs=config.training.max_epochs, + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + max_grad_norm=config.training.max_grad_norm, + use_amp=config.training.use_amp, + save_dir=config.training.save_dir, + log_interval=config.training.log_interval, + eval_interval=config.training.eval_interval, + ) + + # Set trainer state if resuming + if args.resume: + trainer.current_epoch = start_epoch - 1 + trainer.global_step = checkpoint.get('global_step', 0) + trainer.best_val_loss = checkpoint.get('best_val_loss', float('inf')) + print(f"Resuming from global step {trainer.global_step}") + + # Store model config for checkpoint saving + model_config_dict = asdict(model_config) + + # Override save_checkpoint to include model config + original_save_checkpoint = trainer.save_checkpoint + def save_checkpoint_with_config(is_best=False): + original_save_checkpoint(is_best=is_best, model_config=model_config_dict) + trainer.save_checkpoint = save_checkpoint_with_config + + # Train + trainer.train() + + print("Training completed!") + + +if __name__ == '__main__': + from dataclasses import asdict + main() + diff --git a/training/__init__.py b/training/__init__.py new file mode 100644 index 0000000..3213423 --- /dev/null +++ b/training/__init__.py @@ -0,0 +1,405 @@ +""" +Training utilities and training loop +""" +import torch +import torch.nn as nn +from torch.optim import AdamW +from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR +from typing import Dict, Optional, Callable +from pathlib import Path +import json +import sys +from tqdm import tqdm +import math +from .metrics import TrainingMetrics + + +class Trainer: + """ + Trainer class for language model training. + Includes gradient accumulation, mixed precision training, and checkpointing. + """ + + def __init__( + self, + model: nn.Module, + train_loader, + val_loader=None, + optimizer=None, + scheduler=None, + device: str = 'cuda', + max_epochs: int = 10, + gradient_accumulation_steps: int = 1, + max_grad_norm: float = 1.0, + use_amp: bool = True, + save_dir: str = './checkpoints', + log_interval: int = 100, + eval_interval: int = 1000, + ): + """ + Args: + model: Model to train + train_loader: Training data loader + val_loader: Validation data loader (optional) + optimizer: Optimizer (if None, AdamW is used) + scheduler: Learning rate scheduler (optional) + device: Device to train on + max_epochs: Maximum number of epochs + gradient_accumulation_steps: Gradient accumulation steps + max_grad_norm: Maximum gradient norm for clipping + use_amp: Whether to use mixed precision training + save_dir: Directory to save checkpoints + log_interval: Logging interval + eval_interval: Evaluation interval + """ + # Convert device string to torch.device if needed + if isinstance(device, str): + self.device = torch.device(device) + else: + self.device = device + + self.model = model.to(self.device) + self.train_loader = train_loader + self.val_loader = val_loader + self.max_epochs = max_epochs + self.gradient_accumulation_steps = gradient_accumulation_steps + self.max_grad_norm = max_grad_norm + self.save_dir = Path(save_dir) + self.save_dir.mkdir(parents=True, exist_ok=True) + self.log_interval = log_interval + self.eval_interval = eval_interval + + # Setup optimizer + if optimizer is None: + self.optimizer = AdamW( + self.model.parameters(), + lr=1e-4, + betas=(0.9, 0.999), + weight_decay=0.01, + ) + else: + self.optimizer = optimizer + + # Setup scheduler + self.scheduler = scheduler + + # Determine device type for AMP + # Convert device string to torch.device if needed + if isinstance(device, str): + self.device = torch.device(device) + else: + self.device = device + + device_type = self.device.type + + # Setup mixed precision training (only for CUDA) + self.device_type = device_type + self.use_amp = use_amp and device_type == 'cuda' # Only use AMP for CUDA + + if self.use_amp: + # Use new device-agnostic API + self.scaler = torch.amp.GradScaler('cuda') + self.autocast_dtype = torch.float16 + else: + self.scaler = None + self.autocast_dtype = None + + # Loss function + self.criterion = nn.CrossEntropyLoss(ignore_index=-100) + + # Training state + self.current_epoch = 0 + self.global_step = 0 + self.best_val_loss = float('inf') + + # Training metrics tracking + self.metrics = TrainingMetrics(save_dir=save_dir) + + def train_epoch(self) -> Dict[str, float]: + """Train for one epoch.""" + self.model.train() + total_loss = 0.0 + num_batches = 0 + + progress_bar = tqdm( + self.train_loader, + desc=f"Epoch {self.current_epoch + 1}", + mininterval=0.1, + maxinterval=1.0, + file=sys.stderr, # Write to stderr to avoid buffering issues + dynamic_ncols=True, # Auto-adjust to terminal width + disable=False, # Explicitly enable + ) + + for batch_idx, batch in enumerate(progress_bar): + input_ids = batch['input_ids'].to(self.device) + labels = batch['labels'].to(self.device) + + # Forward pass with mixed precision (only for CUDA) + if self.use_amp: + with torch.amp.autocast('cuda', dtype=self.autocast_dtype): + logits, _ = self.model(input_ids) + + # Reshape for loss computation + logits = logits.view(-1, logits.size(-1)) + labels = labels.view(-1) + + loss = self.criterion(logits, labels) + loss = loss / self.gradient_accumulation_steps + else: + logits, _ = self.model(input_ids) + + # Reshape for loss computation + logits = logits.view(-1, logits.size(-1)) + labels = labels.view(-1) + + loss = self.criterion(logits, labels) + loss = loss / self.gradient_accumulation_steps + + # Backward pass + if self.use_amp: + self.scaler.scale(loss).backward() + else: + loss.backward() + + # Gradient accumulation + if (batch_idx + 1) % self.gradient_accumulation_steps == 0: + # Gradient clipping + if self.use_amp: + self.scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + self.max_grad_norm + ) + self.scaler.step(self.optimizer) + self.scaler.update() + else: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + self.max_grad_norm + ) + self.optimizer.step() + + if self.scheduler is not None: + self.scheduler.step() + + self.optimizer.zero_grad() + self.global_step += 1 + + total_loss += loss.item() * self.gradient_accumulation_steps + num_batches += 1 + + # Logging + if self.global_step % self.log_interval == 0: + avg_loss = total_loss / num_batches + lr = self.optimizer.param_groups[0]['lr'] + progress_bar.set_postfix({ + 'loss': f'{avg_loss:.4f}', + 'lr': f'{lr:.2e}', + }) + progress_bar.refresh() # Force immediate refresh + sys.stderr.flush() # Force flush stderr to ensure progress bar displays + + # Log metrics + self.metrics.log( + epoch=self.current_epoch, + step=self.global_step, + train_loss=avg_loss, + lr=lr, + ) + + # Evaluation + if self.val_loader is not None and self.global_step % self.eval_interval == 0: + val_loss = self.evaluate() + if val_loss < self.best_val_loss: + self.best_val_loss = val_loss + self.save_checkpoint(is_best=True) + + avg_loss = total_loss / num_batches + + # Log epoch metrics + self.metrics.log( + epoch=self.current_epoch, + step=self.global_step, + train_loss=avg_loss, + lr=self.optimizer.param_groups[0]['lr'], + ) + + return {'loss': avg_loss} + + @torch.no_grad() + def evaluate(self) -> float: + """Evaluate on validation set.""" + if self.val_loader is None: + return float('inf') + + self.model.eval() + total_loss = 0.0 + num_batches = 0 + + for batch in tqdm( + self.val_loader, + desc="Evaluating", + mininterval=0.1, + maxinterval=1.0, + file=sys.stderr, # Write to stderr to avoid buffering issues + dynamic_ncols=True, # Auto-adjust to terminal width + disable=False, # Explicitly enable + ): + input_ids = batch['input_ids'].to(self.device) + labels = batch['labels'].to(self.device) + + if self.use_amp: + with torch.amp.autocast('cuda', dtype=self.autocast_dtype): + logits, _ = self.model(input_ids) + logits = logits.view(-1, logits.size(-1)) + labels = labels.view(-1) + loss = self.criterion(logits, labels) + else: + logits, _ = self.model(input_ids) + logits = logits.view(-1, logits.size(-1)) + labels = labels.view(-1) + loss = self.criterion(logits, labels) + + total_loss += loss.item() + num_batches += 1 + + avg_loss = total_loss / num_batches + return avg_loss + + def train(self): + """Main training loop.""" + try: + for epoch in range(self.current_epoch, self.max_epochs): + self.current_epoch = epoch + + # Train epoch + train_metrics = self.train_epoch() + + # Evaluation at end of epoch + if self.val_loader is not None: + val_loss = self.evaluate() + print(f"Epoch {epoch + 1}: Train Loss = {train_metrics['loss']:.4f}, " + f"Val Loss = {val_loss:.4f}") + else: + print(f"Epoch {epoch + 1}: Train Loss = {train_metrics['loss']:.4f}") + + # Save checkpoint + self.save_checkpoint() + + # Generate plots at end of training + print("\nšŸ“Š Generating training plots...") + try: + self.metrics.plot_training_curve() + self.metrics.plot_loss_by_epoch() + self.metrics.print_summary() + except Exception as e: + print(f"Warning: Could not generate plots: {e}") + + except KeyboardInterrupt: + print("\n\nāš ļø Training interrupted by user!") + print(f"šŸ’¾ Saving checkpoint at epoch {self.current_epoch + 1}...") + self.save_checkpoint() + print(f"āœ… Checkpoint saved! You can resume with:") + print(f" python3 train.py --data --resume {self.save_dir}/checkpoint_epoch_{self.current_epoch}.pt") + + # Generate plots before exiting + print("\nšŸ“Š Generating training plots...") + try: + self.metrics.plot_training_curve() + self.metrics.plot_loss_by_epoch() + self.metrics.print_summary() + except Exception as e: + print(f"Warning: Could not generate plots: {e}") + + raise + + def save_checkpoint(self, is_best: bool = False, model_config: dict = None): + """Save model checkpoint.""" + checkpoint = { + 'epoch': self.current_epoch, + 'global_step': self.global_step, + 'model_state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'best_val_loss': self.best_val_loss, + } + + # Save model config if provided + if model_config is not None: + checkpoint['model_config'] = model_config + + if self.scheduler is not None: + checkpoint['scheduler_state_dict'] = self.scheduler.state_dict() + + # Save regular checkpoint + checkpoint_path = self.save_dir / f'checkpoint_epoch_{self.current_epoch}.pt' + torch.save(checkpoint, checkpoint_path) + + # Save best checkpoint + if is_best: + best_path = self.save_dir / 'best_checkpoint.pt' + torch.save(checkpoint, best_path) + print(f"Saved best checkpoint with val_loss = {self.best_val_loss:.4f}") + + def load_checkpoint(self, checkpoint_path: str): + """Load model checkpoint.""" + checkpoint = torch.load(checkpoint_path, map_location=self.device) + + self.model.load_state_dict(checkpoint['model_state_dict']) + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + if 'scheduler_state_dict' in checkpoint and self.scheduler is not None: + self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + + self.current_epoch = checkpoint.get('epoch', 0) + self.global_step = checkpoint.get('global_step', 0) + self.best_val_loss = checkpoint.get('best_val_loss', float('inf')) + + print(f"Loaded checkpoint from epoch {self.current_epoch}") + + +def compute_perplexity(model: nn.Module, data_loader, device: str = 'cuda') -> float: + """ + Compute perplexity on a dataset. + + Args: + model: Model to evaluate + data_loader: Data loader + device: Device to use + + Returns: + Perplexity score + """ + model.eval() + total_loss = 0.0 + num_tokens = 0 + + criterion = nn.CrossEntropyLoss(ignore_index=-100, reduction='sum') + + with torch.no_grad(): + for batch in tqdm( + data_loader, + desc="Computing perplexity", + mininterval=0.1, + maxinterval=1.0, + file=sys.stderr, # Write to stderr to avoid buffering issues + dynamic_ncols=True, # Auto-adjust to terminal width + disable=False, # Explicitly enable + ): + input_ids = batch['input_ids'].to(device) + labels = batch['labels'].to(device) + + logits, _ = model(input_ids) + logits = logits.view(-1, logits.size(-1)) + labels = labels.view(-1) + + loss = criterion(logits, labels) + total_loss += loss.item() + num_tokens += (labels != -100).sum().item() + + avg_loss = total_loss / num_tokens + perplexity = math.exp(avg_loss) + + return perplexity + + diff --git a/training/metrics.py b/training/metrics.py new file mode 100644 index 0000000..8ca4399 --- /dev/null +++ b/training/metrics.py @@ -0,0 +1,190 @@ +""" +Training metrics tracking and plotting utilities +""" +import json +import matplotlib.pyplot as plt +from pathlib import Path +from typing import Dict, List, Optional +import numpy as np + + +class TrainingMetrics: + """ + Track and plot training metrics during training. + """ + + def __init__(self, save_dir: str = './checkpoints'): + """ + Args: + save_dir: Directory to save metrics and plots + """ + self.save_dir = Path(save_dir) + self.save_dir.mkdir(parents=True, exist_ok=True) + + self.metrics_file = self.save_dir / 'training_metrics.json' + + # Load existing metrics if available + if self.metrics_file.exists(): + with open(self.metrics_file, 'r') as f: + self.metrics = json.load(f) + else: + self.metrics = { + 'train_loss': [], + 'val_loss': [], + 'learning_rate': [], + 'epochs': [], + 'steps': [], + } + + def log(self, epoch: int, step: int, train_loss: float, + val_loss: Optional[float] = None, lr: Optional[float] = None): + """ + Log training metrics. + + Args: + epoch: Current epoch + step: Current global step + train_loss: Training loss + val_loss: Validation loss (optional) + lr: Learning rate (optional) + """ + self.metrics['train_loss'].append(train_loss) + self.metrics['epochs'].append(epoch) + self.metrics['steps'].append(step) + + if val_loss is not None: + self.metrics['val_loss'].append(val_loss) + else: + self.metrics['val_loss'].append(None) + + if lr is not None: + self.metrics['learning_rate'].append(lr) + else: + self.metrics['learning_rate'].append(None) + + # Save to file + self.save() + + def save(self): + """Save metrics to JSON file.""" + with open(self.metrics_file, 'w') as f: + json.dump(self.metrics, f, indent=2) + + def plot_training_curve(self, save_path: Optional[str] = None): + """ + Plot training and validation loss curves. + + Args: + save_path: Path to save plot (default: save_dir/training_curve.png) + """ + if save_path is None: + save_path = self.save_dir / 'training_curve.png' + + fig, axes = plt.subplots(2, 1, figsize=(12, 8)) + + # Plot 1: Loss curves + ax1 = axes[0] + steps = self.metrics['steps'] + train_loss = self.metrics['train_loss'] + val_loss = [v for v in self.metrics['val_loss'] if v is not None] + val_steps = [steps[i] for i, v in enumerate(self.metrics['val_loss']) if v is not None] + + ax1.plot(steps, train_loss, label='Train Loss', color='blue', alpha=0.7) + if val_loss: + ax1.plot(val_steps, val_loss, label='Val Loss', color='red', alpha=0.7) + + ax1.set_xlabel('Step') + ax1.set_ylabel('Loss') + ax1.set_title('Training and Validation Loss') + ax1.legend() + ax1.grid(True, alpha=0.3) + + # Plot 2: Learning rate + ax2 = axes[1] + lr = [v for v in self.metrics['learning_rate'] if v is not None] + lr_steps = [steps[i] for i, v in enumerate(self.metrics['learning_rate']) if v is not None] + + if lr: + ax2.plot(lr_steps, lr, label='Learning Rate', color='green', alpha=0.7) + ax2.set_xlabel('Step') + ax2.set_ylabel('Learning Rate') + ax2.set_title('Learning Rate Schedule') + ax2.legend() + ax2.grid(True, alpha=0.3) + ax2.set_yscale('log') + + plt.tight_layout() + plt.savefig(save_path, dpi=150, bbox_inches='tight') + print(f"šŸ“Š Training curve saved to: {save_path}") + plt.close() + + def plot_loss_by_epoch(self, save_path: Optional[str] = None): + """ + Plot loss averaged by epoch. + + Args: + save_path: Path to save plot (default: save_dir/loss_by_epoch.png) + """ + if save_path is None: + save_path = self.save_dir / 'loss_by_epoch.png' + + # Group losses by epoch + epochs = self.metrics['epochs'] + train_loss = self.metrics['train_loss'] + + epoch_losses = {} + for epoch, loss in zip(epochs, train_loss): + if epoch not in epoch_losses: + epoch_losses[epoch] = [] + epoch_losses[epoch].append(loss) + + # Average losses per epoch + epoch_nums = sorted(epoch_losses.keys()) + avg_losses = [np.mean(epoch_losses[e]) for e in epoch_nums] + + plt.figure(figsize=(10, 6)) + plt.plot(epoch_nums, avg_losses, marker='o', label='Average Train Loss', color='blue') + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.title('Training Loss by Epoch') + plt.legend() + plt.grid(True, alpha=0.3) + plt.savefig(save_path, dpi=150, bbox_inches='tight') + print(f"šŸ“Š Loss by epoch plot saved to: {save_path}") + plt.close() + + def get_summary(self) -> Dict: + """ + Get summary statistics of training. + + Returns: + Dictionary with summary statistics + """ + train_loss = self.metrics['train_loss'] + val_loss = [v for v in self.metrics['val_loss'] if v is not None] + + summary = { + 'total_steps': len(train_loss), + 'total_epochs': max(self.metrics['epochs']) + 1 if self.metrics['epochs'] else 0, + 'final_train_loss': train_loss[-1] if train_loss else None, + 'best_train_loss': min(train_loss) if train_loss else None, + 'final_val_loss': val_loss[-1] if val_loss else None, + 'best_val_loss': min(val_loss) if val_loss else None, + } + + return summary + + def print_summary(self): + """Print training summary.""" + summary = self.get_summary() + print("\n" + "=" * 60) + print("Training Summary") + print("=" * 60) + print(f"Total Steps: {summary['total_steps']}") + print(f"Total Epochs: {summary['total_epochs']}") + print(f"Final Train Loss: {summary['final_train_loss']:.4f}" if summary['final_train_loss'] else "Final Train Loss: N/A") + print(f"Best Train Loss: {summary['best_train_loss']:.4f}" if summary['best_train_loss'] else "Best Train Loss: N/A") + print(f"Final Val Loss: {summary['final_val_loss']:.4f}" if summary['final_val_loss'] else "Final Val Loss: N/A") + print(f"Best Val Loss: {summary['best_val_loss']:.4f}" if summary['best_val_loss'] else "Best Val Loss: N/A") + print("=" * 60) + diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..ce2e6d2 --- /dev/null +++ b/utils.py @@ -0,0 +1,109 @@ +""" +Utility functions for model evaluation and metrics +""" +import torch +import torch.nn as nn +from typing import List, Dict +import numpy as np +import sys +from tqdm import tqdm + + +def compute_accuracy(model: nn.Module, data_loader, device: str = 'cuda') -> float: + """ + Compute token-level accuracy. + + Args: + model: Model to evaluate + data_loader: Data loader + device: Device to use + + Returns: + Accuracy score + """ + model.eval() + correct = 0 + total = 0 + + with torch.no_grad(): + for batch in tqdm( + data_loader, + desc="Computing accuracy", + mininterval=0.1, + maxinterval=1.0, + file=sys.stderr, # Write to stderr to avoid buffering issues + dynamic_ncols=True, # Auto-adjust to terminal width + disable=False, # Explicitly enable + ): + input_ids = batch['input_ids'].to(device) + labels = batch['labels'].to(device) + + logits, _ = model(input_ids) + predictions = torch.argmax(logits, dim=-1) + + # Mask out padding tokens + mask = (labels != -100) + correct += ((predictions == labels) * mask).sum().item() + total += mask.sum().item() + + accuracy = correct / total if total > 0 else 0.0 + return accuracy + + +def compute_metrics(model: nn.Module, data_loader, device: str = 'cuda') -> Dict[str, float]: + """ + Compute various evaluation metrics. + + Args: + model: Model to evaluate + data_loader: Data loader + device: Device to use + + Returns: + Dictionary of metrics + """ + model.eval() + total_loss = 0.0 + correct = 0 + total_tokens = 0 + + criterion = nn.CrossEntropyLoss(ignore_index=-100, reduction='sum') + + with torch.no_grad(): + for batch in tqdm( + data_loader, + desc="Computing metrics", + mininterval=0.1, + maxinterval=1.0, + file=sys.stderr, # Write to stderr to avoid buffering issues + dynamic_ncols=True, # Auto-adjust to terminal width + disable=False, # Explicitly enable + ): + input_ids = batch['input_ids'].to(device) + labels = batch['labels'].to(device) + + logits, _ = model(input_ids) + logits = logits.view(-1, logits.size(-1)) + labels_flat = labels.view(-1) + + # Loss + loss = criterion(logits, labels_flat) + total_loss += loss.item() + + # Accuracy + predictions = torch.argmax(logits, dim=-1) + mask = (labels_flat != -100) + correct += ((predictions == labels_flat) * mask).sum().item() + total_tokens += mask.sum().item() + + avg_loss = total_loss / total_tokens if total_tokens > 0 else 0.0 + accuracy = correct / total_tokens if total_tokens > 0 else 0.0 + perplexity = np.exp(avg_loss) if avg_loss > 0 else float('inf') + + return { + 'loss': avg_loss, + 'accuracy': accuracy, + 'perplexity': perplexity, + } + +