-
Notifications
You must be signed in to change notification settings - Fork 2
/
data.py
88 lines (76 loc) · 2.94 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
from torchvision import transforms
import numpy as np
class SingleObjectData(torch.utils.data.Dataset):
def __init__(self, transform=None):
self.transform = transform
self.observation = torch.load("data/img/obs_prev_z.pt").unsqueeze(1)
self.action = torch.load("data/img/action.pt")
self.effect = torch.load("data/img/delta_pix_1.pt")
self.eff_mu = self.effect.mean(dim=0)
self.eff_std = self.effect.std(dim=0)
self.effect = (self.effect - self.eff_mu) / (self.eff_std + 1e-6)
def __len__(self):
return len(self.observation)
def __getitem__(self, idx):
sample = {}
sample["observation"] = self.observation[idx]
sample["effect"] = self.effect[idx]
sample["action"] = self.action[idx]
if self.transform:
sample["observation"] = self.transform(self.observation[idx])
return sample
class PairedObjectData(torch.utils.data.Dataset):
def __init__(self, transform=None):
self.transform = transform
self.train = True
self.observation = torch.load("data/img/obs_prev_z.pt")
self.observation = self.observation.reshape(5, 10, 3, 4, 4, 42, 42)
self.observation = self.observation[:, :, 0]
self.effect = torch.load("data/img/delta_pix_3.pt")
self.effect = self.effect.abs()
self.eff_mu = self.effect.mean(dim=0)
self.eff_std = self.effect.std(dim=0)
self.effect = (self.effect - self.eff_mu) / (self.eff_std + 1e-6)
def __len__(self):
return len(self.effect)
def __getitem__(self, idx):
sample = {}
obj_i = idx // 500
size_i = (idx // 50) % 10
obj_j = (idx // 10) % 5
size_j = idx % 10
if self.train:
ix = np.random.randint(0, 4)
iy = np.random.randint(0, 4)
jx = np.random.randint(0, 4)
jy = np.random.randint(0, 4)
else:
ix, iy, jx, jy = 2, 2, 2, 2
img_i = self.observation[obj_i, size_i, ix, iy]
img_j = self.observation[obj_j, size_j, jx, jy]
if self.transform:
img_i = self.transform(img_i)
img_j = self.transform(img_j)
sample["observation"] = torch.cat([img_i, img_j])
else:
sample["observation"] = torch.stack([img_i, img_j])
sample["effect"] = self.effect[idx]
return sample
def default_transform(size, affine, mean=None, std=None):
transform = [transforms.ToPILImage()]
if size:
transform.append(transforms.Resize(size))
if affine:
transform.append(
transforms.RandomAffine(
degrees=0,
translate=(0.1, 0.1),
fill=int(0.285*255)
)
)
transform.append(transforms.ToTensor())
if mean is not None:
transform.append(transforms.Normalize([mean], [std]))
transform = transforms.Compose(transform)
return transform