diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index 44bea984..1289934e 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -115,7 +115,6 @@ def train( # enable profile for qaic qaic_profile.start_profiling(device, 1) if train_config.use_profiler else None - for step, batch in enumerate(train_dataloader): total_train_steps += 1 # stop when the maximum number of training steps is reached @@ -146,16 +145,16 @@ def train( else: loss = model(**batch).loss # Forward call + total_loss += loss.detach().float() + # Accumalate graidents + loss = loss / train_config.gradient_accumulation_steps + if train_config.enable_ddp: if local_rank == 0: tensorboard_updates.add_scalars("loss", {"train": loss}, total_train_steps) else: tensorboard_updates.add_scalars("loss", {"train": loss}, total_train_steps) - total_loss += loss.detach().float() - # Accumalate graidents - loss = loss / train_config.gradient_accumulation_steps - if train_config.save_metrics: train_step_loss.append(loss.detach().float().item()) train_step_perplexity.append(float(torch.exp(loss.detach().float()))) @@ -222,12 +221,13 @@ def train( dist.barrier() dist.all_reduce(eval_epoch_loss, op=dist.ReduceOp.SUM) if local_rank == 0: - tensorboard_updates.add_scalars("loss", {"val": eval_epoch_loss}, total_train_steps) + tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps) + else: eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation( model, train_config, eval_dataloader, local_rank, tokenizer, device ) - tensorboard_updates.add_scalars("loss", {"val": eval_epoch_loss}, total_train_steps) + tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps) if train_config.save_metrics: val_step_loss.extend(temp_val_loss)