-
Notifications
You must be signed in to change notification settings - Fork 23
/
datasets.py
112 lines (92 loc) · 3.75 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import os
import os.path
import errno
import torch
import json
import codecs
import numpy as np
import progressbar
import sys
import torchvision.transforms as transforms
import argparse
import json
class PartDataset(data.Dataset):
def __init__(self, root, npoints = 2048, classification = False, class_choice = None, train = True):
self.npoints = npoints
self.root = root
self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
self.cat = {}
self.classification = classification
with open(self.catfile, 'r') as f:
for line in f:
ls = line.strip().split()
self.cat[ls[0]] = ls[1]
#print(self.cat)
if not class_choice is None:
self.cat = {k:v for k,v in self.cat.items() if k in class_choice}
self.meta = {}
for item in self.cat:
#print('category', item)
self.meta[item] = []
dir_point = os.path.join(self.root, self.cat[item], 'points')
dir_seg = os.path.join(self.root, self.cat[item], 'points_label')
#print(dir_point, dir_seg)
fns = sorted(os.listdir(dir_point))
if train:
fns = fns[:int(len(fns) * 0.9)]
else:
fns = fns[int(len(fns) * 0.9):]
#print(os.path.basename(fns))
for fn in fns:
token = (os.path.splitext(os.path.basename(fn))[0])
self.meta[item].append((os.path.join(dir_point, token + '.pts'), os.path.join(dir_seg, token + '.seg')))
self.datapath = []
for item in self.cat:
for fn in self.meta[item]:
self.datapath.append((item, fn[0], fn[1]))
self.classes = dict(zip(self.cat, range(len(self.cat))))
print(self.classes)
self.num_seg_classes = 0
if not self.classification:
for i in range(len(self.datapath)/50):
l = len(np.unique(np.loadtxt(self.datapath[i][-1]).astype(np.uint8)))
if l > self.num_seg_classes:
self.num_seg_classes = l
#print(self.num_seg_classes)
def __getitem__(self, index):
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
point_set = np.loadtxt(fn[1]).astype(np.float32)
seg = np.loadtxt(fn[2]).astype(np.int64)
#print(point_set.shape, seg.shape)
point_set = point_set - np.expand_dims(np.mean(point_set, axis = 0), 0)
dist = np.max(np.sqrt(np.sum(point_set ** 2, axis = 1)),0)
dist = np.expand_dims(np.expand_dims(dist, 0), 1)
point_set = point_set/dist
choice = np.random.choice(len(seg), self.npoints, replace=True)
#resample
point_set = point_set[choice, :]
point_set = point_set + 1e-5 * np.random.rand(*point_set.shape)
seg = seg[choice]
point_set = torch.from_numpy(point_set.astype(np.float32))
seg = torch.from_numpy(seg.astype(np.int64))
cls = torch.from_numpy(np.array([cls]).astype(np.int64))
if self.classification:
return point_set, cls
else:
return point_set, seg
def __len__(self):
return len(self.datapath)
if __name__ == '__main__':
print('test')
d = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', class_choice = ['Chair'])
print(len(d))
ps, seg = d[0]
print(ps.size(), ps.type(), seg.size(),seg.type())
d = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True)
print(len(d))
ps, cls = d[0]
print(ps.size(), ps.type(), cls.size(),cls.type())