diff --git a/tests/test_TFEngine.py b/tests/test_TFEngine.py index ebf8ae2d9f..72f27257bf 100644 --- a/tests/test_TFEngine.py +++ b/tests/test_TFEngine.py @@ -5050,6 +5050,7 @@ def test_grad_summaries(): print("extern data:", engine.config.typed_value("extern_data")) engine.init_train_from_config() + engine.init_train_epoch() def extra_fetches_cb(summary_proto): """ @@ -5071,7 +5072,9 @@ def extra_fetches_cb(summary_proto): max_seqs=100, used_data_keys=engine.network.used_data_keys, ) - forwarder = Runner( + + engine.updater.set_learning_rate(engine.learning_rate, session=engine.tf_session) + trainer = Runner( engine=engine, dataset=train_data, batches=batches, @@ -5082,8 +5085,8 @@ def extra_fetches_cb(summary_proto): }, extra_fetches_callback=extra_fetches_cb, ) - forwarder.run(report_prefix="test_grad_summaries") - if not forwarder.finalized: + trainer.run(report_prefix="test_grad_summaries") + if not trainer.finalized: raise Exception("Error happened. Exit now.")