-
Notifications
You must be signed in to change notification settings - Fork 9
/
config.py
94 lines (90 loc) · 4.5 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import sys
import os
import argparse
import numpy as np
BATCH_SIZE = 64
CKPT = 0
COMPRESS_RATIO = 0.5
DATA_DIRECTORY = '/storage/ndong/data/cifar/cifar100'
DATA_NAME = 'cifar100'
DROP_RATE = 0
GROWTH_RATE = 32
IGNORE_LABEL = 255
INPUT_SIZE = 224
IS_BOTTLENECK = 'True'
IS_TRAINING = 'True'
LEARNING_RATE = 1e-2
MOMENTUM = 0.9
NUM_CLASSES = 10
NUM_GPUS = 1
NUM_LAYERS = 121
NUM_STEPS = 600000
POWER = 0.9
RANDOM_SEED = 1234
RESTORE_FROM = None
SAVE_NUM_IMAGES = 1
SAVE_PRED_EVERY = 1000
SNAPSHOT_DIR = './' + DATA_NAME
SPLIT_NAME = 'train'
WEIGHT_DECAY = 1e-4
parser = argparse.ArgumentParser(description="DenseNet")
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE,
help="Number of images sent to the network in one step.")
parser.add_argument("--compress-ratio", type=float, default=COMPRESS_RATIO,
help="Compression factor in DenseNet-C.")
parser.add_argument("--ckpt", type=int, default=CKPT,
help="Checkpoint to restore.")
parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY,
help="Path to the directory containing the cifar10 dataset.")
parser.add_argument("--data-name", type=str, default=DATA_NAME,
help="Name of the dataset.")
parser.add_argument("--drop-rate", type=float, default=DROP_RATE,
help="Dropout rate in DenseNet unit.")
parser.add_argument("--freeze-bn", action="store_true",
help="Whether to freeze batch norm params.")
parser.add_argument("--growth-rate", type=float, default=GROWTH_RATE,
help="Dropout rate in DenseNet unit and transition layer.")
parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL,
help="The index of the label to ignore during the training.")
parser.add_argument("--input-size", type=int, default=INPUT_SIZE,
help="height and width of images.")
parser.add_argument("--is-bottleneck", type=str, default=IS_BOTTLENECK,
help="Whether to use bottleneck layer in DenseNet-B.")
parser.add_argument("--is-training", type=str, default=IS_TRAINING,
help="Whether to updates the running means and variances during the training.")
parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE,
help="Base learning rate for training with polynomial decay.")
parser.add_argument("--momentum", type=float, default=MOMENTUM,
help="Momentum component of the optimiser.")
parser.add_argument("--not-restore-last", action="store_true",
help="Whether to not restore last (FC) layers.")
parser.add_argument("--num-classes", type=int, default=NUM_CLASSES,
help="Number of classes to predict (including background).")
parser.add_argument("--num-gpus", type=int, default=NUM_GPUS,
help="Number of GPUs to use.")
parser.add_argument("--num-layers", type=int, default=NUM_LAYERS,
help="Number of layes in DenseNet).")
parser.add_argument("--num-steps", type=int, default=NUM_STEPS,
help="Number of training steps.")
parser.add_argument("--power", type=float, default=POWER,
help="Decay parameter to compute the learning rate.")
parser.add_argument("--random-mirror", action="store_true",
help="Whether to randomly mirror the inputs during the training.")
parser.add_argument("--random-scale", action="store_true",
help="Whether to randomly scale the inputs during the training.")
parser.add_argument("--random-seed", type=int, default=RANDOM_SEED,
help="Random seed to have reproducible results.")
parser.add_argument("--restore-from", type=str, default=RESTORE_FROM,
help="Where restore model parameters from.")
parser.add_argument("--save-num-images", type=int, default=SAVE_NUM_IMAGES,
help="How many images to save.")
parser.add_argument("--save-pred-every", type=int, default=SAVE_PRED_EVERY,
help="Save summaries and checkpoint every often.")
parser.add_argument("--snapshot-dir", type=str, default=SNAPSHOT_DIR,
help="Where to save snapshots of the model.")
parser.add_argument("--split-name", type=str, default=SPLIT_NAME,
help="Split name.")
parser.add_argument("--weight-decay", type=float, default=WEIGHT_DECAY,
help="Regularisation parameter for L2-loss.")
args = parser.parse_args()
args.is_training = args.is_training.strip() == 'True'