From a92a363d70a08ea71502a50b4c2c5369f60aeeb9 Mon Sep 17 00:00:00 2001 From: dnaligase <60392997+dnaligase@users.noreply.github.com> Date: Tue, 15 Aug 2023 14:03:43 +0300 Subject: [PATCH] Update TimeGAN (#18) * Update test_timegan.py * Update timeGAN.py * make keras-like * fix linting --- tests/test_timegan.py | 64 +++++++++++++++++++++++++++++++----------- tsgm/models/timeGAN.py | 21 ++++++++++---- 2 files changed, 63 insertions(+), 22 deletions(-) diff --git a/tests/test_timegan.py b/tests/test_timegan.py index b60c183..7296936 100644 --- a/tests/test_timegan.py +++ b/tests/test_timegan.py @@ -6,6 +6,46 @@ from tensorflow import keras +def test_timegan(): + latent_dim = 24 + feature_dim = 6 + seq_len = 24 + batch_size = 2 + + dataset = _gen_dataset(batch_size, seq_len, feature_dim) + timegan = tsgm.models.timeGAN.TimeGAN( + seq_len=seq_len, module="gru", hidden_dim=latent_dim, n_features=feature_dim, n_layers=3, batch_size=batch_size + ) + timegan.compile() + timegan.fit(dataset, epochs=1) + + _check_internals(timegan) + + # Check generation + generated_samples = timegan.generate(1) + assert generated_samples.shape == (1, seq_len, feature_dim) + + +def test_timegan_on_dataset(): + latent_dim = 24 + feature_dim = 6 + seq_len = 24 + batch_size = 16 + + dataset = _gen_tf_dataset(batch_size, seq_len, feature_dim) # tf.data.Dataset + timegan = tsgm.models.timeGAN.TimeGAN( + seq_len=seq_len, module="gru", hidden_dim=latent_dim, n_features=feature_dim, n_layers=3, batch_size=batch_size + ) + timegan.compile() + timegan.fit(dataset, epochs=1) + + _check_internals(timegan) + + # Check generation + generated_samples = timegan.generate(1) + assert generated_samples.shape == (1, seq_len, feature_dim) + + def _gen_dataset(no, seq_len, dim): """Sine data generation. Args: @@ -42,18 +82,15 @@ def _gen_dataset(no, seq_len, dim): return data -def test_timegan(): - latent_dim = 24 - feature_dim = 6 - seq_len = 24 - batch_size = 2 +def _gen_tf_dataset(no, seq_len, dim): + dataset = _gen_dataset(no, seq_len, dim) + dataset = tf.convert_to_tensor(dataset, dtype=tf.float32) + dataset = tf.data.Dataset.from_tensors(dataset).unbatch().batch(no) + + return dataset - dataset = _gen_dataset(batch_size, seq_len, feature_dim) - timegan = tsgm.models.timeGAN.TimeGAN( - seq_len=seq_len, module="gru", hidden_dim=latent_dim, n_features=feature_dim, n_layers=3, batch_size=batch_size - ) - timegan.compile() - timegan.fit(dataset, epochs=1) + +def _check_internals(timegan): # Check internal nets assert timegan.generator is not None @@ -74,8 +111,3 @@ def test_timegan(): assert timegan.embedder_opt is not None assert timegan.autoencoder_opt is not None assert timegan.adversarialsup_opt is not None - - # Check generation - generated_samples = timegan.generate(1) - assert generated_samples.shape == (1, seq_len, feature_dim) - diff --git a/tsgm/models/timeGAN.py b/tsgm/models/timeGAN.py index b9de8e0..460e15d 100644 --- a/tsgm/models/timeGAN.py +++ b/tsgm/models/timeGAN.py @@ -47,7 +47,7 @@ def labels(self) -> list: return list(self.keys()) -class TimeGAN: +class TimeGAN(keras.Model): """ Time-series Generative Adversarial Networks (TimeGAN) @@ -68,6 +68,7 @@ def __init__( batch_size: int = 256, gamma: float = 1.0, ): + super().__init__() self.seq_len = seq_len self.hidden_dim = hidden_dim self.dim = n_features @@ -453,10 +454,12 @@ def _get_data_batch(self, data, n_windows: int) -> typing.Iterator: def fit( self, - data: TensorLike, + data: typing.Union[TensorLike, tf.data.Dataset], epochs: int, checkpoints_interval: typing.Optional[int] = None, generate_synthetic: tuple = (), + *args, + **kwargs, ): """ :param data: TensorLike, the training data @@ -479,6 +482,12 @@ def fit( self._mse is None or self._bce is None ), "One of the loss functions is not defined. Please call .compile() to set them" + # take tf.data.Dataset | TensorLike + if isinstance(data, tf.data.Dataset): + batches = iter(data.repeat()) + else: + batches = self._get_data_batch(data, n_windows=len(data)) + # Define the model self._define_timegan() @@ -486,7 +495,7 @@ def fit( print("Start Embedding Network Training") for epoch in tqdm(range(epochs), desc="Autoencoder - training"): - X_ = next(self._get_data_batch(data, n_windows=len(data))) + X_ = next(batches) step_e_loss_0 = self._train_autoencoder(X_, self.autoencoder_opt) # Checkpoint @@ -501,7 +510,7 @@ def fit( # Adversarial Supervised network training for epoch in tqdm(range(epochs), desc="Adversarial Supervised - training"): - X_ = next(self._get_data_batch(data, n_windows=len(data))) + X_ = next(batches) step_g_loss_s = self._train_supervisor(X_, self.adversarialsup_opt) # Checkpoint @@ -523,7 +532,7 @@ def fit( # Generator training (twice more than discriminator training) for kk in range(2): - X_ = next(self._get_data_batch(data, n_windows=len(data))) + X_ = next(batches) Z_ = next(self.get_noise_batch()) # -------------------------- # Train the generator @@ -541,7 +550,7 @@ def fit( # -------------------------- _, step_e_loss_t0 = self._train_embedder(X_, self.embedder_opt) - X_ = next(self._get_data_batch(data, n_windows=len(data))) + X_ = next(batches) Z_ = next(self.get_noise_batch()) step_d_loss = self._check_discriminator_loss(X_, Z_) if step_d_loss > 0.15: