forked from davikawasaki/machine-learning-snippets
-
Notifications
You must be signed in to change notification settings - Fork 0
/
split-supervised-learning.py
executable file
·136 lines (113 loc) · 3.76 KB
/
split-supervised-learning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/python
import csv, sys
class SplitBase:
def __init__( self, reader ):
self.reader = reader
# Split and order each truelabel rows in grups
# Has truelabel group, truelabel index, truelabel object.
#
# @param self
# @param trueLabel
#
# @return trueLabelGroup
def split( self, trueLabel ):
tlgb = {}
for n, row in enumerate(self.reader):
# Save header row
if n == 0:
self.header = row
tli = self.getTrueLabelIndex(self.header, trueLabel)
if(tli == -1):
raise Exception('There is not any truelabel as ' + trueLabel)
else:
self.tli = tli
# Get other rows
else:
tlo = row[self.tli]
if tlo not in tlgb:
tlgb[tlo] = list()
del(row[self.tli])
tlgb[tlo].append(row)
return tlgb
# Split and order each truelabel rows in grups
# Has truelabel group, truelabel index, truelabel object.
#
# @param self
# @param trueLabelGroupedBy
# @param trainingRate
# @param dataFileName
#
# @return Boolean
def extractSaveBaseRates( self, tlgb, trRate, dfn ):
if not tlgb:
return False
if not trRate:
return False
else:
trGroup = []
ttGroup = []
trRate = trRate/100.0
for key, arrays in tlgb.items():
n = len(arrays)
trQty = int(n*trRate)
for i, item in enumerate(arrays):
if i < trQty-1:
trGroup.append(item)
else:
ttGroup.append(item)
# Generate training sample file
trOutFile = open('data/output/' + dfn + '_training' + '.csv', "wb")
writer = csv.writer(trOutFile, delimiter=';')
# Header as the first row to save
writer.writerow(self.header)
# Save other rows
for item in trGroup:
writer.writerow(item)
trOutFile.close()
# Generate testing sample file
trOutFile = open('data/output/' + dfn + '_testing' + '.csv', "wb")
writer = csv.writer(trOutFile, delimiter=';')
# Header as the first row to save
writer.writerow(self.header)
# Save other rows
for item in ttGroup:
writer.writerow(item)
trOutFile.close()
return True
# Return true label column index inside csv file
#
# @param self
# @param header
# @param trueLabel
#
# @return index (-1 for non-existent)
def getTrueLabelIndex ( self, header, trueLabel ):
colnum = 0
for col in header:
if col == trueLabel:
return colnum
return -1
################
# Main process #
################
# truelabel and datafile
tl = sys.argv[1]
df = sys.argv[2]
# training and testing rates
trRate = int(sys.argv[3])
ttRate = int(sys.argv[4])
if((trRate + ttRate) != 100):
print 'Training and Testing rates must complete 100%! Try it again.'
elif(ttRate > trRate):
print 'Testing rate must not be bigger than Training rate! Try it again.'
else:
idf = open('data/raw/' + df + '.csv', 'rU')
reader = csv.reader(idf, delimiter=';')
splitBase = SplitBase(reader)
tlgb = splitBase.split(tl)
#print tlgb
if splitBase.extractSaveBaseRates(tlgb, trRate, df):
print 'Successful extraction of ' + df + ' file!'
else:
print 'There was an error in the extraction of ' + df + ' file!'
idf.close()