-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdata.py
executable file
·109 lines (88 loc) · 3.56 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import pickle
import numpy as np
import os
from urllib.request import urlretrieve
import tarfile
import zipfile
import sys
def get_data_set(name="train", cifar=10):
x = None
y = None
l = None
# maybe_download_and_extract()
folder_name = "cifar_10" if cifar == 10 else "cifar_100"
f = open('./data_set/'+folder_name+'/batches.meta', 'rb')
datadict = pickle.load(f, encoding='latin1')
f.close()
l = datadict['label_names']
if name is "train":
for i in range(5):
f = open('./data_set/'+folder_name+'/data_batch_' + str(i + 1), 'rb')
datadict = pickle.load(f, encoding='latin1') #提取数据
f.close()
_X = datadict["data"]
_Y = datadict['labels']
#print('_X')
#print(_X)
#print(np.shape(_X))
_X = np.array(_X, dtype=float) / 255.0
#print(np.shape(_X))
_X = _X.reshape([-1, 3, 32, 32])
#print(np.shape(_X))
_X = _X.transpose([0, 2, 3, 1])#矩阵转置,里面的编号是指将原来的维度变换到当前维度
#例如,原来的2变换到当前1维度
#print(np.shape(_X))
_X = _X.reshape(-1, 32*32*3)
# print(np.shape(_X))
# print(np.shape(_X))
# print(np.shape(_Y))
if x is None:
x = _X
y = _Y
else:
x = np.concatenate((x, _X), axis=0) #将x与读取的_X拼接起来
y = np.concatenate((y, _Y), axis=0)
# print(np.shape(x))
# print(np.shape(y))
elif name is "test":
f = open('./data_set/'+folder_name+'/test_batch', 'rb')
datadict = pickle.load(f, encoding='latin1')
f.close()
x = datadict["data"]
y = np.array(datadict['labels'])
x = np.array(x, dtype=float) / 255.0
x = x.reshape([-1, 3, 32, 32])
x = x.transpose([0, 2, 3, 1])
x = x.reshape(-1, 32*32*3)
def dense_to_one_hot(labels_dense, num_classes=10):
num_labels = labels_dense.shape[0]
index_offset = np.arange(num_labels) * num_classes
labels_one_hot = np.zeros((num_labels, num_classes))
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
return labels_one_hot
return x, dense_to_one_hot(y), l
get_data_set(name="train")
def _print_download_progress(count, block_size, total_size):
pct_complete = float(count * block_size) / total_size
msg = "\r- Download progress: {0:.1%}".format(pct_complete)
sys.stdout.write(msg)
sys.stdout.flush()
def maybe_download_and_extract():
main_directory = "./data_set/"
cifar_10_directory = main_directory+"cifar_10/"
if not os.path.exists(main_directory):
os.makedirs(main_directory)
url = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = url.split('/')[-1]
file_path = os.path.join(main_directory, filename)
zip_cifar_10 = file_path
file_path, _ = urlretrieve(url=url, filename=file_path, reporthook=_print_download_progress)
print()
print("Download finished. Extracting files.")
if file_path.endswith(".zip"):
zipfile.ZipFile(file=file_path, mode="r").extractall(main_directory)
elif file_path.endswith((".tar.gz", ".tgz")):
tarfile.open(name=file_path, mode="r:gz").extractall(main_directory)
print("Done.")
os.rename(main_directory+"./cifar-10-batches-py", cifar_10_directory)
os.remove(zip_cifar_10)