Skip to content

Commit

Permalink
[QEff. Finetune] : To plot correct loss values on tensorboard (#207)
Browse files Browse the repository at this point in the history
[QEff. Finetune] : Plotting the correct (scaled) loss on tensorboard for the training.
Removing this change for now: Plotting the loss for each sample of eval rather the average

Signed-off-by: Swati Allabadi <[email protected]>
Co-authored-by: Swati Allabadi <[email protected]>
  • Loading branch information
quic-swatia authored and quic-rishinr committed Jan 10, 2025
1 parent 7350c96 commit 74aa67c
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions QEfficient/finetune/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def train(

# enable profile for qaic
qaic_profile.start_profiling(device, 1) if train_config.use_profiler else None

for step, batch in enumerate(train_dataloader):
total_train_steps += 1
# stop when the maximum number of training steps is reached
Expand Down Expand Up @@ -146,16 +145,16 @@ def train(
else:
loss = model(**batch).loss # Forward call

total_loss += loss.detach().float()
# Accumalate graidents
loss = loss / train_config.gradient_accumulation_steps

if train_config.enable_ddp:
if local_rank == 0:
tensorboard_updates.add_scalars("loss", {"train": loss}, total_train_steps)
else:
tensorboard_updates.add_scalars("loss", {"train": loss}, total_train_steps)

total_loss += loss.detach().float()
# Accumalate graidents
loss = loss / train_config.gradient_accumulation_steps

if train_config.save_metrics:
train_step_loss.append(loss.detach().float().item())
train_step_perplexity.append(float(torch.exp(loss.detach().float())))
Expand Down Expand Up @@ -222,12 +221,13 @@ def train(
dist.barrier()
dist.all_reduce(eval_epoch_loss, op=dist.ReduceOp.SUM)
if local_rank == 0:
tensorboard_updates.add_scalars("loss", {"val": eval_epoch_loss}, total_train_steps)
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)

else:
eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(
model, train_config, eval_dataloader, local_rank, tokenizer, device
)
tensorboard_updates.add_scalars("loss", {"val": eval_epoch_loss}, total_train_steps)
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)

if train_config.save_metrics:
val_step_loss.extend(temp_val_loss)
Expand Down

0 comments on commit 74aa67c

Please sign in to comment.