From 4398afe4c765b7ebf7e9a3e5d9aef088c8274d58 Mon Sep 17 00:00:00 2001 From: Alexander Nikitin <1243786+AlexanderVNikitin@users.noreply.github.com> Date: Mon, 20 Nov 2023 10:43:29 +0200 Subject: [PATCH] add graph and eager execution for timegan tests --- tests/test_timegan.py | 94 ++++++++++++++++++++++++++++--------------- 1 file changed, 61 insertions(+), 33 deletions(-) diff --git a/tests/test_timegan.py b/tests/test_timegan.py index 6a8d179..166f3ac 100644 --- a/tests/test_timegan.py +++ b/tests/test_timegan.py @@ -23,12 +23,16 @@ def test_timegan(): batch_size=batch_size, ) timegan.compile() - timegan.fit(dataset, epochs=1) - _check_internals(timegan) + try: + tf.config.experimental_run_functions_eagerly(True) + timegan.fit(dataset, epochs=1) - # Check generation - generated_samples = timegan.generate(1) + _check_internals(timegan) + + generated_samples = timegan.generate(1) + finally: + tf.config.experimental_run_functions_eagerly(False) assert generated_samples.shape == (1, seq_len, feature_dim) @@ -48,12 +52,15 @@ def test_timegan_on_dataset(): batch_size=batch_size, ) timegan.compile() - timegan.fit(dataset, epochs=1) + try: + tf.config.experimental_run_functions_eagerly(True) + timegan.fit(dataset, epochs=1) - _check_internals(timegan) + _check_internals(timegan) - # Check generation - generated_samples = timegan.generate(1) + generated_samples = timegan.generate(1) + finally: + tf.config.experimental_run_functions_eagerly(False) assert generated_samples.shape == (1, seq_len, feature_dim) @@ -153,17 +160,21 @@ def test_train_timegan(mocked_gradienttape): batch_size=batch_size, ) timegan.compile() - timegan.fit(dataset, epochs=1) - batches = timegan._get_data_batch(dataset, n_windows=len(dataset)) - assert timegan._train_autoencoder(next(batches), timegan.autoencoder_opt) - assert timegan._train_supervisor(next(batches), timegan.adversarialsup_opt) - assert timegan._train_generator( - next(batches), next(timegan.get_noise_batch()), timegan.generator_opt - ) - assert timegan._train_embedder(next(batches), timegan.embedder_opt) - assert timegan._train_discriminator( - next(batches), next(timegan.get_noise_batch()), timegan.discriminator_opt - ) + try: + tf.config.experimental_run_functions_eagerly(True) + timegan.fit(dataset, epochs=1) + batches = timegan._get_data_batch(dataset, n_windows=len(dataset)) + assert timegan._train_autoencoder(next(batches), timegan.autoencoder_opt) + assert timegan._train_supervisor(next(batches), timegan.adversarialsup_opt) + assert timegan._train_generator( + next(batches), next(timegan.get_noise_batch()), timegan.generator_opt + ) + assert timegan._train_embedder(next(batches), timegan.embedder_opt) + assert timegan._train_discriminator( + next(batches), next(timegan.get_noise_batch()), timegan.discriminator_opt + ) + finally: + tf.config.experimental_run_functions_eagerly(False) @pytest.fixture @@ -204,7 +215,11 @@ def test_timegan_train_autoencoder(mocked_data, mocked_timegan): mocked_timegan._define_timegan() X_ = next(batches) - loss = mocked_timegan._train_autoencoder(X_, mocked_timegan.autoencoder_opt) + try: + tf.config.experimental_run_functions_eagerly(True) + loss = mocked_timegan._train_autoencoder(X_, mocked_timegan.autoencoder_opt) + finally: + tf.config.experimental_run_functions_eagerly(False) # Assert that the loss is a float assert loss.dtype in [tf.float32, tf.float64] @@ -215,8 +230,11 @@ def test_timegan_train_embedder(mocked_data, mocked_timegan): mocked_timegan._define_timegan() X_ = next(batches) - _, loss = mocked_timegan._train_embedder(X_, mocked_timegan.embedder_opt) - + try: + tf.config.experimental_run_functions_eagerly(True) + _, loss = mocked_timegan._train_embedder(X_, mocked_timegan.embedder_opt) + finally: + tf.config.experimental_run_functions_eagerly(False) # Assert that the loss is a float assert loss.dtype in [tf.float32, tf.float64] @@ -227,13 +245,17 @@ def test_timegan_train_generator(mocked_data, mocked_timegan): mocked_timegan._define_timegan() X_ = next(batches) Z_ = next(mocked_timegan.get_noise_batch()) - ( - step_g_loss_u, - step_g_loss_u_e, - step_g_loss_s, - step_g_loss_v, - step_g_loss, - ) = mocked_timegan._train_generator(X_, Z_, mocked_timegan.generator_opt) + try: + tf.config.experimental_run_functions_eagerly(True) + ( + step_g_loss_u, + step_g_loss_u_e, + step_g_loss_s, + step_g_loss_v, + step_g_loss, + ) = mocked_timegan._train_generator(X_, Z_, mocked_timegan.generator_opt) + finally: + tf.config.experimental_run_functions_eagerly(False) # Assert that the loss is a float for loss in ( @@ -252,7 +274,11 @@ def test_timegan_check_discriminator_loss(mocked_data, mocked_timegan): mocked_timegan._define_timegan() X_ = next(batches) Z_ = next(mocked_timegan.get_noise_batch()) - loss = mocked_timegan._check_discriminator_loss(X_, Z_) + try: + tf.config.experimental_run_functions_eagerly(True) + loss = mocked_timegan._check_discriminator_loss(X_, Z_) + finally: + tf.config.experimental_run_functions_eagerly(False) # Assert that the loss is a float assert loss.dtype in [tf.float32, tf.float64] @@ -260,11 +286,13 @@ def test_timegan_check_discriminator_loss(mocked_data, mocked_timegan): def test_timegan_train_discriminator(mocked_data, mocked_timegan): batches = iter(mocked_data.repeat()) - mocked_timegan._define_timegan() X_ = next(batches) Z_ = next(mocked_timegan.get_noise_batch()) - loss = mocked_timegan._train_discriminator(X_, Z_, mocked_timegan.discriminator_opt) - + try: + tf.config.experimental_run_functions_eagerly(True) + loss = mocked_timegan._train_discriminator(X_, Z_, mocked_timegan.discriminator_opt) + finally: + tf.config.experimental_run_functions_eagerly(False) # Assert that the loss is a float assert loss.dtype in [tf.float32, tf.float64]