diff --git a/examples/cifar10/utils_cifar.py b/examples/cifar10/utils_cifar.py index 5152367..1eec02e 100644 --- a/examples/cifar10/utils_cifar.py +++ b/examples/cifar10/utils_cifar.py @@ -11,7 +11,19 @@ def generate_samples(model, parallel, savedir, step, net_="normal"): - """Generate 64 images (8 x 8) for sanity check along training.""" + """Save 64 generated images (8 x 8) for sanity check along training. + + Parameters + ---------- + model: + represents the neural network that we want to generate samples from + parallel: bool + represents the parallel training flag. Torchdyn only runs on 1 GPU, we need to send the models from several GPUs to 1 GPU. + savedir: str + represents the path where we want to save the generated images + step: int + represents the current step of training + """ model.eval() model_ = copy.deepcopy(model)