diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index 74c06e180..938c4fa11 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -471,6 +471,7 @@ def _test_submission(workload_name, batch=batch, loss_type=workload.loss_type, optimizer_state=optimizer_state, + train_state={}, eval_results=[], global_step=global_step, rng=update_rng)