Skip to content

Commit

Permalink
Update TimeGAN (#18)
Browse files Browse the repository at this point in the history
* Update test_timegan.py

* Update timeGAN.py

* make keras-like

* fix linting
  • Loading branch information
dnaligase authored Aug 15, 2023
1 parent a4a0a27 commit a92a363
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 22 deletions.
64 changes: 48 additions & 16 deletions tests/test_timegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,46 @@
from tensorflow import keras


def test_timegan():
latent_dim = 24
feature_dim = 6
seq_len = 24
batch_size = 2

dataset = _gen_dataset(batch_size, seq_len, feature_dim)
timegan = tsgm.models.timeGAN.TimeGAN(
seq_len=seq_len, module="gru", hidden_dim=latent_dim, n_features=feature_dim, n_layers=3, batch_size=batch_size
)
timegan.compile()
timegan.fit(dataset, epochs=1)

_check_internals(timegan)

# Check generation
generated_samples = timegan.generate(1)
assert generated_samples.shape == (1, seq_len, feature_dim)


def test_timegan_on_dataset():
latent_dim = 24
feature_dim = 6
seq_len = 24
batch_size = 16

dataset = _gen_tf_dataset(batch_size, seq_len, feature_dim) # tf.data.Dataset
timegan = tsgm.models.timeGAN.TimeGAN(
seq_len=seq_len, module="gru", hidden_dim=latent_dim, n_features=feature_dim, n_layers=3, batch_size=batch_size
)
timegan.compile()
timegan.fit(dataset, epochs=1)

_check_internals(timegan)

# Check generation
generated_samples = timegan.generate(1)
assert generated_samples.shape == (1, seq_len, feature_dim)


def _gen_dataset(no, seq_len, dim):
"""Sine data generation.
Args:
Expand Down Expand Up @@ -42,18 +82,15 @@ def _gen_dataset(no, seq_len, dim):
return data


def test_timegan():
latent_dim = 24
feature_dim = 6
seq_len = 24
batch_size = 2
def _gen_tf_dataset(no, seq_len, dim):
dataset = _gen_dataset(no, seq_len, dim)
dataset = tf.convert_to_tensor(dataset, dtype=tf.float32)
dataset = tf.data.Dataset.from_tensors(dataset).unbatch().batch(no)

return dataset

dataset = _gen_dataset(batch_size, seq_len, feature_dim)
timegan = tsgm.models.timeGAN.TimeGAN(
seq_len=seq_len, module="gru", hidden_dim=latent_dim, n_features=feature_dim, n_layers=3, batch_size=batch_size
)
timegan.compile()
timegan.fit(dataset, epochs=1)

def _check_internals(timegan):

# Check internal nets
assert timegan.generator is not None
Expand All @@ -74,8 +111,3 @@ def test_timegan():
assert timegan.embedder_opt is not None
assert timegan.autoencoder_opt is not None
assert timegan.adversarialsup_opt is not None

# Check generation
generated_samples = timegan.generate(1)
assert generated_samples.shape == (1, seq_len, feature_dim)

21 changes: 15 additions & 6 deletions tsgm/models/timeGAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def labels(self) -> list:
return list(self.keys())


class TimeGAN:
class TimeGAN(keras.Model):
"""
Time-series Generative Adversarial Networks (TimeGAN)
Expand All @@ -68,6 +68,7 @@ def __init__(
batch_size: int = 256,
gamma: float = 1.0,
):
super().__init__()
self.seq_len = seq_len
self.hidden_dim = hidden_dim
self.dim = n_features
Expand Down Expand Up @@ -453,10 +454,12 @@ def _get_data_batch(self, data, n_windows: int) -> typing.Iterator:

def fit(
self,
data: TensorLike,
data: typing.Union[TensorLike, tf.data.Dataset],
epochs: int,
checkpoints_interval: typing.Optional[int] = None,
generate_synthetic: tuple = (),
*args,
**kwargs,
):
"""
:param data: TensorLike, the training data
Expand All @@ -479,14 +482,20 @@ def fit(
self._mse is None or self._bce is None
), "One of the loss functions is not defined. Please call .compile() to set them"

# take tf.data.Dataset | TensorLike
if isinstance(data, tf.data.Dataset):
batches = iter(data.repeat())
else:
batches = self._get_data_batch(data, n_windows=len(data))

# Define the model
self._define_timegan()

# 1. Embedding network training
print("Start Embedding Network Training")

for epoch in tqdm(range(epochs), desc="Autoencoder - training"):
X_ = next(self._get_data_batch(data, n_windows=len(data)))
X_ = next(batches)
step_e_loss_0 = self._train_autoencoder(X_, self.autoencoder_opt)

# Checkpoint
Expand All @@ -501,7 +510,7 @@ def fit(

# Adversarial Supervised network training
for epoch in tqdm(range(epochs), desc="Adversarial Supervised - training"):
X_ = next(self._get_data_batch(data, n_windows=len(data)))
X_ = next(batches)
step_g_loss_s = self._train_supervisor(X_, self.adversarialsup_opt)

# Checkpoint
Expand All @@ -523,7 +532,7 @@ def fit(

# Generator training (twice more than discriminator training)
for kk in range(2):
X_ = next(self._get_data_batch(data, n_windows=len(data)))
X_ = next(batches)
Z_ = next(self.get_noise_batch())
# --------------------------
# Train the generator
Expand All @@ -541,7 +550,7 @@ def fit(
# --------------------------
_, step_e_loss_t0 = self._train_embedder(X_, self.embedder_opt)

X_ = next(self._get_data_batch(data, n_windows=len(data)))
X_ = next(batches)
Z_ = next(self.get_noise_batch())
step_d_loss = self._check_discriminator_loss(X_, Z_)
if step_d_loss > 0.15:
Expand Down

0 comments on commit a92a363

Please sign in to comment.