Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623619200
  • Loading branch information
zhong1wan authored and The swirl_dynamics Authors committed Apr 10, 2024
1 parent 2f5ee6b commit c1bfc9d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 3 additions & 2 deletions swirl_dynamics/templates/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,9 @@ class ParameterOverview(Callback):
log_to_tb: bool = True

def on_train_begin(self, trainer: trainers.BaseTrainer) -> None:
if isinstance(trainer.train_state, train_states.BasicTrainState):
params = trainer.train_state.params
train_state = trainer.unreplicated_train_state
if isinstance(train_state, train_states.BasicTrainState):
params = train_state.params
if self.log_to_info:
logging.info("Logging parameter overview.")
parameter_overview.log_parameter_overview(params)
Expand Down
2 changes: 1 addition & 1 deletion swirl_dynamics/templates/callbacks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_logs_parameter_overview(self):
callback = callbacks.ParameterOverview()
callback.metric_writer = metric_writers.create_default_writer(work_dir)
trainer = mock.Mock(spec=trainers.BaseTrainer)
trainer.train_state = train_states.BasicTrainState(
trainer.unreplicated_train_state = train_states.BasicTrainState(
step=jnp.array(0),
params={"bias": jnp.ones((10,)), "weights": jnp.ones((10, 10))},
opt_state={},
Expand Down

0 comments on commit c1bfc9d

Please sign in to comment.