MMD_GAN.jl
is a Julia module implementing a Maximum Mean Discrepancy (MMD) Generative Adversarial Network. This module provides functionalities to train GAN models using MMD for measuring the discrepancy between the generated and real data distributions. It is designed for easy experimentation with different hyperparameters and model architectures.
To use MMD_GAN.jl, clone this repository into your local machine. Make sure you have Julia installed and set up on your system.
git clone [email protected]:josemanuel22/MMD_GAN.jl.git
The module includes core functionalities to define hyperparameters, set up models, and train an MMD GAN. Key components include:
HyperParamsMMD
: A structure to define hyperparameters for the MMD GAN.train_mmd_gan
: A function to train the MMD GAN using specified encoder, decoder, and generator models with given hyperparameters.
using MMD_GAN
# Define your models (encoder, decoder, generator)
# enc = ...
# dec = ...
# gen = ...
# Define hyperparameters
hparams = HyperParamsMMD(
target_model = Normal(23.0f0, 1.0f0),
noise_model = Normal(0.0f0, 1.0f0),
# Other hyperparameters...
)
# Train the model
losses_gen, losses_dscr = train_mmd_gan(enc, dec, gen, hparams)
Contributions to MMD_GAN.jl are welcome. Please read our contribution guidelines for more details.
This project is licensed under the MIT License.