-
-
Notifications
You must be signed in to change notification settings - Fork 1k
/
utils.py
116 lines (101 loc) · 4.61 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import torch
import torch.nn as nn
from torch.nn.utils import fuse_conv_bn_eval
class CrossEntropyLabelSmooth(nn.Module):
"""Cross entropy loss with label smoothing regularizer.
Reference:
Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
Equation: y = (1 - epsilon) * y + epsilon / K.
Args:
epsilon (float): weight.
"""
def __init__(self, epsilon=0.1, use_gpu=True):
super(CrossEntropyLabelSmooth, self).__init__()
self.epsilon = epsilon
self.use_gpu = use_gpu
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, inputs, targets):
"""
Args:
inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
targets: ground truth labels with shape (num_classes)
"""
_, num_classes = inputs.shape
log_probs = self.logsoftmax(inputs)
targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
if self.use_gpu: targets = targets.cuda()
targets = (1 - self.epsilon) * targets + self.epsilon / num_classes
loss = (- targets * log_probs).mean(0).sum()
return loss
def fuse_all_conv_bn(model):
stack = []
for name, module in model.named_children():
if list(module.named_children()):
fuse_all_conv_bn(module)
if isinstance(module, nn.BatchNorm2d):
if not stack:
continue
if isinstance(stack[-1][1], nn.Conv2d):
setattr(model, stack[-1][0], fuse_conv_bn_eval(stack[-1][1], module))
setattr(model, name, nn.Identity())
else:
stack.append((name, module))
return model
def save_network(network, dirname, epoch_label, local_rank=-1):
if isinstance(epoch_label, int):
save_filename = 'net_%03d.pth'% epoch_label
else:
save_filename = 'net_%s.pth'% epoch_label
save_path = os.path.join('./model',dirname,save_filename)
if local_rank>-1:
if local_rank == 0: # save the main process model
torch.save(network.state_dict(), save_path)
network.cuda(local_rank)
else:
torch.save(network.state_dict(), save_path)
network.cuda()
def load_state_dict_mute(self, state_dict: 'OrderedDict[str, Tensor]',
strict: bool = True):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants. If :attr:`strict` is ``True``, then
the keys of :attr:`state_dict` must exactly match the keys returned
by this module's :meth:`~torch.nn.Module.state_dict` function.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
Returns:
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
* **missing_keys** is a list of str containing the missing keys
* **unexpected_keys** is a list of str containing the unexpected keys
"""
missing_keys: List[str] = []
unexpected_keys: List[str] = []
error_msgs: List[str] = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
# mypy isn't aware that "_metadata" exists in state_dict
state_dict._metadata = metadata # type: ignore[attr-defined]
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(self)
del load
if strict:
if len(unexpected_keys) > 0:
error_msgs.insert(
0, 'Unexpected key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in unexpected_keys)))
if len(missing_keys) > 0:
error_msgs.insert(
0, 'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys)))