Skip to content

Commit

Permalink
add ddpm
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderVNikitin committed Aug 16, 2024
1 parent e487aca commit c2884d7
Show file tree
Hide file tree
Showing 4 changed files with 509 additions and 0 deletions.
37 changes: 37 additions & 0 deletions tests/test_ddpm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest
import tsgm

import tensorflow as tf
import numpy as np
from tensorflow import keras


def test_ddpm():
seq_len = 12
feat_dim = 1

model_type = tsgm.models.architectures.zoo["ddpm_denoiser"]
architecture = model_type(seq_len=seq_len, feat_dim=feat_dim)

denoiser_model = architecture.model

X = tsgm.utils.gen_sine_dataset(50, seq_len, feat_dim, max_value=20)

scaler = tsgm.utils.TSFeatureWiseScaler((0, 1))
X = scaler.fit_transform(X).astype(np.float64)

ddpm_model = tsgm.models.ddpm.DDPM(denoiser_model, model_type(seq_len=seq_len, feat_dim=feat_dim).model, 1000)
ddpm_model.compile(
loss=keras.losses.MeanSquaredError(),
optimizer=keras.optimizers.Adam(0.0003))

with pytest.raises(ValueError):
ddpm_model.generate(7)

ddpm_model.fit(X, epochs=1, batch_size=128)

x_samples = ddpm_model.generate(7)
assert x_samples.shape == (7, seq_len, feat_dim)

x_decoded = ddpm_model(3)
assert x_decoded.shape == (3, seq_len, feat_dim)
1 change: 1 addition & 0 deletions tsgm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
import tsgm.models.monitors
import tsgm.models.sts
import tsgm.models.timeGAN
import tsgm.models.ddpm

from tsgm.models.architectures import zoo
130 changes: 130 additions & 0 deletions tsgm/models/architectures/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,135 @@ def _build_discriminator(self):
return keras.Model(inputs, x)


class TimeEmbedding(layers.Layer):
def __init__(self, dim: int, **kwargs) -> None:
super().__init__(**kwargs)
self.dim = dim
self.half_dim = dim // 2
self.emb = math.log(10000) / (self.half_dim - 1)
self.emb = tf.exp(tf.range(self.half_dim, dtype=tf.float32) * -self.emb)

def call(self, inputs: tsgm.types.Tensor) -> tsgm.types.Tensor:
inputs = tf.cast(inputs, dtype=tf.float32)
emb = inputs[:, None] * self.emb[None, :]
emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=-1)

return emb


class BaseDenoisingArchitecture(Architecture):
"""
Base class for denoising architectures in DDPM (Denoising Diffusion Probabilistic Models, `tsgm.models.ddpm`).
Attributes:
arch_type: A string indicating the type of architecture, set to "ddpm:denoising".
_seq_len: The length of the input sequences.
_feat_dim: The dimensionality of the input features.
_n_filters: The number of filters used in the convolutional layers.
_n_conv_layers: The number of convolutional layers in the model.
_model: The Keras model instance built using the `_build_model` method.
"""

arch_type = "ddpm:denoising"

def __init__(self, seq_len: int, feat_dim: int, n_filters: int = 64, n_conv_layers: int = 3, **kwargs) -> None:
"""
Initializes the BaseDenoisingArchitecture with the specified parameters.
Args:
seq_len (int): The length of the input sequences.
feat_dim (int): The dimensionality of the input features.
n_filters (int, optional): The number of filters for convolutional layers. Default is 64.
n_conv_layers (int, optional): The number of convolutional layers. Default is 3.
**kwargs: Additional keyword arguments to be passed to the parent class `Architecture`.
"""
self._seq_len = seq_len
self._feat_dim = feat_dim
self._n_filters = n_filters
self._n_conv_layers = n_conv_layers
self._model = self._build_model()

@property
def model(self) -> keras.models.Model:
"""
Provides access to the Keras model instance.
Returns:
keras.models.Model: The Keras model instance built by `_build_model`.
"""
return self._model

def get(self) -> T.Dict:
"""
Returns a dictionary containing the model.
:returns: A dictionary containing the model.
:rtype: dict
"""
return {"model": self.model}

def _build_model(self) -> None:
"""
Abstract method for building the Keras model.
Subclasses must implement this method to define the specific architecture of the model.
Raises:
NotImplementedError: If the method is not overridden by a subclass.
"""
raise NotImplementedError


class DDPMConvDenoiser(BaseDenoisingArchitecture):
"""
A convolutional denoising model for DDPM.
This class defines a convolutional neural network architecture used as a denoiser in DDPM.
It predicts the noise added to the input samples during the diffusion process.
Attributes:
arch_type: A string indicating the architecture type, set to "ddpm:denoiser".
"""
arch_type = "ddpm:denoiser"

def __init__(self, **kwargs):
"""
Initializes the DDPMConvDenoiser model with additional parameters.
Args:
**kwargs: Additional keyword arguments to be passed to the parent class.
"""
super().__init__(**kwargs)

def _build_model(self) -> keras.Model:
"""
Constructs and returns the Keras model for the DDPM denoiser.
The model consists of:
- A 1D convolutional layer to process input features.
- An additional input layer for time embedding to incorporate timestep information.
- `n_conv_layers` convolutional layers to process the combined features and time embeddings.
- A final convolutional layer to output the predicted noise.
Returns:
keras.Model: The Keras model instance for the DDPM denoiser.
"""
inputs = keras.Input(shape=(self._seq_len, self._feat_dim))

# Input for the additional float parameter
time_input = keras.Input(shape=(1,))

temb = TimeEmbedding(dim=self._seq_len)(time_input)
temb = keras.layers.Reshape((temb.shape[-1], 1))(temb)

x = layers.Concatenate()([inputs, temb])

for l_id in range(self._n_conv_layers):
x = layers.Conv1D(self._n_filters, 3, padding="same", activation="relu")(x)

outputs = layers.Conv1D(self._feat_dim, 3, padding="same")(x)
return keras.Model(inputs=[inputs, time_input], outputs=outputs)


class Zoo(dict):
"""
A collection of architectures represented. It behaves like supports Python `dict` API.
Expand Down Expand Up @@ -995,6 +1124,7 @@ def summary(self) -> None:
"cgan_lstm_n": cGAN_LSTMnArchitecture,
"cgan_lstm_3": cGAN_LSTMConv3Architecture,
"wavegan": WaveGANArchitecture,
"ddpm_denoiser": DDPMConvDenoiser,

# Downstream models
"clf_cn": ConvnArchitecture,
Expand Down
Loading

0 comments on commit c2884d7

Please sign in to comment.