forked from phuccuongngo99/Fence_GAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
29 lines (23 loc) · 1.61 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import argparse
from fgan_train import training_pipeline
parser = argparse.ArgumentParser('Train your Fence GAN')
###Training hyperparameter
args = parser.add_argument('--dataset', type=str, default='mnist', help='mnist | cifar10')
args = parser.add_argument('--ano_class',type=int, default=2, help='1 anomaly class')
args = parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train')
###FenceGAN hyperparameter
args = parser.add_argument('--beta', type=float, default=30,help='beta')
args = parser.add_argument('--gamma', type=float, default=0.1, help='gamma')
args = parser.add_argument('--alpha', type=float, default=0.5, help='alpha')
###Other hyperparameters
args = parser.add_argument('--batch_size',type=int, default=200, help='')
args = parser.add_argument('--pretrain',type=int, default=15, help='number of pretrain epoch')
args = parser.add_argument('--d_l2', type=float, default=0, help='L2 Regularizer for Discriminator')
args = parser.add_argument('--d_lr', type=float, default=1e-5, help='learning_rate of discriminator')
args = parser.add_argument('--g_lr', type=float, default=2e-5, help='learning rate of generator')
args = parser.add_argument('--v_freq', type=int, default=4, help='epoch frequency to evaluate performance')
args = parser.add_argument('--seed', type=int, default=0, help='numpy and tensorflow seed')
args = parser.add_argument('--evaluation', type=str, default='auprc', help="'auprc' or 'auroc'")
args = parser.add_argument('--latent_dim', type=int, default=200, help='Latent dimension of Gaussian noise input to Generator')
args = parser.parse_args()
training_pipeline(args)