-
Notifications
You must be signed in to change notification settings - Fork 2
/
smart.py
190 lines (161 loc) · 6.38 KB
/
smart.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from argparse import ArgumentParser, Namespace
from smart_compress.util.globals import Globals
from typing import Tuple, Union
import torch
import torch.nn.functional as F
from smart_compress.compress.base import CompressionAlgorithmBase
class SmartFP(CompressionAlgorithmBase):
@staticmethod
def add_argparse_args(parent_parser: ArgumentParser):
parser = ArgumentParser(
parents=[CompressionAlgorithmBase.add_argparse_args(parent_parser)],
add_help=False,
)
parser.add_argument(
"--num_samples",
type=int,
default=16,
help="number of samples to use for mean/std_dev calculation",
)
parser.add_argument(
"--use_sample_stats",
action="store_true",
help="use sample mean and std for smart compression",
)
parser.add_argument(
"--no_stochastic_rounding",
action="store_false",
help="use stochastic rounding when quantizing",
dest="stochastic_rounding",
)
parser.add_argument(
"--num_bits_main",
type=int,
default=6,
help="number of bits for main data (within 1 std dev)",
)
parser.add_argument(
"--num_bits_outlier",
type=int,
default=8,
help="number of bits for outlier data (more than 1 std dev)",
)
parser.add_argument(
"--main_std_dev_threshold",
type=float,
default=1.0,
help="std dev to consider something main",
)
parser.add_argument(
"--outlier_std_dev_threshold",
type=float,
default=2.5,
help="max std dev for outliers (everything else is clamped to this)",
)
parser.add_argument("--min_size", type=int, default=8)
parser.add_argument(
"--use_range_std_dev",
action="store_true",
help="use range std dev (from range batch norm paper)",
)
parser.add_argument(
"--use_batch_norm", action="store_true", help="support BN acceleration"
)
parser.add_argument(
"--bn_scalar_params", action="store_true", help="BN params should be scalar"
)
return parser
def __init__(self, hparams: Namespace):
super().__init__(hparams)
self.range_outlier = ((2 ** (self.hparams.num_bits_outlier - 2)) - 1) / (
self.hparams.outlier_std_dev_threshold - self.hparams.main_std_dev_threshold
)
self.range_normal = (
(2 ** (self.hparams.num_bits_main - 2)) - 1
) / self.hparams.main_std_dev_threshold
self.clamped_range = (
(1e-4, 1e4) if self.hparams.precision == 16 else (1e-38, 1e38)
)
def _get_sample_mean_std(self, data: torch.Tensor):
numel = data.numel()
k = min(numel, self.hparams.num_samples)
sample = data.view(-1)[torch.randperm(numel, device=data.device)[:k]]
return sample.mean(), self._get_std(sample, unbiased=False)
def _round_stochastic(self, data: torch.Tensor):
probs = torch.rand_like(data, device=data.device)
floored_data = data.floor()
fractions = data - floored_data
return floored_data + F.relu((fractions - probs) + 0.5).round()
def _get_std(self, data: torch.Tensor, *args, **kwargs):
if self.hparams.use_range_std_dev:
range_ = data.max() - data.min()
C = 1 / torch.sqrt(
2.0 * torch.log(torch.tensor(data.numel()).type_as(range_))
)
return range_ * C
return data.std(*args, **kwargs)
@torch.no_grad()
def __call__(
self,
data: torch.Tensor,
tag: str = None,
all_positive=False,
batch_norm_stats: Union[Tuple[torch.Tensor, torch.Tensor], None] = None,
**_
):
with Globals.profiler.profile("smaq"):
use_bn = self.hparams.use_batch_norm and batch_norm_stats is not None
numel = data.numel()
orig_size = numel * 32
if numel < self.hparams.min_size:
self.log_ratio(tag, orig_size, 32, 32)
return data
mean, std_dev = (
(data.mean(), self._get_std(data))
if not self.hparams.use_sample_stats
else self._get_sample_mean_std(data)
)
gamma, beta = batch_norm_stats or (
torch.tensor(1.0).type_as(std_dev),
torch.tensor(0.0).type_as(mean),
)
if use_bn and self.hparams.bn_scalar_params:
gamma = gamma.mean()
beta = beta.mean()
if use_bn:
data = (
((data.permute(0, 3, 2, 1).clone() - beta) / gamma)
.permute(0, 3, 2, 1)
.clone()
)
if std_dev == 0: # uniform dataset
std_dev = torch.ones_like(std_dev, device=data.device)
data = (data - mean) / std_dev.clamp(*self.clamped_range)
is_outlier_higher = data > self.hparams.main_std_dev_threshold
is_outlier_lower = data < -self.hparams.main_std_dev_threshold
is_outlier = is_outlier_higher | is_outlier_lower
scalars = (is_outlier_higher * -self.hparams.main_std_dev_threshold) + (
is_outlier_lower * self.hparams.main_std_dev_threshold
)
ranges = torch.where(is_outlier, self.range_outlier, self.range_normal)
data = (data + scalars) * ranges
if self.hparams.stochastic_rounding:
data = self._round_stochastic(data)
else:
data = data.trunc()
data = (data / ranges) - scalars
data = (data * std_dev) + mean
if use_bn:
data = (
((data.permute(0, 3, 2, 1).clone() * gamma) + beta)
.permute(0, 3, 2, 1)
.clone()
)
if all_positive:
data = data.clamp_min(0.0)
new_size = lambda: (
torch.sum(is_outlier) * self.hparams.num_bits_outlier
+ torch.sum(~is_outlier) * self.hparams.num_bits_main
)
self.log_size(tag, orig_size, new_size)
return data