Skip to content

Commit

Permalink
refactor: now looping over multiple loggers, if any are supplied
Browse files Browse the repository at this point in the history
This will only still function for wandb/tensorboard, but supports multiple
  • Loading branch information
laserkelvin committed Jul 1, 2024
1 parent af745c0 commit 515e4b8
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions matsciml/lightning/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,21 +1056,23 @@ def encoder_head_comparison(
" encoder median norm: {encoder_median:.3e},"
" output head: {output_median:.3e}"
)
# optionally record to service as well
if log_history and pl_module.logger is not None:
log_service = pl_module.logger.experiment
encoder_norm_vals = torch.from_numpy(encoder_norm_vals).float()
output_norm_vals = torch.from_numpy(output_norm_vals).float()
if isinstance(log_service, pl_loggers.TensorBoardLogger):
log_service.add_histogram(
"encoder_weight_norm", encoder_norm_vals, global_step
)
log_service.add_histogram(
"outputhead_weight_norm", output_norm_vals, global_step
)
elif isinstance(log_service, pl_loggers.WandbLogger):
log_service.log({"encoder_weight_norm": encoder_norm_vals})
log_service.log({"outputhead_weight_norm": output_norm_vals})
# optionally record to a supported service as well
# this nominally should work for multiple loggers
if log_history and len(pl_module.loggers) > 0:
for pl_logger in pl_module.loggers:
log_service = pl_logger.experiment
encoder_norm_vals = torch.from_numpy(encoder_norm_vals).float()
output_norm_vals = torch.from_numpy(output_norm_vals).float()
if isinstance(log_service, pl_loggers.TensorBoardLogger):
log_service.add_histogram(
"encoder_weight_norm", encoder_norm_vals, global_step
)
log_service.add_histogram(
"outputhead_weight_norm", output_norm_vals, global_step
)
elif isinstance(log_service, pl_loggers.WandbLogger):
log_service.log({"encoder_weight_norm": encoder_norm_vals})
log_service.log({"outputhead_weight_norm": output_norm_vals})

def on_before_optimizer_step(
self,
Expand Down

0 comments on commit 515e4b8

Please sign in to comment.