-
Notifications
You must be signed in to change notification settings - Fork 2
/
parameters.py
121 lines (99 loc) · 3.46 KB
/
parameters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
__author__ = 'Jiri Fajtl'
__email__ = '[email protected]'
__version__= '1.8'
__status__ = "Research"
__date__ = "2/1/2020"
__license__= "MIT License"
import random
import torch
import json
import numpy as np
import sys
import torch
import json
# ===================================================
class Params:
def __init__(self):
rnd_seed = 12345
random.seed(rnd_seed)
np.random.seed(rnd_seed)
torch.manual_seed(rnd_seed)
torch.cuda.manual_seed(rnd_seed)
self.log_stdout = True
self.use_cuda = True
self.cuda_device = 1 # -1 for all available GPUs
# self.dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.lr = [0.00005]
self.lr_epoch = [0]
self.batch_size = 256 # train batch size
self.epoch_start = None
self.iter_start = -1
self.print_every_batch = 10
self.keep_last_models = 15
def diff(self, params):
''' shows changes between params -> self '''
diffs = []
vars = [attr for attr in dir(self) if not callable(getattr(self,attr)) and not (attr.startswith("__") or attr.startswith("_"))]
info_str = ''
for i, var in enumerate(vars):
val = getattr(self, var)
# if isinstance(val, torch.Tensor):
# val = val.data.cpu().numpy().tolist()[0]
info_str += '['+str(i)+'] '+var+': '+str(val)+'\n'
pval = getattr(params, var)
if pval != val:
diffs.append([var, pval, val])
return diffs
def diff_str(self, params):
diff = self.diff(params)
out = ''
for var, past, current in diff:
out += var+': '+str(past)+' -> '+str(current)+'\n'
return out
def __getattr__(self, item):
return None
def save(self, filename='config.json'):
with open(filename, 'w') as f:
json.dump(self.__dict__, f, indent=3, sort_keys=True)
return
def load(self, filename='config.json'):
try:
with open(filename, 'r') as f:
args = json.load(f)
self.load_from_args(args)
except:
return False
return True
def load_from_args(self, args):
for key in args:
# print(key, args[key])
setattr(self, key, args[key])
def load_from_sys_args(self, args):
args ={}
for i, val in enumerate(sys.argv):
if i == 0: continue
toks = val.split('=')
if len(toks)==1:
val = True
else:
try:
val = float(toks[1])
except:
val = toks[1]
pass
args[toks[0]] = val
self.load_from_args(args)
return
def __str__(self):
# Ignore all functions and variables starting with _ and __
vars = [attr for attr in dir(self) if not callable(getattr(self,attr)) and not (attr.startswith("__") or attr.startswith("_"))]
info_str = ''
for i, var in enumerate(vars):
val = getattr(self, var)
if isinstance(val, torch.Tensor):
val = val.data.cpu().numpy().tolist()[0]
info_str += '['+str(i)+'] '+var+': '+str(val)+'\n'
return info_str
#=================================================================================
if __name__ == "__main__":
print("NOT AN EXECUTABLE!")