Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 642330504
  • Loading branch information
zhong1wan authored and The swirl_dynamics Authors committed Jun 11, 2024
1 parent a2fa0a0 commit 774636f
Showing 1 changed file with 0 additions and 19 deletions.
19 changes: 0 additions & 19 deletions swirl_dynamics/templates/callbacks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,25 +185,6 @@ def test_reports_monitors(self):
)


class ParameterOverviewTest(absltest.TestCase):

def test_logs_parameter_overview(self):
work_dir = self.create_tempdir().full_path
callback = callbacks.ParameterOverview()
callback.metric_writer = metric_writers.create_default_writer(work_dir)
trainer = mock.Mock(spec=trainers.BaseTrainer)
trainer.unreplicated_train_state = train_states.BasicTrainState(
step=jnp.array(0),
params={"bias": jnp.ones((10,)), "weights": jnp.ones((10, 10))},
opt_state={},
)
buffer = io.StringIO()
logging.use_python_logging()
logging.get_absl_handler().python_handler.stream = buffer
callback.on_train_begin(trainer)
self.assertRegex(buffer.getvalue(), r"bias.*(10,)")
self.assertRegex(buffer.getvalue(), r"weights.*(10, 10)")


if __name__ == "__main__":
absltest.main()

0 comments on commit 774636f

Please sign in to comment.