-
Notifications
You must be signed in to change notification settings - Fork 59
/
Copy pathdata_parser.py
41 lines (34 loc) · 1.35 KB
/
data_parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
import csv
from collections import namedtuple
ListDataJpeg = namedtuple('ListDataJpeg', ['id', 'label', 'path'])
class JpegDataset(object):
def __init__(self, csv_path_input, csv_path_labels, data_root):
self.classes = self.read_csv_labels(csv_path_labels)
self.classes_dict = self.get_two_way_dict(self.classes)
self.csv_data = self.read_csv_input(csv_path_input, data_root)
def read_csv_input(self, csv_path, data_root):
csv_data = []
with open(csv_path) as csvfile:
csv_reader = csv.reader(csvfile, delimiter=';')
for row in csv_reader:
item = ListDataJpeg(row[0],
row[1],
os.path.join(data_root, row[0])
)
if row[1] in self.classes:
csv_data.append(item)
return csv_data
def read_csv_labels(self, csv_path):
classes = []
with open(csv_path) as csvfile:
csv_reader = csv.reader(csvfile)
for row in csv_reader:
classes.append(row[0])
return classes
def get_two_way_dict(self, classes):
classes_dict = {}
for i, item in enumerate(classes):
classes_dict[item] = i
classes_dict[i] = item
return classes_dict