-
Notifications
You must be signed in to change notification settings - Fork 77
/
utils.py
executable file
·177 lines (150 loc) · 6.38 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import os
import warnings
import argparse
from yacs.config import CfgNode
from .defaults import get_cfg_defaults
def load_cfg(args: argparse.Namespace, freeze=True, add_cfg_func=None):
"""Load configurations.
"""
# Set configurations
cfg = get_cfg_defaults()
if add_cfg_func is not None:
add_cfg_func(cfg)
if args.config_base is not None:
cfg.merge_from_file(args.config_base)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
# Overwrite options given configs with higher priority.
if args.inference:
update_inference_cfg(cfg)
overwrite_cfg(cfg, args)
if freeze:
cfg.freeze()
else:
warnings.warn("Configs are mutable during the process, "
"please make sure that is expected.")
return cfg
def save_all_cfg(cfg: CfgNode, output_dir: str):
r"""Save configs in the output directory.
"""
# Save config.yaml in the experiment directory after combine all
# non-default configurations from yaml file and command line.
path = os.path.join(output_dir, "config.yaml")
with open(path, "w") as f:
f.write(cfg.dump())
print("Full config saved to {}".format(path))
def update_inference_cfg(cfg: CfgNode):
r"""Overwrite configurations (cfg) when running mode is inference. Please
note that None type is only supported in YACS>=0.1.8.
"""
# dataset configurations
if cfg.INFERENCE.INPUT_PATH is not None:
cfg.DATASET.INPUT_PATH = cfg.INFERENCE.INPUT_PATH
cfg.DATASET.IMAGE_NAME = cfg.INFERENCE.IMAGE_NAME
cfg.DATASET.OUTPUT_PATH = cfg.INFERENCE.OUTPUT_PATH
if cfg.INFERENCE.PAD_SIZE is not None:
cfg.DATASET.PAD_SIZE = cfg.INFERENCE.PAD_SIZE
if cfg.INFERENCE.IS_ABSOLUTE_PATH is not None:
cfg.DATASET.IS_ABSOLUTE_PATH = cfg.INFERENCE.IS_ABSOLUTE_PATH
if cfg.INFERENCE.DO_CHUNK_TITLE is not None:
cfg.DATASET.DO_CHUNK_TITLE = cfg.INFERENCE.DO_CHUNK_TITLE
if cfg.INFERENCE.DATA_SCALE is not None:
cfg.DATASET.DATA_SCALE = cfg.INFERENCE.DATA_SCALE
# model configurations
if cfg.INFERENCE.INPUT_SIZE is not None:
cfg.MODEL.INPUT_SIZE = cfg.INFERENCE.INPUT_SIZE
if cfg.INFERENCE.OUTPUT_SIZE is not None:
cfg.MODEL.OUTPUT_SIZE = cfg.INFERENCE.OUTPUT_SIZE
# specify feature maps to return as inference time
cfg.MODEL.RETURN_FEATS = cfg.INFERENCE.MODEL_RETURN_FEATS
# output file name(s)
out_name = cfg.INFERENCE.OUTPUT_NAME
name_lst = out_name.split(".")
if cfg.DATASET.DO_CHUNK_TITLE or cfg.INFERENCE.DO_SINGLY:
assert len(name_lst) <= 2, \
"Invalid output file name is given."
if len(name_lst) == 2:
cfg.INFERENCE.OUTPUT_NAME = name_lst[0]
else:
if len(name_lst) == 1:
cfg.INFERENCE.OUTPUT_NAME = name_lst[0] + '.h5'
for topt in cfg.MODEL.TARGET_OPT:
# For multi-class semantic segmentation and quantized distance
# transform, no activation function is applied at the output layer
# during training. For inference where the output is assumed to be
# in (0,1), we apply softmax.
if topt[0] in ['5', '9'] and cfg.MODEL.OUTPUT_ACT == 'none':
cfg.MODEL.OUTPUT_ACT = 'softmax'
break
def overwrite_cfg(cfg: CfgNode, args: argparse.Namespace):
r"""Overwrite some configs given configs or args with higher priority.
"""
# Distributed training:
if args.distributed:
cfg.SYSTEM.DISTRIBUTED = True
cfg.SYSTEM.PARALLEL = 'DDP'
# Update augmentation options when valid masks are specified
if cfg.DATASET.VALID_MASK_NAME is not None:
assert cfg.DATASET.LABEL_NAME is not None, \
"Using valid mask is only supported when target label is given."
assert cfg.AUGMENTOR.ADDITIONAL_TARGETS_NAME is not None
assert cfg.AUGMENTOR.ADDITIONAL_TARGETS_TYPE is not None
cfg.AUGMENTOR.ADDITIONAL_TARGETS_NAME += ['valid_mask']
cfg.AUGMENTOR.ADDITIONAL_TARGETS_TYPE += ['mask']
# Model I/O size
for x in cfg.MODEL.INPUT_SIZE:
if x % 2 == 0 and not cfg.MODEL.POOLING_LAYER:
warnings.warn(
"When downsampling by stride instead of using pooling "
"layers, the cfg.MODEL.INPUT_SIZE are expected to contain "
"numbers of 2n+1 to avoid feature mis-matching, "
"but get {}".format(cfg.MODEL.INPUT_SIZE))
break
if x % 2 == 1 and cfg.MODEL.POOLING_LAYER:
warnings.warn(
"When downsampling by pooling layers the cfg.MODEL.INPUT_SIZE "
"are expected to contain even numbers to avoid feature mis-matching, "
"but get {}".format(cfg.MODEL.INPUT_SIZE))
break
# Mixed-precision training (only works with DDP)
cfg.MODEL.MIXED_PRECESION = (
cfg.MODEL.MIXED_PRECESION and args.distributed)
# Scaling factors for image, label and valid mask
if cfg.DATASET.IMAGE_SCALE is None:
cfg.DATASET.IMAGE_SCALE = cfg.DATASET.DATA_SCALE
if cfg.DATASET.LABEL_SCALE is None:
cfg.DATASET.LABEL_SCALE = cfg.DATASET.DATA_SCALE
if cfg.DATASET.VALID_MASK_SCALE is None:
cfg.DATASET.VALID_MASK_SCALE = cfg.DATASET.DATA_SCALE
# Disable label reducing for semantic segmentation to avoid class shift
for topt in cfg.MODEL.TARGET_OPT:
if topt[0] == '9': # semantic segmentation mode
cfg.DATASET.REDUCE_LABEL = False
break
def validate_cfg(cfg: CfgNode):
num_target = len(cfg.MODEL.TARGET_OPT)
assert len(cfg.INFERENCE.OUTPUT_ACT) == num_target, \
"Activations need to be specified for each learning target."
def convert_cfg_markdown(cfg):
"""Converts given cfg node to markdown for tensorboard visualization.
"""
r = ""
s = []
def helper(cfg):
s_indent = []
for k, v in sorted(cfg.items()):
seperator = " "
attr_str = " \n{}:{}{} \n".format(str(k), seperator, str(v))
s_indent.append(attr_str)
return s_indent
for k, v in sorted(cfg.items()):
seperator = " " if isinstance(v, str) else " \n"
val = helper(v)
val_str = ""
for line in val:
val_str += line
attr_str = "##{}:{}{} \n".format(str(k), seperator, val_str)
s.append(attr_str)
for line in s:
r += " \n" + line + " \n"
return r