-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
76 lines (52 loc) · 1.43 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import shutil
import time
import pprint
import torch
def set_gpu(x):
os.environ['CUDA_VISIBLE_DEVICES'] = x
print('using gpu:', x)
def ensure_path(path):
if os.path.exists(path):
if input('{} exists, remove? ([y]/n)'.format(path)) != 'n':
shutil.rmtree(path)
os.mkdir(path)
else:
os.mkdir(path)
class Averager():
def __init__(self):
self.n = 0
self.v = 0
def add(self, x):
self.v = (self.v * self.n + x) / (self.n + 1)
self.n += 1
def item(self):
return self.v
def count_acc(logits, label):
pred = torch.argmax(logits, dim=1)
return (pred == label).type(torch.cuda.FloatTensor).mean().item()
def dot_metric(a, b):
return torch.mm(a, b.t())
def euclidean_metric(a, b):
n = a.shape[0]
m = b.shape[0]
a = a.unsqueeze(1).expand(n, m, -1)
b = b.unsqueeze(0).expand(n, m, -1)
logits = -((a - b)**2).sum(dim=2)
return logits
class Timer():
def __init__(self):
self.o = time.time()
def measure(self, p=1):
x = (time.time() - self.o) / p
x = int(x)
if x >= 3600:
return '{:.1f}h'.format(x / 3600)
if x >= 60:
return '{}m'.format(round(x / 60))
return '{}s'.format(x)
_utils_pp = pprint.PrettyPrinter()
def pprint(x):
_utils_pp.pprint(x)
def l2_loss(pred, label):
return ((pred - label)**2).sum() / len(pred) / 2