From fb1ca67be96c15c2c3a8e6d0c13d0507fa2abcdc Mon Sep 17 00:00:00 2001 From: Carlos Gutierrez Date: Sun, 16 Nov 2025 16:52:26 -0500 Subject: [PATCH] fixing memory --- training/__init__.py | 148 ++++++++++++++++++++++--------------------- 1 file changed, 76 insertions(+), 72 deletions(-) diff --git a/training/__init__.py b/training/__init__.py index 3213423..1b761e7 100644 --- a/training/__init__.py +++ b/training/__init__.py @@ -131,13 +131,23 @@ class Trainer: disable=False, # Explicitly enable ) - for batch_idx, batch in enumerate(progress_bar): - input_ids = batch['input_ids'].to(self.device) - labels = batch['labels'].to(self.device) - - # Forward pass with mixed precision (only for CUDA) - if self.use_amp: - with torch.amp.autocast('cuda', dtype=self.autocast_dtype): + try: + for batch_idx, batch in enumerate(progress_bar): + input_ids = batch['input_ids'].to(self.device) + labels = batch['labels'].to(self.device) + + # Forward pass with mixed precision (only for CUDA) + if self.use_amp: + with torch.amp.autocast('cuda', dtype=self.autocast_dtype): + logits, _ = self.model(input_ids) + + # Reshape for loss computation + logits = logits.view(-1, logits.size(-1)) + labels = labels.view(-1) + + loss = self.criterion(logits, labels) + loss = loss / self.gradient_accumulation_steps + else: logits, _ = self.model(input_ids) # Reshape for loss computation @@ -146,76 +156,70 @@ class Trainer: loss = self.criterion(logits, labels) loss = loss / self.gradient_accumulation_steps - else: - logits, _ = self.model(input_ids) - # Reshape for loss computation - logits = logits.view(-1, logits.size(-1)) - labels = labels.view(-1) - - loss = self.criterion(logits, labels) - loss = loss / self.gradient_accumulation_steps - - # Backward pass - if self.use_amp: - self.scaler.scale(loss).backward() - else: - loss.backward() - - # Gradient accumulation - if (batch_idx + 1) % self.gradient_accumulation_steps == 0: - # Gradient clipping + # Backward pass if self.use_amp: - self.scaler.unscale_(self.optimizer) - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), - self.max_grad_norm - ) - self.scaler.step(self.optimizer) - self.scaler.update() + self.scaler.scale(loss).backward() else: - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), - self.max_grad_norm + loss.backward() + + # Gradient accumulation + if (batch_idx + 1) % self.gradient_accumulation_steps == 0: + # Gradient clipping + if self.use_amp: + self.scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + self.max_grad_norm + ) + self.scaler.step(self.optimizer) + self.scaler.update() + else: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + self.max_grad_norm + ) + self.optimizer.step() + + if self.scheduler is not None: + self.scheduler.step() + + self.optimizer.zero_grad() + self.global_step += 1 + + total_loss += loss.item() * self.gradient_accumulation_steps + num_batches += 1 + + # Logging + if self.global_step % self.log_interval == 0: + avg_loss = total_loss / num_batches if num_batches > 0 else 0.0 + lr = self.optimizer.param_groups[0]['lr'] + progress_bar.set_postfix({ + 'loss': f'{avg_loss:.4f}', + 'lr': f'{lr:.2e}', + }) + progress_bar.refresh() # Force immediate refresh + sys.stderr.flush() # Force flush stderr to ensure progress bar displays + + # Log metrics + self.metrics.log( + epoch=self.current_epoch, + step=self.global_step, + train_loss=avg_loss, + lr=lr, ) - self.optimizer.step() - - if self.scheduler is not None: - self.scheduler.step() - - self.optimizer.zero_grad() - self.global_step += 1 - - total_loss += loss.item() * self.gradient_accumulation_steps - num_batches += 1 - - # Logging - if self.global_step % self.log_interval == 0: - avg_loss = total_loss / num_batches - lr = self.optimizer.param_groups[0]['lr'] - progress_bar.set_postfix({ - 'loss': f'{avg_loss:.4f}', - 'lr': f'{lr:.2e}', - }) - progress_bar.refresh() # Force immediate refresh - sys.stderr.flush() # Force flush stderr to ensure progress bar displays - - # Log metrics - self.metrics.log( - epoch=self.current_epoch, - step=self.global_step, - train_loss=avg_loss, - lr=lr, - ) - - # Evaluation - if self.val_loader is not None and self.global_step % self.eval_interval == 0: - val_loss = self.evaluate() - if val_loss < self.best_val_loss: - self.best_val_loss = val_loss - self.save_checkpoint(is_best=True) + except KeyboardInterrupt: + # Re-raise to be handled by outer try-except + raise - avg_loss = total_loss / num_batches + # Evaluation (only reached if no interruption) + if self.val_loader is not None and self.global_step % self.eval_interval == 0: + val_loss = self.evaluate() + if val_loss < self.best_val_loss: + self.best_val_loss = val_loss + self.save_checkpoint(is_best=True) + + avg_loss = total_loss / num_batches if num_batches > 0 else 0.0 # Log epoch metrics self.metrics.log(