-
Notifications
You must be signed in to change notification settings - Fork 2
/
cifar-10.py
55 lines (37 loc) · 1.27 KB
/
cifar-10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import numpy as np
import os, sys, cPickle
from invalidizer.invalidizer import *
def load_batch(fpath, label_key='labels'):
f = open(fpath, 'rb')
d = cPickle.load(f)
f.close()
data = d["data"]
labels = d[label_key]
data = data.reshape(data.shape[0], 3, 32, 32)
return data, labels
def load_data(path):
dirname = "cifar-10-batches-py"
nb_test_samples = 10000
nb_train_samples = 50000
X_train = np.zeros((nb_train_samples, 3, 32, 32), dtype="uint8")
y_train = np.zeros((nb_train_samples,), dtype="uint8")
for i in range(1, 6):
fpath = os.path.join(path, 'data_batch_' + str(i))
data, labels = load_batch(fpath)
X_train[(i-1)*10000:i*10000, :, :, :] = data
y_train[(i-1)*10000:i*10000] = labels
fpath = os.path.join(path, 'test_batch')
X_test, y_test = load_batch(fpath)
y_train = np.reshape(y_train, (len(y_train), 1))
y_test = np.reshape(y_test, (len(y_test), 1))
return (X_train, y_train), (X_test, y_test)
path = './datasets/cifar-10-batches-py'
(X_train, y_train), (X_test, y_test) = load_data(path)
test = X_train[0172].swapaxes(0,2)
import matplotlib.pyplot as plt
plt.ion()
plt.imshow(test)
raw_input()
inval = invalidizer(test, 4, 8)
plt.imshow(inval)
raw_input()