-
Notifications
You must be signed in to change notification settings - Fork 8
/
arguments.py
99 lines (77 loc) · 3.17 KB
/
arguments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import argparse
import os
import torch
import numpy as np
import torch
import random
import re
import yaml
import shutil
import warnings
from datetime import datetime
class Namespace(object):
def __init__(self, somedict):
for key, value in somedict.items():
assert isinstance(key, str) and re.match("[A-Za-z_-]", key)
if isinstance(value, dict):
self.__dict__[key] = Namespace(value)
else:
self.__dict__[key] = value
def __getattr__(self, attribute):
raise AttributeError(f"Can not find {attribute} in namespace. Please write {attribute} in your config file(xxx.yaml)!")
def set_deterministic(seed):
# seed by default is None
if seed is not None:
print(f"Deterministic with seed = {seed}")
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config-file', required=True, type=str, help="xxx.yaml")
parser.add_argument('--debug', action='store_true')
parser.add_argument('--debug_subset_size', type=int, default=8)
parser.add_argument('--download', action='store_true', help="if can't find dataset, download from web")
parser.add_argument('--data_dir', type=str, default=os.getenv('DATA'))
parser.add_argument('--log_dir', type=str, default=os.getenv('LOG'))
parser.add_argument('--ckpt_dir', type=str, default=os.getenv('CHECKPOINT'))
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--eval_from', type=str, default=None)
parser.add_argument('--hide_progress', action='store_true')
args = parser.parse_args()
with open(args.config_file, 'r') as f:
for key, value in Namespace(yaml.load(f, Loader=yaml.FullLoader)).__dict__.items():
vars(args)[key] = value
if args.debug:
if args.train:
args.train.batch_size
args.train.num_epochs = 1
args.train.stop_at_epoch = 1
if args.eval:
args.eval.batch_size = 2
args.eval.num_epochs = 1 # train only one epoch
args.dataset.num_workers = 0
assert not None in [args.log_dir, args.data_dir, args.ckpt_dir, args.name]
args.log_dir = os.path.join(args.log_dir, 'in-progress_'+datetime.now().strftime('%m%d%H%M%S_')+args.name)
os.makedirs(args.log_dir, exist_ok=False)
print(f'creating file {args.log_dir}')
os.makedirs(args.ckpt_dir, exist_ok=True)
shutil.copy2(args.config_file, args.log_dir)
set_deterministic(args.seed)
vars(args)['aug_kwargs'] = {
'name':args.model.name,
'image_size': args.dataset.image_size
}
vars(args)['dataset_kwargs'] = {
'download':args.download,
'debug_subset_size': args.debug_subset_size if args.debug else None,
}
vars(args)['dataloader_kwargs'] = {
'drop_last': True,
'pin_memory': True,
'num_workers': args.dataset.num_workers,
}
return args