Skip to content

Commit

Permalink
add graph and eager execution for timegan tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderVNikitin committed Nov 20, 2023
1 parent fa80378 commit 4398afe
Showing 1 changed file with 61 additions and 33 deletions.
94 changes: 61 additions & 33 deletions tests/test_timegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ def test_timegan():
batch_size=batch_size,
)
timegan.compile()
timegan.fit(dataset, epochs=1)

_check_internals(timegan)
try:
tf.config.experimental_run_functions_eagerly(True)
timegan.fit(dataset, epochs=1)

# Check generation
generated_samples = timegan.generate(1)
_check_internals(timegan)

generated_samples = timegan.generate(1)
finally:
tf.config.experimental_run_functions_eagerly(False)
assert generated_samples.shape == (1, seq_len, feature_dim)


Expand All @@ -48,12 +52,15 @@ def test_timegan_on_dataset():
batch_size=batch_size,
)
timegan.compile()
timegan.fit(dataset, epochs=1)
try:
tf.config.experimental_run_functions_eagerly(True)
timegan.fit(dataset, epochs=1)

_check_internals(timegan)
_check_internals(timegan)

# Check generation
generated_samples = timegan.generate(1)
generated_samples = timegan.generate(1)
finally:
tf.config.experimental_run_functions_eagerly(False)
assert generated_samples.shape == (1, seq_len, feature_dim)


Expand Down Expand Up @@ -153,17 +160,21 @@ def test_train_timegan(mocked_gradienttape):
batch_size=batch_size,
)
timegan.compile()
timegan.fit(dataset, epochs=1)
batches = timegan._get_data_batch(dataset, n_windows=len(dataset))
assert timegan._train_autoencoder(next(batches), timegan.autoencoder_opt)
assert timegan._train_supervisor(next(batches), timegan.adversarialsup_opt)
assert timegan._train_generator(
next(batches), next(timegan.get_noise_batch()), timegan.generator_opt
)
assert timegan._train_embedder(next(batches), timegan.embedder_opt)
assert timegan._train_discriminator(
next(batches), next(timegan.get_noise_batch()), timegan.discriminator_opt
)
try:
tf.config.experimental_run_functions_eagerly(True)
timegan.fit(dataset, epochs=1)
batches = timegan._get_data_batch(dataset, n_windows=len(dataset))
assert timegan._train_autoencoder(next(batches), timegan.autoencoder_opt)
assert timegan._train_supervisor(next(batches), timegan.adversarialsup_opt)
assert timegan._train_generator(
next(batches), next(timegan.get_noise_batch()), timegan.generator_opt
)
assert timegan._train_embedder(next(batches), timegan.embedder_opt)
assert timegan._train_discriminator(
next(batches), next(timegan.get_noise_batch()), timegan.discriminator_opt
)
finally:
tf.config.experimental_run_functions_eagerly(False)


@pytest.fixture
Expand Down Expand Up @@ -204,7 +215,11 @@ def test_timegan_train_autoencoder(mocked_data, mocked_timegan):

mocked_timegan._define_timegan()
X_ = next(batches)
loss = mocked_timegan._train_autoencoder(X_, mocked_timegan.autoencoder_opt)
try:
tf.config.experimental_run_functions_eagerly(True)
loss = mocked_timegan._train_autoencoder(X_, mocked_timegan.autoencoder_opt)
finally:
tf.config.experimental_run_functions_eagerly(False)

# Assert that the loss is a float
assert loss.dtype in [tf.float32, tf.float64]
Expand All @@ -215,8 +230,11 @@ def test_timegan_train_embedder(mocked_data, mocked_timegan):

mocked_timegan._define_timegan()
X_ = next(batches)
_, loss = mocked_timegan._train_embedder(X_, mocked_timegan.embedder_opt)

try:
tf.config.experimental_run_functions_eagerly(True)
_, loss = mocked_timegan._train_embedder(X_, mocked_timegan.embedder_opt)
finally:
tf.config.experimental_run_functions_eagerly(False)
# Assert that the loss is a float
assert loss.dtype in [tf.float32, tf.float64]

Expand All @@ -227,13 +245,17 @@ def test_timegan_train_generator(mocked_data, mocked_timegan):
mocked_timegan._define_timegan()
X_ = next(batches)
Z_ = next(mocked_timegan.get_noise_batch())
(
step_g_loss_u,
step_g_loss_u_e,
step_g_loss_s,
step_g_loss_v,
step_g_loss,
) = mocked_timegan._train_generator(X_, Z_, mocked_timegan.generator_opt)
try:
tf.config.experimental_run_functions_eagerly(True)
(
step_g_loss_u,
step_g_loss_u_e,
step_g_loss_s,
step_g_loss_v,
step_g_loss,
) = mocked_timegan._train_generator(X_, Z_, mocked_timegan.generator_opt)
finally:
tf.config.experimental_run_functions_eagerly(False)

# Assert that the loss is a float
for loss in (
Expand All @@ -252,19 +274,25 @@ def test_timegan_check_discriminator_loss(mocked_data, mocked_timegan):
mocked_timegan._define_timegan()
X_ = next(batches)
Z_ = next(mocked_timegan.get_noise_batch())
loss = mocked_timegan._check_discriminator_loss(X_, Z_)
try:
tf.config.experimental_run_functions_eagerly(True)
loss = mocked_timegan._check_discriminator_loss(X_, Z_)
finally:
tf.config.experimental_run_functions_eagerly(False)

# Assert that the loss is a float
assert loss.dtype in [tf.float32, tf.float64]


def test_timegan_train_discriminator(mocked_data, mocked_timegan):
batches = iter(mocked_data.repeat())

mocked_timegan._define_timegan()
X_ = next(batches)
Z_ = next(mocked_timegan.get_noise_batch())
loss = mocked_timegan._train_discriminator(X_, Z_, mocked_timegan.discriminator_opt)

try:
tf.config.experimental_run_functions_eagerly(True)
loss = mocked_timegan._train_discriminator(X_, Z_, mocked_timegan.discriminator_opt)
finally:
tf.config.experimental_run_functions_eagerly(False)
# Assert that the loss is a float
assert loss.dtype in [tf.float32, tf.float64]

0 comments on commit 4398afe

Please sign in to comment.