-
Notifications
You must be signed in to change notification settings - Fork 3
/
datatypes.py
64 lines (52 loc) · 2.2 KB
/
datatypes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
try:
import pandas as pd
except ImportError:
pass
class Data:
def __init__(self):
pass
class DenseData(Data):
def __init__(self, data, group_names, *args):
self.groups = args[0] if len(args) > 0 and args[0] != None else [np.array([i]) for i in range(len(group_names))]
l = sum(len(g) for g in self.groups)
num_samples = data.shape[0]
t = False
if l != data.shape[1]:
t = True
num_samples = data.shape[1]
valid = (not t and l == data.shape[1]) or (t and l == data.shape[0])
assert valid, "# of names must match data matrix!"
self.weights = args[1] if len(args) > 1 else np.ones(num_samples)
self.weights /= np.sum(self.weights)
wl = len(self.weights)
valid = (not t and wl == data.shape[0]) or (t and wl == data.shape[1])
assert valid, "# weights must match data matrix!"
self.transposed = t
self.group_names = group_names
self.data = data
class DenseDataWithIndex(DenseData):
def __init__(self, data, group_names, index, index_name, *args):
DenseData.__init__(self, data, group_names, *args)
self.index_value = index
self.index_name = index_name
def convert_to_df(self):
data = pd.DataFrame(self.data, columns=self.group_names)
index = pd.DataFrame(self.index_value, columns=[self.index_name])
df = pd.concat([index, data], axis=1)
df = df.set_index(self.index_name)
return df
def convert_to_data(val, keep_index=False):
if isinstance(val, Data):
return val
elif type(val) == np.ndarray:
return DenseData(val, [str(i) for i in range(val.shape[1])])
elif str(type(val)).endswith("'pandas.core.series.Series'>"):
return DenseData(val.as_matrix().reshape((1,len(val))), list(val.index))
elif str(type(val)).endswith("'pandas.core.frame.DataFrame'>"):
if keep_index:
return DenseDataWithIndex(val.as_matrix(), list(val.columns), val.index.values, val.index.name)
else:
return DenseData(val.as_matrix(), list(val.columns))
else:
assert False, "Unknown type passed as data object: "+str(type(val))