-
Notifications
You must be signed in to change notification settings - Fork 39
/
disout.py
109 lines (79 loc) · 4.35 KB
/
disout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
#This program is free software; you can redistribute it and/or modify it under the terms of the BSD 3-Clause License.
#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 3-Clause License for more details.
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
class Disout(nn.Module):
"""
Beyond Dropout: Feature Map Distortion to Regularize Deep Neural Networks
https://arxiv.org/abs/2002.11022
Args:
dist_prob (float): probability of an element to be distorted.
block_size (int): size of the block to be distorted.
alpha: the intensity of distortion.
Shape:
- Input: `(N, C, H, W)`
- Output: `(N, C, H, W)`
"""
def __init__(self, dist_prob, block_size=6,alpha=1.0):
super(Disout, self).__init__()
self.dist_prob = dist_prob
self.weight_behind=None
self.alpha=alpha
self.block_size = block_size
def forward(self, x):
if not self.training:
return x
else:
x=x.clone()
if x.dim()==4:
width=x.size(2)
height=x.size(3)
seed_drop_rate = self.dist_prob* (width*height) / self.block_size**2 / (( width -self.block_size + 1)*( height -self.block_size + 1))
valid_block_center=torch.zeros(width,height,device=x.device).float()
valid_block_center[int(self.block_size // 2):(width - (self.block_size - 1) // 2),int(self.block_size // 2):(height - (self.block_size - 1) // 2)]=1.0
valid_block_center=valid_block_center.unsqueeze(0).unsqueeze(0)
randdist = torch.rand(x.shape,device=x.device)
block_pattern = ((1 -valid_block_center + float(1 - seed_drop_rate) + randdist) >= 1).float()
if self.block_size == width and self.block_size == height:
block_pattern = torch.min(block_pattern.view(x.size(0),x.size(1),x.size(2)*x.size(3)),dim=2)[0].unsqueeze(-1).unsqueeze(-1)
else:
block_pattern = -F.max_pool2d(input=-block_pattern, kernel_size=(self.block_size, self.block_size), stride=(1, 1), padding=self.block_size // 2)
if self.block_size % 2 == 0:
block_pattern = block_pattern[:, :, :-1, :-1]
percent_ones = block_pattern.sum() / float(block_pattern.numel())
if not (self.weight_behind is None) and not(len(self.weight_behind)==0):
wtsize=self.weight_behind.size(3)
weight_max=self.weight_behind.max(dim=0,keepdim=True)[0]
sig=torch.ones(weight_max.size(),device=weight_max.device)
sig[torch.rand(weight_max.size(),device=sig.device)<0.5]=-1
weight_max=weight_max*sig
weight_mean=weight_max.mean(dim=(2,3),keepdim=True)
if wtsize==1:
weight_mean=0.1*weight_mean
#print(weight_mean)
mean=torch.mean(x).clone().detach()
var=torch.var(x).clone().detach()
if not (self.weight_behind is None) and not(len(self.weight_behind)==0):
dist=self.alpha*weight_mean*(var**0.5)*torch.randn(*x.shape,device=x.device)
else:
dist=self.alpha*0.01*(var**0.5)*torch.randn(*x.shape,device=x.device)
x=x*block_pattern
dist=dist*(1-block_pattern)
x=x+dist
x=x/percent_ones
return x
class LinearScheduler(nn.Module):
def __init__(self, disout, start_value, stop_value, nr_steps):
super(LinearScheduler, self).__init__()
self.disout = disout
self.i = 0
self.drop_values = np.linspace(start=start_value, stop=stop_value, num=nr_steps)
def forward(self, x):
return self.disout(x)
def step(self):
if self.i < len(self.drop_values):
self.disout.dist_prob = self.drop_values[self.i]
self.i += 1