-
Notifications
You must be signed in to change notification settings - Fork 22
/
rf_predict.py
51 lines (40 loc) · 1.63 KB
/
rf_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import rpy2.robjects as robjects
import convert_to_r_fmt
import game
import pymongo
import random
class WinPredictor:
def __init__(self):
# this is expensive, it loads the decision tree into r's environment.
robjects.r('library(randomForest)')
robjects.r('load("r5_100kn_500t.rt")')
# pass
def predict_all_turns(self, game_val):
encoded_states = []
for game_state in game_val.game_state_iterator():
encoded_state = (
convert_to_r_fmt.encode_state_r_fmt(game_val, game_state)
# convert the data into this funny column vector format so that we
# can turn it into a data frame.
labelled_column_vecs = {}
for ind, header_val in enumerate(convert_to_r_fmt.make_header()):
labelled_column_vecs[header_val] = []
for encoded_state in encoded_states:
labelled_column_vecs[header_val].append(encoded_state[ind])
for k in labelled_column_vecs:
labelled_column_vecs[k] = robjects.FloatVector(
labelled_column_vecs[k])
turn_data = robjects.DataFrame(labelled_column_vecs)
var_name = 'rf_pred_dataframe' + str(random.randint(0, 1000000))
robjects.globalenv[var_name] = turn_data
predict_cmd = 'predict(r5, %s)' % var_name
predictions = list(robjects.r(predict_cmd))
robjects.r.rm(var_name)
return predictions
def main():
c = pymongo.Connection()
g = game.Game(c.test.games.find_one())
predictor = WinPredictor()
print predictor.predict_all_turns(g)
if __name__ == '__main__':
main()