From c2884d7e9d63acae75edc1790c52bc80e6b376d8 Mon Sep 17 00:00:00 2001 From: Alexander Nikitin <1243786+AlexanderVNikitin@users.noreply.github.com> Date: Fri, 16 Aug 2024 16:06:14 +0300 Subject: [PATCH] add ddpm --- tests/test_ddpm.py | 37 ++++ tsgm/models/__init__.py | 1 + tsgm/models/architectures/zoo.py | 130 ++++++++++++ tsgm/models/ddpm.py | 341 +++++++++++++++++++++++++++++++ 4 files changed, 509 insertions(+) create mode 100644 tests/test_ddpm.py create mode 100644 tsgm/models/ddpm.py diff --git a/tests/test_ddpm.py b/tests/test_ddpm.py new file mode 100644 index 0000000..cfe19af --- /dev/null +++ b/tests/test_ddpm.py @@ -0,0 +1,37 @@ +import pytest +import tsgm + +import tensorflow as tf +import numpy as np +from tensorflow import keras + + +def test_ddpm(): + seq_len = 12 + feat_dim = 1 + + model_type = tsgm.models.architectures.zoo["ddpm_denoiser"] + architecture = model_type(seq_len=seq_len, feat_dim=feat_dim) + + denoiser_model = architecture.model + + X = tsgm.utils.gen_sine_dataset(50, seq_len, feat_dim, max_value=20) + + scaler = tsgm.utils.TSFeatureWiseScaler((0, 1)) + X = scaler.fit_transform(X).astype(np.float64) + + ddpm_model = tsgm.models.ddpm.DDPM(denoiser_model, model_type(seq_len=seq_len, feat_dim=feat_dim).model, 1000) + ddpm_model.compile( + loss=keras.losses.MeanSquaredError(), + optimizer=keras.optimizers.Adam(0.0003)) + + with pytest.raises(ValueError): + ddpm_model.generate(7) + + ddpm_model.fit(X, epochs=1, batch_size=128) + + x_samples = ddpm_model.generate(7) + assert x_samples.shape == (7, seq_len, feat_dim) + + x_decoded = ddpm_model(3) + assert x_decoded.shape == (3, seq_len, feat_dim) diff --git a/tsgm/models/__init__.py b/tsgm/models/__init__.py index 0a164e3..596dce2 100644 --- a/tsgm/models/__init__.py +++ b/tsgm/models/__init__.py @@ -5,5 +5,6 @@ import tsgm.models.monitors import tsgm.models.sts import tsgm.models.timeGAN +import tsgm.models.ddpm from tsgm.models.architectures import zoo diff --git a/tsgm/models/architectures/zoo.py b/tsgm/models/architectures/zoo.py index 87e4b21..7716424 100644 --- a/tsgm/models/architectures/zoo.py +++ b/tsgm/models/architectures/zoo.py @@ -964,6 +964,135 @@ def _build_discriminator(self): return keras.Model(inputs, x) +class TimeEmbedding(layers.Layer): + def __init__(self, dim: int, **kwargs) -> None: + super().__init__(**kwargs) + self.dim = dim + self.half_dim = dim // 2 + self.emb = math.log(10000) / (self.half_dim - 1) + self.emb = tf.exp(tf.range(self.half_dim, dtype=tf.float32) * -self.emb) + + def call(self, inputs: tsgm.types.Tensor) -> tsgm.types.Tensor: + inputs = tf.cast(inputs, dtype=tf.float32) + emb = inputs[:, None] * self.emb[None, :] + emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=-1) + + return emb + + +class BaseDenoisingArchitecture(Architecture): + """ + Base class for denoising architectures in DDPM (Denoising Diffusion Probabilistic Models, `tsgm.models.ddpm`). + + Attributes: + arch_type: A string indicating the type of architecture, set to "ddpm:denoising". + _seq_len: The length of the input sequences. + _feat_dim: The dimensionality of the input features. + _n_filters: The number of filters used in the convolutional layers. + _n_conv_layers: The number of convolutional layers in the model. + _model: The Keras model instance built using the `_build_model` method. + """ + + arch_type = "ddpm:denoising" + + def __init__(self, seq_len: int, feat_dim: int, n_filters: int = 64, n_conv_layers: int = 3, **kwargs) -> None: + """ + Initializes the BaseDenoisingArchitecture with the specified parameters. + + Args: + seq_len (int): The length of the input sequences. + feat_dim (int): The dimensionality of the input features. + n_filters (int, optional): The number of filters for convolutional layers. Default is 64. + n_conv_layers (int, optional): The number of convolutional layers. Default is 3. + **kwargs: Additional keyword arguments to be passed to the parent class `Architecture`. + """ + self._seq_len = seq_len + self._feat_dim = feat_dim + self._n_filters = n_filters + self._n_conv_layers = n_conv_layers + self._model = self._build_model() + + @property + def model(self) -> keras.models.Model: + """ + Provides access to the Keras model instance. + + Returns: + keras.models.Model: The Keras model instance built by `_build_model`. + """ + return self._model + + def get(self) -> T.Dict: + """ + Returns a dictionary containing the model. + + :returns: A dictionary containing the model. + :rtype: dict + """ + return {"model": self.model} + + def _build_model(self) -> None: + """ + Abstract method for building the Keras model. + Subclasses must implement this method to define the specific architecture of the model. + + Raises: + NotImplementedError: If the method is not overridden by a subclass. + """ + raise NotImplementedError + + +class DDPMConvDenoiser(BaseDenoisingArchitecture): + """ + A convolutional denoising model for DDPM. + + This class defines a convolutional neural network architecture used as a denoiser in DDPM. + It predicts the noise added to the input samples during the diffusion process. + + Attributes: + arch_type: A string indicating the architecture type, set to "ddpm:denoiser". + """ + arch_type = "ddpm:denoiser" + + def __init__(self, **kwargs): + """ + Initializes the DDPMConvDenoiser model with additional parameters. + + Args: + **kwargs: Additional keyword arguments to be passed to the parent class. + """ + super().__init__(**kwargs) + + def _build_model(self) -> keras.Model: + """ + Constructs and returns the Keras model for the DDPM denoiser. + + The model consists of: + - A 1D convolutional layer to process input features. + - An additional input layer for time embedding to incorporate timestep information. + - `n_conv_layers` convolutional layers to process the combined features and time embeddings. + - A final convolutional layer to output the predicted noise. + + Returns: + keras.Model: The Keras model instance for the DDPM denoiser. + """ + inputs = keras.Input(shape=(self._seq_len, self._feat_dim)) + + # Input for the additional float parameter + time_input = keras.Input(shape=(1,)) + + temb = TimeEmbedding(dim=self._seq_len)(time_input) + temb = keras.layers.Reshape((temb.shape[-1], 1))(temb) + + x = layers.Concatenate()([inputs, temb]) + + for l_id in range(self._n_conv_layers): + x = layers.Conv1D(self._n_filters, 3, padding="same", activation="relu")(x) + + outputs = layers.Conv1D(self._feat_dim, 3, padding="same")(x) + return keras.Model(inputs=[inputs, time_input], outputs=outputs) + + class Zoo(dict): """ A collection of architectures represented. It behaves like supports Python `dict` API. @@ -995,6 +1124,7 @@ def summary(self) -> None: "cgan_lstm_n": cGAN_LSTMnArchitecture, "cgan_lstm_3": cGAN_LSTMConv3Architecture, "wavegan": WaveGANArchitecture, + "ddpm_denoiser": DDPMConvDenoiser, # Downstream models "clf_cn": ConvnArchitecture, diff --git a/tsgm/models/ddpm.py b/tsgm/models/ddpm.py new file mode 100644 index 0000000..e007c5a --- /dev/null +++ b/tsgm/models/ddpm.py @@ -0,0 +1,341 @@ +""" +The implementation is based on Keras DDPM implementation: https://keras.io/examples/generative/ddpm/ +""" +import numpy as np + +from tensorflow import keras +import tensorflow as tf +from tensorflow.python.types.core import TensorLike + +import typing as T + + +class GaussianDiffusion: + """Gaussian diffusion utility for generating samples using a diffusion process. + + This class implements a Gaussian diffusion process, where a sample is gradually + perturbed by adding Gaussian noise over a series of timesteps. It also includes + methods to reverse the diffusion process, predicting the original data from + the noisy samples. + + Args: + beta_start (float): Start value of the scheduled variance for the diffusion process. + beta_end (float): End value of the scheduled variance for the diffusion process. + timesteps (int): Number of timesteps in the forward process. + """ + + def __init__( + self, + beta_start: float = 1e-4, + beta_end: float = 0.02, + timesteps: int = 1000, + ) -> None: + self.beta_start = beta_start + self.beta_end = beta_end + self.timesteps = timesteps + + # Define the linear variance schedule + self.betas = betas = np.linspace( + beta_start, + beta_end, + timesteps, + dtype=np.float64, # Using float64 for better precision + ) + self.num_timesteps = int(timesteps) + + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + self.betas = tf.constant(betas, dtype=tf.float32) + self.alphas_cumprod = tf.constant(alphas_cumprod, dtype=tf.float32) + self.alphas_cumprod_prev = tf.constant(alphas_cumprod_prev, dtype=tf.float32) + + # Calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = tf.constant( + np.sqrt(alphas_cumprod), dtype=tf.float32 + ) + + self.sqrt_one_minus_alphas_cumprod = tf.constant( + np.sqrt(1.0 - alphas_cumprod), dtype=tf.float32 + ) + + self.log_one_minus_alphas_cumprod = tf.constant( + np.log(1.0 - alphas_cumprod), dtype=tf.float32 + ) + + self.sqrt_recip_alphas_cumprod = tf.constant( + np.sqrt(1.0 / alphas_cumprod), dtype=tf.float32 + ) + self.sqrt_recipm1_alphas_cumprod = tf.constant( + np.sqrt(1.0 / alphas_cumprod - 1), dtype=tf.float32 + ) + + # Calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = ( + betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) + ) + self.posterior_variance = tf.constant(posterior_variance, dtype=tf.float32) + + # Log calculation clipped because the posterior variance is 0 at the beginning + # of the diffusion chain + self.posterior_log_variance_clipped = tf.constant( + np.log(np.maximum(posterior_variance, 1e-20)), dtype=tf.float32 + ) + + self.posterior_mean_coef1 = tf.constant( + betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod), + dtype=tf.float32, + ) + + self.posterior_mean_coef2 = tf.constant( + (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod), + dtype=tf.float32, + ) + + def _extract(self, a: TensorLike, t: int, x_shape: tf.TensorShape) -> TensorLike: + """ + Extracts coefficients for a specific timestep and reshapes them for broadcasting. + + Args: + a: Tensor to extract from. + t: Timestep for which the coefficients are to be extracted. + x_shape: Shape of the current batched samples. + + Returns: + Tensor reshaped to [batch_size, 1, 1] for broadcasting. + """ + batch_size = x_shape[0] + out = tf.gather(a, t) + return tf.reshape(out, [batch_size, 1, 1]) + + def q_mean_variance(self, x_start: TensorLike, t: float) -> T.Tuple: + """Extracts the mean and variance at a specific timestep in the forward diffusion process. + + Args: + x_start: Initial sample (before the first diffusion step). + t: A timestep. + + Returns: + mean, variance, log_variance: Tensors representing the mean, variance, + and log variance of the distribution at `t`. + """ + x_start_shape = tf.shape(x_start) + mean = self._extract(self.sqrt_alphas_cumprod, t, x_start_shape) * x_start + variance = self._extract(1.0 - self.alphas_cumprod, t, x_start_shape) + log_variance = self._extract( + self.log_one_minus_alphas_cumprod, t, x_start_shape + ) + return mean, variance, log_variance + + def q_sample(self, x_start: TensorLike, t: float, noise: float) -> T.Tuple: + """Performs the forward diffusion step by adding Gaussian noise to the sample. + + Args: + x_start: Initial sample (before the first diffusion step) + t: Current timestep + noise: Gaussian noise to be added at timestep `t` + + Returns: + Diffused samples at timestep `t` + """ + x_start_shape = tf.shape(x_start) + return ( + self._extract(self.sqrt_alphas_cumprod, t, tf.shape(x_start)) * x_start + + self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start_shape) + * noise + ) + + def predict_start_from_noise(self, x_t: TensorLike, t, noise): + """Predicts the initial sample from the noisy sample at timestep `t`. + + Args: + x_t: Noisy sample at timestep `t`. + t: Current timestep. + noise: Gaussian noise added at timestep `t`. + + Returns: + Predicted initial sample. + """ + + x_t_shape = tf.shape(x_t) + return ( + self._extract(self.sqrt_recip_alphas_cumprod, t, x_t_shape) * x_t + - self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t_shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + """Computes the mean and variance of the posterior distribution q(x_{t-1} | x_t, x_0). + + Args: + x_start: Initial sample (x_0) for the posterior computation. + x_t: Sample at timestep `t`. + t: Current timestep. + + Returns: + Posterior mean, variance, and clipped log variance at the current timestep. + """ + + x_t_shape = tf.shape(x_t) + posterior_mean = ( + self._extract(self.posterior_mean_coef1, t, x_t_shape) * x_start + + self._extract(self.posterior_mean_coef2, t, x_t_shape) * x_t + ) + posterior_variance = self._extract(self.posterior_variance, t, x_t_shape) + posterior_log_variance_clipped = self._extract( + self.posterior_log_variance_clipped, t, x_t_shape + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, pred_noise, x, t): + """Predicts the mean and variance for the reverse diffusion step. + + Args: + pred_noise: Noise predicted by the diffusion model. + x: Samples at a given timestep for which the noise was predicted. + t: Current timestep. + + Returns: + model_mean, posterior_variance, posterior_log_variance: Tensors + representing the mean and variance of the model at the current timestep. + """ + x_recon = self.predict_start_from_noise(x, t=t, noise=pred_noise) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t + ) + return model_mean, posterior_variance, posterior_log_variance + + def p_sample(self, pred_noise, x, t): + """Generates a sample from the diffusion model by reversing the diffusion process. + + Args: + pred_noise: Noise predicted by the diffusion model. + x: Samples at a given timestep for which the noise was predicted. + t: Current timestep. + + Returns: + Sample generated by reversing the diffusion process at timestep `t`. + """ + model_mean, _, model_log_variance = self.p_mean_variance( + pred_noise, x=x, t=t + ) + noise = tf.random.normal(shape=x.shape, dtype=x.dtype) + # No noise when t == 0 + nonzero_mask = tf.reshape( + 1 - tf.cast(tf.equal(t, 0), tf.float32), [tf.shape(x)[0], 1, 1] + ) + return model_mean + nonzero_mask * tf.exp(0.5 * model_log_variance) * noise + + +class DDPM(keras.Model): + """ + Denoising Diffusion Probabilistic Model + + Args: + network (keras.Model): A Keras model that predicts the noise added to the images. + ema_network (keras.Model): EMA model, a clone of `network` + timesteps (int): The number of timesteps in the diffusion process. + ema (float): The decay factor for the EMA, default is 0.999. + """ + def __init__(self, network: keras.Model, ema_network: keras.Model, timesteps: int, ema: float = 0.999) -> None: + super().__init__() + self.network = network + self.ema_network = ema_network + self.timesteps = timesteps + self.gdf_util = GaussianDiffusion(timesteps=timesteps) + self.ema = ema + + self.ema_network.set_weights(network.get_weights()) # Initially the weights are the same + + # Filled in during training + self.seq_len = None + self.feat_dim = None + + def train_step(self, images: TensorLike) -> T.Dict: + """ + Performs a single training step on a batch of images. + + Args: + images: A batch of images to train on. + + Returns: + A dictionary containing the loss value for the training step. + """ + self.seq_len, self.feat_dim = images.shape[1], images.shape[2] + + # 1. Get the batch size + batch_size = tf.shape(images)[0] + + # 2. Sample timesteps uniformly + t = tf.random.uniform( + minval=0, maxval=self.timesteps, shape=(batch_size,), dtype=tf.int64 + ) + + with tf.GradientTape() as tape: + # 3. Sample random noise to be added to the images in the batch + noise = tf.random.normal(shape=tf.shape(images), dtype=images.dtype) + + # 4. Diffuse the images with noise + images_t = self.gdf_util.q_sample(images, t, noise) + + # 5. Pass the diffused images and time steps to the network + pred_noise = self.network([images_t, t], training=True) + + # 6. Calculate the loss + loss = self.loss(noise, pred_noise) + + # 7. Get the gradients + gradients = tape.gradient(loss, self.network.trainable_weights) + + # 8. Update the weights of the network + self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights)) + + # 9. Updates the weight values for the network with EMA weights + for weight, ema_weight in zip(self.network.weights, self.ema_network.weights): + ema_weight.assign(self.ema * ema_weight + (1 - self.ema) * weight) + + # 10. Return loss values + return {"loss": loss} + + def generate(self, n_samples: int = 16) -> TensorLike: + """ + Generates new samples by running the reverse diffusion process. + + Args: + n_samples: The number of samples to generate. + + Returns: + Generated samples after running the reverse diffusion process. + """ + + if self.seq_len is None or self.feat_dim is None: + raise ValueError("DDPM is not trained") + + # 1. Randomly sample noise (starting point for reverse process) + samples = tf.random.normal( + shape=(n_samples, self.seq_len, self.feat_dim), dtype=tf.float32 + ) + # 2. Sample from the model iteratively + for t in reversed(range(0, self.timesteps)): + tt = tf.cast(tf.fill(n_samples, t), dtype=tf.int64) + pred_noise = self.ema_network.predict( + [samples, tt], verbose=0, batch_size=n_samples + ) + samples = self.gdf_util.p_sample( + pred_noise, samples, tt + ) + # 3. Return generated samples + return samples + + def call(self, n_samples: int) -> TensorLike: + """ + Calls the generate method to produce samples. + + Args: + n_samples: The number of samples to generate. + + Returns: + Generated samples. + """ + return self.generate(n_samples)