-
Notifications
You must be signed in to change notification settings - Fork 25
/
metric_calculators.py
137 lines (103 loc) · 4 KB
/
metric_calculators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
from abc import ABC, abstractmethod
import pdb
class MetricCalculator(ABC):
@abstractmethod
def update(self, batch_size, dx, *feats, **aux_params): return NotImplemented
@abstractmethod
def finalize(self): return NotImplemented
def compute_correlation(covariance, eps=1e-7):
std = torch.diagonal(covariance).sqrt()
covariance = covariance / (torch.clamp(torch.outer(std, std), min=eps))
return covariance
class CovarianceMetric(MetricCalculator):
name = 'covariance'
def __init__(self):
self.std = None
self.mean = None
self.outer = None
def update(self, batch_size, *feats, **aux_params):
feats = torch.cat(feats, dim=0)
feats = torch.nan_to_num(feats, 0, 0, 0)
std = feats.std(dim=1)
mean = feats.mean(dim=1)
outer = (feats @ feats.T) / feats.shape[1]
if self.mean is None: self.mean = torch.zeros_like( mean)
if self.outer is None: self.outer = torch.zeros_like(outer)
if self.std is None: self.std = torch.zeros_like( std)
self.mean += mean * batch_size
self.outer += outer * batch_size
self.std += std * batch_size
def finalize(self, numel, eps=1e-4):
self.outer /= numel
self.mean /= numel
self.std /= numel
cov = self.outer - torch.outer(self.mean, self.mean)
if torch.isnan(cov).any():
breakpoint()
if (torch.diagonal(cov) < 0).sum():
pdb.set_trace()
return cov
class CorrelationMetric(MetricCalculator):
name = 'correlation'
def __init__(self):
self.std = None
self.mean = None
self.outer = None
def update(self, batch_size, dx, *feats, **aux_params):
feats = torch.cat(feats, dim=0)
std = feats.std(dim=1)
mean = feats.mean(dim=1)
outer = (feats @ feats.T) / feats.shape[1]
if self.std is None: self.std = torch.zeros_like( std)
if self.mean is None: self.mean = torch.zeros_like( mean)
if self.outer is None: self.outer = torch.zeros_like(outer)
self.std += std * dx
self.mean += mean * dx
self.outer += outer * dx
def finalize(self, eps=1e-4):
corr = self.outer - torch.outer(self.mean, self.mean)
corr /= (torch.outer(self.std, self.std) + eps)
return corr
class CossimMetric(MetricCalculator):
name = 'cossim'
def __init__(self):
self.std = None
self.mean = None
self.outer = None
def update(self, batch_size, dx, *feats, **aux_params):
feats = torch.cat(feats, dim=0)
feats = feats.view(feats.shape[0], -1, batch_size)
feats = feats / feats.norm(dim=-1, keepdim=True)
feats = feats.view(feats.shape[0], -1)
outer = (feats @ feats.T) / (feats.shape[1] // batch_size)
if self.outer is None: self.outer = torch.zeros_like(outer)
self.outer += outer * dx
def finalize(self, eps=1e-4):
return self.outer
class MeanMetric(MetricCalculator):
name = 'mean'
def __init__(self):
self.mean = None
def update(self, batch_size, *feats, **aux_params):
feats = torch.cat(feats, dim=0)
mean = feats.abs().mean(dim=1)
if self.mean is None:
self.mean = torch.zeros_like(mean)
self.mean += mean * batch_size
def finalize(self, numel, eps=1e-4):
return self.mean / numel
def get_metric_fns(names):
metrics = {}
for name in names:
if name == 'mean':
metrics[name] = MeanMetric
elif name == 'covariance':
metrics[name] = CovarianceMetric
elif name == 'correlation':
metrics[name] = CorrelationMetric
elif name == 'cossim':
metrics[name] = CossimMetric
else:
raise NotImplementedError(name)
return metrics