-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathutil.py
52 lines (42 loc) · 1.37 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
class AverageMeter(object):
r"""
Computes and stores the average and current value.
Adapted from
https://github.com/pytorch/examples/blob/ec10eee2d55379f0b9c87f4b36fcf8d0723f45fc/imagenet/main.py#L359-L380
"""
def __init__(self, name=None, fmt='.6f'):
fmtstr = f'{{val:{fmt}}} ({{avg:{fmt}}})'
if name is not None:
fmtstr = name + ' ' + fmtstr
self.fmtstr = fmtstr
self.reset()
def reset(self):
self.val = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
@property
def avg(self):
avg = self.sum / self.count
if isinstance(avg, torch.Tensor):
avg = avg.item()
return avg
def __str__(self):
val = self.val
if isinstance(val, torch.Tensor):
val = val.item()
return self.fmtstr.format(val=val, avg=self.avg)
class TwoAugUnsupervisedDataset(torch.utils.data.Dataset):
r"""Returns two augmentation and no labels."""
def __init__(self, dataset, transform):
self.dataset = dataset
self.transform = transform
def __getitem__(self, index):
image, _ = self.dataset[index]
return self.transform(image), self.transform(image)
def __len__(self):
return len(self.dataset)