diff --git a/src/eval/logger.py b/src/eval/logger.py index b34cbd3..411b187 100644 --- a/src/eval/logger.py +++ b/src/eval/logger.py @@ -60,17 +60,13 @@ def log_LR(self, model: torch.nn.Module, schedulers: list, step: int): :param schedulers: List of schedulers corresponding to the optimizers :param step: Current step number """ - lr_info = {} - scheduler_info = {} for i, (optimizer, scheduler) in enumerate(zip(model.optimizers, schedulers)): for j, param_group in enumerate(optimizer.param_groups): lr = param_group["lr"] - name = param_group.get("name", f"optimizer_{i}_group_{j}") - lr_info[f"Learning Rate/{name}"] = lr - scheduler_info[f"Scheduler Type/{name}"] = scheduler.__class__.__name__ - - wandb.log(lr_info, step=step) - wandb.log(scheduler_info, step=step) + param_name = param_group.get("name", f"optimizer_{i}_group_{j}") + s_type = f"Sch:{scheduler.__class__.__name__}" + l_name = s_type + f" LR: {param_name}" + wandb.log({l_name: lr}, step=step) # BUG: failed to show in wandb def log_gradients(self, model: torch.nn.Module, step: int):