fixing memory

This commit is contained in:
Carlos Gutierrez
2025-11-16 16:52:26 -05:00
parent 9f17e1db24
commit fb1ca67be9

View File

@@ -131,6 +131,7 @@ class Trainer:
disable=False, # Explicitly enable
)
try:
for batch_idx, batch in enumerate(progress_bar):
input_ids = batch['input_ids'].to(self.device)
labels = batch['labels'].to(self.device)
@@ -191,7 +192,7 @@ class Trainer:
# Logging
if self.global_step % self.log_interval == 0:
avg_loss = total_loss / num_batches
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
lr = self.optimizer.param_groups[0]['lr']
progress_bar.set_postfix({
'loss': f'{avg_loss:.4f}',
@@ -207,15 +208,18 @@ class Trainer:
train_loss=avg_loss,
lr=lr,
)
except KeyboardInterrupt:
# Re-raise to be handled by outer try-except
raise
# Evaluation
# Evaluation (only reached if no interruption)
if self.val_loader is not None and self.global_step % self.eval_interval == 0:
val_loss = self.evaluate()
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
self.save_checkpoint(is_best=True)
avg_loss = total_loss / num_batches
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
# Log epoch metrics
self.metrics.log(