Skip to content

Commit

Permalink
Merge pull request #118 from Modalities/gradient_accumulation_loss_lo…
Browse files Browse the repository at this point in the history
…gging

refactor: fixed loss logging during gradient accumulation and gradient clipping
  • Loading branch information
fromm-m authored Apr 30, 2024
2 parents bd06dab + 4352a19 commit a599005
Showing 1 changed file with 37 additions and 25 deletions.
62 changes: 37 additions & 25 deletions src/modalities/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler

Expand Down Expand Up @@ -40,23 +41,26 @@ def __init__(
def _train_batch(
self,
batch: DatasetBatch,
model: nn.Module,
model: FSDP,
optimizer: Optimizer,
scheduler: LRScheduler,
loss_fun: Loss,
train_step_id: int,
data_loader: LLMDataLoader,
) -> Tuple[torch.Tensor, torch.Tensor]:
result_batch = model_predict_batch(model=model, batch=batch)
loss = loss_fun(result_batch) / self.gradient_acc_steps
loss.backward()
gradient_norm_score = self.gradient_clipper(model)
loss = loss_fun(result_batch)
(loss / self.gradient_acc_steps).backward()

if (train_step_id + 1) % self.gradient_acc_steps == 0 or (train_step_id + 1) == len(data_loader):
# gradient_norm_score = self.gradient_clipper(model)
gradient_norm_score = model.clip_grad_norm_(max_norm=1, norm_type=2).sum()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
return loss, gradient_norm_score
return loss, gradient_norm_score
else:
return loss, None

def train(
self,
Expand All @@ -82,6 +86,7 @@ def train(
dist.barrier()
forward_backward_time_recorder = TimeRecorder()
forward_backward_time_recorder.start()
gradient_norm_scores = []
for batch_id, batch in enumerate(train_loader):
# Because we might resume training, we add the starting batch id of the data loader
train_step_id = batch_id + train_loader.fast_forward_batch_id
Expand All @@ -98,9 +103,13 @@ def train(
forward_backward_time_recorder.stop()
# Save the batch loss
cumulated_loss_and_gradient_norm[0] += batch_loss.item()
cumulated_loss_and_gradient_norm[1] += gradient_norm_score.item()
# This works, because we always drop the last batch in case it has less samples than the batch size
cumulated_loss_and_gradient_norm[-1] += 1 # number of local batches

# gradient norm is already synced across all ranks
if gradient_norm_score is not None:
gradient_norm_scores.append(gradient_norm_score.item())

batch_length_tensor = torch.tensor(len(batch)).to(device)
thoughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor)

Expand All @@ -124,35 +133,37 @@ def train(
)
synced_num_samples_per_second = synced_num_samples / synced_forward_backward_time
# TODO: insert reducer from outside so Trainer is independent of FSDP
cumulated_loss_and_gradient_norm[2] = batch_loss.item()
cumulated_loss_and_gradient_norm[3] = gradient_norm_score.item()
# add the loss and gradient norm for the LAST batch
cumulated_loss_and_gradient_norm[1] = batch_loss.item()

reduced_loss_and_gradient_norm = Reducer.reduce(
tensor=cumulated_loss_and_gradient_norm,
operation=dist.ReduceOp.SUM,
# divide the first two elements by the last one
# i.e., summed batch loss / (num batches * world size)
# and summed gradient norm/ (num batches * world size).
# keep the other elements as is
post_processing_fun=lambda t: torch.cat((t[:2] / t[-1], t[2:-1] / dist.get_world_size())),
# 1.) summed batch loss / (num batches * world size)
# 2.) last batch loss / world size
post_processing_fun=lambda t: torch.stack([t[0] / t[-1], t[1] / dist.get_world_size()]),
)

train_loss_avg, train_gradient_norm_avg, train_loss_last_batch, train_gradient_norm_last_batch = (
train_loss_avg, train_loss_last_batch = (
reduced_loss_and_gradient_norm[0],
reduced_loss_and_gradient_norm[1],
reduced_loss_and_gradient_norm[2],
reduced_loss_and_gradient_norm[3],
)
losses = {
f"{loss_fun.tag} average": train_loss_avg,
f"{loss_fun.tag} last step": train_loss_last_batch,
}
if len(gradient_norm_scores) > 0:
metrics = {
"grad_norm_avg": torch.mean(torch.Tensor(gradient_norm_scores)),
"grad_norm_last_batch": gradient_norm_scores[-1],
}
gradient_norm_scores = []
else:
metrics = {}

training_metrics = EvaluationResultBatch(
losses={
f"{loss_fun.tag} interval average": train_loss_avg,
f"{loss_fun.tag} last batch": train_loss_last_batch,
},
metrics={
"grad_norm_avg": train_gradient_norm_avg,
"grad_norm_last_batch": train_gradient_norm_last_batch,
},
losses=losses,
metrics=metrics,
# TODO: hardcoded metric key
throughput_metrics={
"training_synced_num_samples_per_second": synced_num_samples_per_second,
Expand Down Expand Up @@ -181,7 +192,8 @@ def train(

def _reset_loss_and_gradient_norm(self):
# TODO: we should handle the device assignment more centrally.
cumulated_loss_and_gradient_norm = torch.zeros(5)
# summed lcoal losses, loss of last local batch, number of local batches (i.e., number of steps)
cumulated_loss_and_gradient_norm = torch.zeros(3)
if torch.cuda.is_available():
cumulated_loss_and_gradient_norm = cumulated_loss_and_gradient_norm.to(torch.device(self.local_rank))
else:
Expand Down

0 comments on commit a599005

Please sign in to comment.