diff --git a/swirl_dynamics/templates/callbacks_test.py b/swirl_dynamics/templates/callbacks_test.py index 5b5a6b6..7c965ac 100644 --- a/swirl_dynamics/templates/callbacks_test.py +++ b/swirl_dynamics/templates/callbacks_test.py @@ -185,25 +185,6 @@ def test_reports_monitors(self): ) -class ParameterOverviewTest(absltest.TestCase): - - def test_logs_parameter_overview(self): - work_dir = self.create_tempdir().full_path - callback = callbacks.ParameterOverview() - callback.metric_writer = metric_writers.create_default_writer(work_dir) - trainer = mock.Mock(spec=trainers.BaseTrainer) - trainer.unreplicated_train_state = train_states.BasicTrainState( - step=jnp.array(0), - params={"bias": jnp.ones((10,)), "weights": jnp.ones((10, 10))}, - opt_state={}, - ) - buffer = io.StringIO() - logging.use_python_logging() - logging.get_absl_handler().python_handler.stream = buffer - callback.on_train_begin(trainer) - self.assertRegex(buffer.getvalue(), r"bias.*(10,)") - self.assertRegex(buffer.getvalue(), r"weights.*(10, 10)") - if __name__ == "__main__": absltest.main()