From 4352a1907082eebd21cae896c560fb9dd6e7f606 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Mon, 29 Apr 2024 14:08:58 +0200 Subject: [PATCH] refactor: fixed loss logging during gradient accumulation --- src/modalities/trainer.py | 62 +++++++++++++++++++++++---------------- 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index 607f21ab..686373ef 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist import torch.nn as nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler @@ -40,7 +41,7 @@ def __init__( def _train_batch( self, batch: DatasetBatch, - model: nn.Module, + model: FSDP, optimizer: Optimizer, scheduler: LRScheduler, loss_fun: Loss, @@ -48,15 +49,18 @@ def _train_batch( data_loader: LLMDataLoader, ) -> Tuple[torch.Tensor, torch.Tensor]: result_batch = model_predict_batch(model=model, batch=batch) - loss = loss_fun(result_batch) / self.gradient_acc_steps - loss.backward() - gradient_norm_score = self.gradient_clipper(model) + loss = loss_fun(result_batch) + (loss / self.gradient_acc_steps).backward() if (train_step_id + 1) % self.gradient_acc_steps == 0 or (train_step_id + 1) == len(data_loader): + # gradient_norm_score = self.gradient_clipper(model) + gradient_norm_score = model.clip_grad_norm_(max_norm=1, norm_type=2).sum() optimizer.step() scheduler.step() optimizer.zero_grad() - return loss, gradient_norm_score + return loss, gradient_norm_score + else: + return loss, None def train( self, @@ -82,6 +86,7 @@ def train( dist.barrier() forward_backward_time_recorder = TimeRecorder() forward_backward_time_recorder.start() + gradient_norm_scores = [] for batch_id, batch in enumerate(train_loader): # Because we might resume training, we add the starting batch id of the data loader train_step_id = batch_id + train_loader.fast_forward_batch_id @@ -98,9 +103,13 @@ def train( forward_backward_time_recorder.stop() # Save the batch loss cumulated_loss_and_gradient_norm[0] += batch_loss.item() - cumulated_loss_and_gradient_norm[1] += gradient_norm_score.item() # This works, because we always drop the last batch in case it has less samples than the batch size cumulated_loss_and_gradient_norm[-1] += 1 # number of local batches + + # gradient norm is already synced across all ranks + if gradient_norm_score is not None: + gradient_norm_scores.append(gradient_norm_score.item()) + batch_length_tensor = torch.tensor(len(batch)).to(device) thoughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor) @@ -124,35 +133,37 @@ def train( ) synced_num_samples_per_second = synced_num_samples / synced_forward_backward_time # TODO: insert reducer from outside so Trainer is independent of FSDP - cumulated_loss_and_gradient_norm[2] = batch_loss.item() - cumulated_loss_and_gradient_norm[3] = gradient_norm_score.item() + # add the loss and gradient norm for the LAST batch + cumulated_loss_and_gradient_norm[1] = batch_loss.item() reduced_loss_and_gradient_norm = Reducer.reduce( tensor=cumulated_loss_and_gradient_norm, operation=dist.ReduceOp.SUM, - # divide the first two elements by the last one - # i.e., summed batch loss / (num batches * world size) - # and summed gradient norm/ (num batches * world size). - # keep the other elements as is - post_processing_fun=lambda t: torch.cat((t[:2] / t[-1], t[2:-1] / dist.get_world_size())), + # 1.) summed batch loss / (num batches * world size) + # 2.) last batch loss / world size + post_processing_fun=lambda t: torch.stack([t[0] / t[-1], t[1] / dist.get_world_size()]), ) - train_loss_avg, train_gradient_norm_avg, train_loss_last_batch, train_gradient_norm_last_batch = ( + train_loss_avg, train_loss_last_batch = ( reduced_loss_and_gradient_norm[0], reduced_loss_and_gradient_norm[1], - reduced_loss_and_gradient_norm[2], - reduced_loss_and_gradient_norm[3], ) + losses = { + f"{loss_fun.tag} average": train_loss_avg, + f"{loss_fun.tag} last step": train_loss_last_batch, + } + if len(gradient_norm_scores) > 0: + metrics = { + "grad_norm_avg": torch.mean(torch.Tensor(gradient_norm_scores)), + "grad_norm_last_batch": gradient_norm_scores[-1], + } + gradient_norm_scores = [] + else: + metrics = {} training_metrics = EvaluationResultBatch( - losses={ - f"{loss_fun.tag} interval average": train_loss_avg, - f"{loss_fun.tag} last batch": train_loss_last_batch, - }, - metrics={ - "grad_norm_avg": train_gradient_norm_avg, - "grad_norm_last_batch": train_gradient_norm_last_batch, - }, + losses=losses, + metrics=metrics, # TODO: hardcoded metric key throughput_metrics={ "training_synced_num_samples_per_second": synced_num_samples_per_second, @@ -181,7 +192,8 @@ def train( def _reset_loss_and_gradient_norm(self): # TODO: we should handle the device assignment more centrally. - cumulated_loss_and_gradient_norm = torch.zeros(5) + # summed lcoal losses, loss of last local batch, number of local batches (i.e., number of steps) + cumulated_loss_and_gradient_norm = torch.zeros(3) if torch.cuda.is_available(): cumulated_loss_and_gradient_norm = cumulated_loss_and_gradient_norm.to(torch.device(self.local_rank)) else: