fixing memory
This commit is contained in:
@@ -131,13 +131,23 @@ class Trainer:
|
|||||||
disable=False, # Explicitly enable
|
disable=False, # Explicitly enable
|
||||||
)
|
)
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(progress_bar):
|
try:
|
||||||
input_ids = batch['input_ids'].to(self.device)
|
for batch_idx, batch in enumerate(progress_bar):
|
||||||
labels = batch['labels'].to(self.device)
|
input_ids = batch['input_ids'].to(self.device)
|
||||||
|
labels = batch['labels'].to(self.device)
|
||||||
# Forward pass with mixed precision (only for CUDA)
|
|
||||||
if self.use_amp:
|
# Forward pass with mixed precision (only for CUDA)
|
||||||
with torch.amp.autocast('cuda', dtype=self.autocast_dtype):
|
if self.use_amp:
|
||||||
|
with torch.amp.autocast('cuda', dtype=self.autocast_dtype):
|
||||||
|
logits, _ = self.model(input_ids)
|
||||||
|
|
||||||
|
# Reshape for loss computation
|
||||||
|
logits = logits.view(-1, logits.size(-1))
|
||||||
|
labels = labels.view(-1)
|
||||||
|
|
||||||
|
loss = self.criterion(logits, labels)
|
||||||
|
loss = loss / self.gradient_accumulation_steps
|
||||||
|
else:
|
||||||
logits, _ = self.model(input_ids)
|
logits, _ = self.model(input_ids)
|
||||||
|
|
||||||
# Reshape for loss computation
|
# Reshape for loss computation
|
||||||
@@ -146,76 +156,70 @@ class Trainer:
|
|||||||
|
|
||||||
loss = self.criterion(logits, labels)
|
loss = self.criterion(logits, labels)
|
||||||
loss = loss / self.gradient_accumulation_steps
|
loss = loss / self.gradient_accumulation_steps
|
||||||
else:
|
|
||||||
logits, _ = self.model(input_ids)
|
|
||||||
|
|
||||||
# Reshape for loss computation
|
# Backward pass
|
||||||
logits = logits.view(-1, logits.size(-1))
|
|
||||||
labels = labels.view(-1)
|
|
||||||
|
|
||||||
loss = self.criterion(logits, labels)
|
|
||||||
loss = loss / self.gradient_accumulation_steps
|
|
||||||
|
|
||||||
# Backward pass
|
|
||||||
if self.use_amp:
|
|
||||||
self.scaler.scale(loss).backward()
|
|
||||||
else:
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
# Gradient accumulation
|
|
||||||
if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
|
|
||||||
# Gradient clipping
|
|
||||||
if self.use_amp:
|
if self.use_amp:
|
||||||
self.scaler.unscale_(self.optimizer)
|
self.scaler.scale(loss).backward()
|
||||||
torch.nn.utils.clip_grad_norm_(
|
|
||||||
self.model.parameters(),
|
|
||||||
self.max_grad_norm
|
|
||||||
)
|
|
||||||
self.scaler.step(self.optimizer)
|
|
||||||
self.scaler.update()
|
|
||||||
else:
|
else:
|
||||||
torch.nn.utils.clip_grad_norm_(
|
loss.backward()
|
||||||
self.model.parameters(),
|
|
||||||
self.max_grad_norm
|
# Gradient accumulation
|
||||||
|
if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
|
||||||
|
# Gradient clipping
|
||||||
|
if self.use_amp:
|
||||||
|
self.scaler.unscale_(self.optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(
|
||||||
|
self.model.parameters(),
|
||||||
|
self.max_grad_norm
|
||||||
|
)
|
||||||
|
self.scaler.step(self.optimizer)
|
||||||
|
self.scaler.update()
|
||||||
|
else:
|
||||||
|
torch.nn.utils.clip_grad_norm_(
|
||||||
|
self.model.parameters(),
|
||||||
|
self.max_grad_norm
|
||||||
|
)
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
if self.scheduler is not None:
|
||||||
|
self.scheduler.step()
|
||||||
|
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
self.global_step += 1
|
||||||
|
|
||||||
|
total_loss += loss.item() * self.gradient_accumulation_steps
|
||||||
|
num_batches += 1
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
if self.global_step % self.log_interval == 0:
|
||||||
|
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
||||||
|
lr = self.optimizer.param_groups[0]['lr']
|
||||||
|
progress_bar.set_postfix({
|
||||||
|
'loss': f'{avg_loss:.4f}',
|
||||||
|
'lr': f'{lr:.2e}',
|
||||||
|
})
|
||||||
|
progress_bar.refresh() # Force immediate refresh
|
||||||
|
sys.stderr.flush() # Force flush stderr to ensure progress bar displays
|
||||||
|
|
||||||
|
# Log metrics
|
||||||
|
self.metrics.log(
|
||||||
|
epoch=self.current_epoch,
|
||||||
|
step=self.global_step,
|
||||||
|
train_loss=avg_loss,
|
||||||
|
lr=lr,
|
||||||
)
|
)
|
||||||
self.optimizer.step()
|
except KeyboardInterrupt:
|
||||||
|
# Re-raise to be handled by outer try-except
|
||||||
if self.scheduler is not None:
|
raise
|
||||||
self.scheduler.step()
|
|
||||||
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
self.global_step += 1
|
|
||||||
|
|
||||||
total_loss += loss.item() * self.gradient_accumulation_steps
|
|
||||||
num_batches += 1
|
|
||||||
|
|
||||||
# Logging
|
|
||||||
if self.global_step % self.log_interval == 0:
|
|
||||||
avg_loss = total_loss / num_batches
|
|
||||||
lr = self.optimizer.param_groups[0]['lr']
|
|
||||||
progress_bar.set_postfix({
|
|
||||||
'loss': f'{avg_loss:.4f}',
|
|
||||||
'lr': f'{lr:.2e}',
|
|
||||||
})
|
|
||||||
progress_bar.refresh() # Force immediate refresh
|
|
||||||
sys.stderr.flush() # Force flush stderr to ensure progress bar displays
|
|
||||||
|
|
||||||
# Log metrics
|
|
||||||
self.metrics.log(
|
|
||||||
epoch=self.current_epoch,
|
|
||||||
step=self.global_step,
|
|
||||||
train_loss=avg_loss,
|
|
||||||
lr=lr,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Evaluation
|
|
||||||
if self.val_loader is not None and self.global_step % self.eval_interval == 0:
|
|
||||||
val_loss = self.evaluate()
|
|
||||||
if val_loss < self.best_val_loss:
|
|
||||||
self.best_val_loss = val_loss
|
|
||||||
self.save_checkpoint(is_best=True)
|
|
||||||
|
|
||||||
avg_loss = total_loss / num_batches
|
# Evaluation (only reached if no interruption)
|
||||||
|
if self.val_loader is not None and self.global_step % self.eval_interval == 0:
|
||||||
|
val_loss = self.evaluate()
|
||||||
|
if val_loss < self.best_val_loss:
|
||||||
|
self.best_val_loss = val_loss
|
||||||
|
self.save_checkpoint(is_best=True)
|
||||||
|
|
||||||
|
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
||||||
|
|
||||||
# Log epoch metrics
|
# Log epoch metrics
|
||||||
self.metrics.log(
|
self.metrics.log(
|
||||||
|
|||||||
Reference in New Issue
Block a user