Initial commit: SheepOp LLM - Transformer-based language model implementation
- Complete transformer implementation from scratch - Training pipeline with gradient accumulation and mixed precision - Optimized inference with KV caching - Multi-format data processing (PDFs, images, code, text) - Comprehensive documentation - Apache 2.0 license - Example training plots included in docs/images/
This commit is contained in:
292
train.py
Normal file
292
train.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""
|
||||
Main training script
|
||||
"""
|
||||
import torch
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import os
|
||||
import importlib.util
|
||||
|
||||
# Ensure current directory is in path
|
||||
project_root = Path(__file__).parent.absolute()
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Explicitly import from local data module to avoid conflicts with stdlib 'data' module
|
||||
# Python 3.12 has a standard library 'data' module that conflicts with our local data/
|
||||
data_module_path = project_root / "data" / "__init__.py"
|
||||
if not data_module_path.exists():
|
||||
# Try alternative paths
|
||||
alt_paths = [
|
||||
project_root / "data" / "__init__.py",
|
||||
Path("data") / "__init__.py",
|
||||
Path.cwd() / "data" / "__init__.py",
|
||||
]
|
||||
|
||||
found = False
|
||||
for alt_path in alt_paths:
|
||||
if alt_path.exists():
|
||||
data_module_path = alt_path
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
error_msg = f"Could not find data module!\n"
|
||||
error_msg += f" Searched:\n"
|
||||
error_msg += f" - {project_root / 'data' / '__init__.py'}\n"
|
||||
error_msg += f" - {Path('data') / '__init__.py'}\n"
|
||||
error_msg += f" - {Path.cwd() / 'data' / '__init__.py'}\n"
|
||||
error_msg += f" Current directory: {Path.cwd()}\n"
|
||||
error_msg += f" Project root: {project_root}\n"
|
||||
error_msg += f" Does data/ directory exist? {Path(project_root / 'data').exists()}\n"
|
||||
error_msg += f"\n Please ensure you're running from the project root directory.\n"
|
||||
error_msg += f" Try: cd <project_root> && python3 train.py ..."
|
||||
raise ImportError(error_msg)
|
||||
|
||||
spec = importlib.util.spec_from_file_location("sheepop_data", data_module_path)
|
||||
sheepop_data = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(sheepop_data)
|
||||
|
||||
# Import from the explicitly loaded module
|
||||
SimpleTokenizer = sheepop_data.SimpleTokenizer
|
||||
create_dataloader = sheepop_data.create_dataloader
|
||||
DataProcessor = sheepop_data.DataProcessor
|
||||
extract_text_from_directory = sheepop_data.extract_text_from_directory
|
||||
|
||||
from models import TransformerModel
|
||||
from training import Trainer
|
||||
from config import Config, get_default_config
|
||||
from dataclasses import asdict
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
"""Set random seed for reproducibility."""
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Train SheepOp LLM')
|
||||
parser.add_argument('--config', type=str, help='Path to config file')
|
||||
parser.add_argument('--data', type=str, required=True, help='Path to training data')
|
||||
parser.add_argument('--output', type=str, default='./checkpoints', help='Output directory')
|
||||
parser.add_argument('--resume', type=str, help='Path to checkpoint to resume from')
|
||||
|
||||
# Auto-detect best device
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
default_device = 'mps'
|
||||
elif torch.cuda.is_available():
|
||||
default_device = 'cuda'
|
||||
else:
|
||||
default_device = 'cpu'
|
||||
|
||||
parser.add_argument('--device', type=str, default=default_device,
|
||||
help=f'Device to use (default: {default_device})')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load configuration
|
||||
if args.config:
|
||||
config = Config.from_json(args.config)
|
||||
else:
|
||||
config = get_default_config()
|
||||
|
||||
config.device = args.device
|
||||
config.training.save_dir = args.output
|
||||
|
||||
# Set seed
|
||||
set_seed(config.seed)
|
||||
|
||||
# Setup device with smart detection
|
||||
if config.device == 'cuda' and torch.cuda.is_available():
|
||||
device = torch.device('cuda')
|
||||
elif config.device == 'mps' and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
device = torch.device('mps')
|
||||
elif config.device == 'auto':
|
||||
# Auto-detect best available device
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device('cuda')
|
||||
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
device = torch.device('mps')
|
||||
print("Auto-detected MPS (Apple Silicon GPU)")
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
print("Auto-detected CPU")
|
||||
else:
|
||||
device = torch.device(config.device)
|
||||
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Load data - supports multiple file types (PDFs, images, code, text, etc.)
|
||||
data_path = Path(args.data)
|
||||
texts = []
|
||||
|
||||
if data_path.is_file():
|
||||
# Single file - try to process it
|
||||
print(f"Processing single file: {data_path}")
|
||||
processor = DataProcessor()
|
||||
texts = list(processor.process_file(data_path))
|
||||
elif data_path.is_dir():
|
||||
# Directory - process all supported file types
|
||||
print(f"Processing directory: {data_path}")
|
||||
print("Supported file types:")
|
||||
print(" - Text files: .txt, .md, .rst, .log, .csv, .json, .jsonl, .xml, .html, .htm")
|
||||
print(" - Code files: .py, .js, .ts, .java, .cpp, .c, .go, .rs, .rb, .php, .swift, etc.")
|
||||
print(" - PDF files: .pdf (requires PyPDF2 or pdfplumber)")
|
||||
print(" - Images: .png, .jpg, .jpeg, .gif, .bmp, .tiff (requires pytesseract for OCR)")
|
||||
print()
|
||||
|
||||
# Process directory with all file types
|
||||
try:
|
||||
texts = extract_text_from_directory(
|
||||
directory=data_path,
|
||||
recursive=True,
|
||||
use_ocr=True, # Enable OCR for images
|
||||
use_pdf_extraction=True, # Enable PDF extraction
|
||||
min_length=10, # Minimum length for text lines
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⚠️ Data processing interrupted by user (Ctrl+C).")
|
||||
print(" Note: No checkpoint is saved because training hasn't started yet.")
|
||||
print(" Checkpoints are only saved during training, not during data extraction.")
|
||||
print(" Please run the training command again to retry.")
|
||||
raise
|
||||
else:
|
||||
raise ValueError(f"Data path {args.data} does not exist")
|
||||
|
||||
if not texts:
|
||||
raise ValueError(f"No text data extracted from {args.data}. Please check that the directory contains supported file types.")
|
||||
|
||||
print(f"\n✅ Successfully loaded {len(texts):,} text samples from {data_path}")
|
||||
print(f" Sample preview (first 3 lines):")
|
||||
for i, text in enumerate(texts[:3]):
|
||||
preview = text[:80] + "..." if len(text) > 80 else text
|
||||
print(f" {i+1}. {preview}")
|
||||
|
||||
# Create tokenizer
|
||||
tokenizer = SimpleTokenizer()
|
||||
print(f"Vocabulary size: {tokenizer.vocab_size}")
|
||||
|
||||
# Create data loaders
|
||||
train_loader = create_dataloader(
|
||||
texts=texts,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=config.training.batch_size,
|
||||
max_length=config.data.max_length,
|
||||
shuffle=True,
|
||||
num_workers=config.data.num_workers,
|
||||
)
|
||||
|
||||
# Create model
|
||||
model_config = config.model
|
||||
model_config.vocab_size = tokenizer.vocab_size
|
||||
|
||||
# Resume from checkpoint if provided
|
||||
start_epoch = 0
|
||||
checkpoint = None
|
||||
if args.resume:
|
||||
checkpoint_path = Path(args.resume)
|
||||
if not checkpoint_path.exists():
|
||||
print(f"⚠️ Warning: Checkpoint file '{args.resume}' not found!")
|
||||
print(f" Starting fresh training instead...")
|
||||
args.resume = None # Disable resume flag
|
||||
else:
|
||||
print(f"Resuming from checkpoint: {args.resume}")
|
||||
checkpoint = torch.load(args.resume, map_location=device)
|
||||
|
||||
# Load model config from checkpoint if available
|
||||
if 'model_config' in checkpoint:
|
||||
checkpoint_config = checkpoint['model_config']
|
||||
model_config.vocab_size = checkpoint_config.get('vocab_size', model_config.vocab_size)
|
||||
print(f"Loaded model config from checkpoint")
|
||||
|
||||
model = TransformerModel(**asdict(model_config))
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
start_epoch = checkpoint.get('epoch', 0) + 1 # Start from next epoch
|
||||
print(f"Resuming from epoch {start_epoch}")
|
||||
|
||||
if not args.resume:
|
||||
model = TransformerModel(**asdict(model_config))
|
||||
|
||||
print(f"Model created with {model.get_num_params():,} parameters")
|
||||
|
||||
# Setup optimizer
|
||||
optimizer = torch.optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=config.training.learning_rate,
|
||||
weight_decay=config.training.weight_decay,
|
||||
betas=(0.9, 0.999),
|
||||
)
|
||||
|
||||
# Load optimizer state if resuming
|
||||
if args.resume:
|
||||
if 'optimizer_state_dict' in checkpoint:
|
||||
# Move optimizer state to correct device
|
||||
optimizer_state = checkpoint['optimizer_state_dict']
|
||||
for state in optimizer_state['state'].values():
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
state[k] = v.to(device)
|
||||
optimizer.load_state_dict(optimizer_state)
|
||||
print("Loaded optimizer state from checkpoint")
|
||||
|
||||
# Setup scheduler
|
||||
total_steps = len(train_loader) * config.training.max_epochs
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer,
|
||||
T_max=total_steps,
|
||||
)
|
||||
|
||||
# Load scheduler state if resuming
|
||||
if args.resume:
|
||||
if 'scheduler_state_dict' in checkpoint and scheduler is not None:
|
||||
# Scheduler state usually doesn't need device transfer, but let's be safe
|
||||
scheduler_state = checkpoint['scheduler_state_dict']
|
||||
scheduler.load_state_dict(scheduler_state)
|
||||
print("Loaded scheduler state from checkpoint")
|
||||
|
||||
# Create trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=None, # Can add validation loader
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
device=device,
|
||||
max_epochs=config.training.max_epochs,
|
||||
gradient_accumulation_steps=config.training.gradient_accumulation_steps,
|
||||
max_grad_norm=config.training.max_grad_norm,
|
||||
use_amp=config.training.use_amp,
|
||||
save_dir=config.training.save_dir,
|
||||
log_interval=config.training.log_interval,
|
||||
eval_interval=config.training.eval_interval,
|
||||
)
|
||||
|
||||
# Set trainer state if resuming
|
||||
if args.resume:
|
||||
trainer.current_epoch = start_epoch - 1
|
||||
trainer.global_step = checkpoint.get('global_step', 0)
|
||||
trainer.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
|
||||
print(f"Resuming from global step {trainer.global_step}")
|
||||
|
||||
# Store model config for checkpoint saving
|
||||
model_config_dict = asdict(model_config)
|
||||
|
||||
# Override save_checkpoint to include model config
|
||||
original_save_checkpoint = trainer.save_checkpoint
|
||||
def save_checkpoint_with_config(is_best=False):
|
||||
original_save_checkpoint(is_best=is_best, model_config=model_config_dict)
|
||||
trainer.save_checkpoint = save_checkpoint_with_config
|
||||
|
||||
# Train
|
||||
trainer.train()
|
||||
|
||||
print("Training completed!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from dataclasses import asdict
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user