fixing memory

This commit is contained in:
Carlos Gutierrez
2025-11-16 16:52:26 -05:00
parent 9f17e1db24
commit fb1ca67be9

View File

@@ -131,13 +131,23 @@ class Trainer:
disable=False, # Explicitly enable
)
for batch_idx, batch in enumerate(progress_bar):
input_ids = batch['input_ids'].to(self.device)
labels = batch['labels'].to(self.device)
# Forward pass with mixed precision (only for CUDA)
if self.use_amp:
with torch.amp.autocast('cuda', dtype=self.autocast_dtype):
try:
for batch_idx, batch in enumerate(progress_bar):
input_ids = batch['input_ids'].to(self.device)
labels = batch['labels'].to(self.device)
# Forward pass with mixed precision (only for CUDA)
if self.use_amp:
with torch.amp.autocast('cuda', dtype=self.autocast_dtype):
logits, _ = self.model(input_ids)
# Reshape for loss computation
logits = logits.view(-1, logits.size(-1))
labels = labels.view(-1)
loss = self.criterion(logits, labels)
loss = loss / self.gradient_accumulation_steps
else:
logits, _ = self.model(input_ids)
# Reshape for loss computation
@@ -146,76 +156,70 @@ class Trainer:
loss = self.criterion(logits, labels)
loss = loss / self.gradient_accumulation_steps
else:
logits, _ = self.model(input_ids)
# Reshape for loss computation
logits = logits.view(-1, logits.size(-1))
labels = labels.view(-1)
loss = self.criterion(logits, labels)
loss = loss / self.gradient_accumulation_steps
# Backward pass
if self.use_amp:
self.scaler.scale(loss).backward()
else:
loss.backward()
# Gradient accumulation
if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
# Gradient clipping
# Backward pass
if self.use_amp:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.max_grad_norm
)
self.scaler.step(self.optimizer)
self.scaler.update()
self.scaler.scale(loss).backward()
else:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.max_grad_norm
loss.backward()
# Gradient accumulation
if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
# Gradient clipping
if self.use_amp:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.max_grad_norm
)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.max_grad_norm
)
self.optimizer.step()
if self.scheduler is not None:
self.scheduler.step()
self.optimizer.zero_grad()
self.global_step += 1
total_loss += loss.item() * self.gradient_accumulation_steps
num_batches += 1
# Logging
if self.global_step % self.log_interval == 0:
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
lr = self.optimizer.param_groups[0]['lr']
progress_bar.set_postfix({
'loss': f'{avg_loss:.4f}',
'lr': f'{lr:.2e}',
})
progress_bar.refresh() # Force immediate refresh
sys.stderr.flush() # Force flush stderr to ensure progress bar displays
# Log metrics
self.metrics.log(
epoch=self.current_epoch,
step=self.global_step,
train_loss=avg_loss,
lr=lr,
)
self.optimizer.step()
if self.scheduler is not None:
self.scheduler.step()
self.optimizer.zero_grad()
self.global_step += 1
total_loss += loss.item() * self.gradient_accumulation_steps
num_batches += 1
# Logging
if self.global_step % self.log_interval == 0:
avg_loss = total_loss / num_batches
lr = self.optimizer.param_groups[0]['lr']
progress_bar.set_postfix({
'loss': f'{avg_loss:.4f}',
'lr': f'{lr:.2e}',
})
progress_bar.refresh() # Force immediate refresh
sys.stderr.flush() # Force flush stderr to ensure progress bar displays
# Log metrics
self.metrics.log(
epoch=self.current_epoch,
step=self.global_step,
train_loss=avg_loss,
lr=lr,
)
# Evaluation
if self.val_loader is not None and self.global_step % self.eval_interval == 0:
val_loss = self.evaluate()
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
self.save_checkpoint(is_best=True)
except KeyboardInterrupt:
# Re-raise to be handled by outer try-except
raise
avg_loss = total_loss / num_batches
# Evaluation (only reached if no interruption)
if self.val_loader is not None and self.global_step % self.eval_interval == 0:
val_loss = self.evaluate()
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
self.save_checkpoint(is_best=True)
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
# Log epoch metrics
self.metrics.log(