-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnikhil_mean_model.py
66 lines (50 loc) · 1.99 KB
/
nikhil_mean_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from collections import defaultdict
import numpy as np
class MeanModel():
'''Creates benchmark model that uses mean truth values for prediciton.
For a given drug cell lines pair, d1 c1 in the test set the model predicts
the mean truth value of all drug cell line pairs in the traning set
that include d1.
All drugs need to be in the traning set.
Inputs
-----
y_train: pd series or dataframe
traning truth values were index of the df gives cell_line drug
seprated by the string :: e.g. for d1 and cl1 index = 'd1::cl1'
all_drugs: list like
gives all drug names
Methods
------
predict(y_index): gives models prediciton for cell line drug pairs
replace_nan(re='mean') replace missing values with re
'''
def __init__(self, y_train, drugs, verb=1):
model = defaultdict(list)
#group cls by drugs
self.verb = verb
for ind, val in y_train.items():
cl, d = ind.split('::')
model[d].append(val)
#take average of all values for a given drug
for d in drugs:
model[d] = np.mean(model[d])
self.model = model
def predict(self, y_index, reformat=True):
#reformat index to get just drug
if reformat:
y_index = [y.split('::')[1] for y in y_index]
return np.array([self.model[y] for y in y_index])
def replace_nan(self, re='mean'):
#replace nan's with re, deflat re=0
num_nan = 0
for k in self.model:
if np.isnan(self.model[k]):
num_nan += 1
if re=='mean':
vals = np.array(list(self.model.values()))
vals = vals[~np.isnan(vals)]
self.model[k] = vals.mean()
else:
self.model[k] = re
if self.verb > 0:
print(f'{num_nan} nan values replaced out of {len(self.model)}')