-
Notifications
You must be signed in to change notification settings - Fork 0
/
b_good_prediction.py
125 lines (116 loc) · 6.44 KB
/
b_good_prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#This code was developed by SESS in Aarhus University with the support of the B-GOOD project.
#Python3 and Pytorch2 are required to run this code.
import numpy as np
import torch
import torch.nn as nn
# Define model's forward propagation, this might be different for different models
class LSTMModel(nn.Module):
def forward(self, x):
h0 = torch.zeros(self.num_layers, self.hidden_size)
c0 = torch.zeros(self.num_layers, self.hidden_size)
out, _ = self.lstm(x, (h0, c0))
out = self.linear_relu_stack(out)
return out
# Define the predict class
class PredictHSI:
def __init__(self, model_path):
#Although the model is trained with GPU, it can be loaded and used with CPU
self.model = torch.load(model_path, map_location=torch.device('cpu'))
self.model.eval()
# the predict function, it takes a numpy array as input and return a numpy array as output
def predict(self, data):
data_t = torch.from_numpy(data).float()
with torch.no_grad():
pred = self.model(data_t)
pred = pred.numpy()
return pred.flatten()[0]
# This is the main function, it is used to demonstrate how to use the trained model in three steps: 1. load the trained model; 2. get the data; 3. predict the data
if __name__ == "__main__":
# This is an example of how to use the model
# First, load the model
model_path = "./hsi.pt"
model = PredictHSI(model_path)
# Second, get the data, it is a numpy array with a shape of (1, 240), each element is a float number, (current hourly weight - previous hourly weight)/previous hourly weight for 10 days
test_data = np.array([[-6.99018885e-04, 9.79429949e-03, -5.06186113e-03,
-3.18174111e-03, -8.34001880e-03, 1.28073115e-02,
-4.17804392e-03, -9.13545396e-03, -4.12983587e-03,
-1.20761544e-02, 0.00000000e+00, 0.00000000e+00,
-3.48304249e-02, -4.73645553e-02, 1.32411858e-02,
-2.81776935e-02, -8.48464295e-03, -8.09094310e-03,
-3.89682944e-03, 1.84798101e-03, 2.41040994e-04,
-8.13915115e-03, -4.57977923e-03, 2.08902202e-04,
-8.58105998e-03, -4.90116712e-04, 1.51052361e-03,
-6.53221132e-03, -8.07487313e-03, -2.67555518e-03,
-7.69724278e-03, -1.16181765e-02, -1.69130433e-02,
-3.98119390e-02, -6.29840121e-02, -1.10316435e-02,
-6.26867265e-02, 7.63296499e-04, -5.72874118e-03,
-3.61320451e-02, -6.88573811e-03, 1.29760401e-02,
-4.28731591e-02, 3.18174111e-03, -3.19620371e-02,
-1.06861512e-03, -2.39514410e-02, -1.93555914e-02,
-2.49879174e-02, -2.65948568e-03, -2.21677367e-02,
-2.73260139e-02, -3.45652811e-02, 5.87336579e-03,
-4.02940214e-02, -1.53462766e-02, -3.48143540e-02,
-7.81856626e-02, 9.43755880e-02, 3.09657343e-02,
-7.48351961e-02, -8.71764962e-03, -1.84557065e-02,
-2.94230711e-02, -1.41973151e-02, -2.25855410e-02,
-3.35850450e-03, 2.11312599e-03, 1.48641947e-03,
-1.55873178e-03, 7.79365888e-03, 4.54764022e-03,
6.86966861e-03, 3.65578849e-03, 2.48272228e-03,
2.65145092e-03, 1.42133841e-02, -1.45428069e-03,
1.74352992e-02, -2.95676966e-03, -4.83046174e-02,
-2.93909330e-02, -4.87304553e-02, -1.13996327e-01,
-6.55872524e-02, -1.85681917e-02, -2.86356714e-02,
-5.86533081e-03, -8.47660843e-03, -1.02522774e-02,
-6.41972525e-03, 1.10557470e-02, -9.93088912e-03,
1.51052361e-03, 6.02602493e-03, 2.53093056e-03,
3.41474754e-03, 7.17498688e-03, 1.70335639e-03,
3.28619219e-03, 1.68728700e-03, 5.80105325e-03,
1.17306621e-03, -6.12244150e-03, -1.45829804e-02,
-1.70496330e-02, -1.69532176e-03, -7.88204093e-03,
3.67989251e-03, -1.09834345e-02, -1.19074257e-02,
-1.54507281e-02, 2.08661165e-02, -1.85601565e-03,
4.01734986e-04, -4.33070352e-03, -1.61818862e-02,
1.32251158e-02, 1.13128573e-02, -1.09673655e-02,
-7.83383287e-03, -1.63265113e-02, 1.10557470e-02,
-1.27028609e-02, 5.95371285e-03, 4.57977905e-04,
-1.25100277e-02, -2.73179787e-04, -1.05575956e-02,
-2.23284308e-02, -7.12918937e-02, 1.32010123e-02,
3.52160893e-02, -2.00867490e-03, -1.25100277e-02,
0.00000000e+00, 0.00000000e+00, 2.80732419e-02,
-3.14638838e-02, 6.02602493e-03, -1.52659300e-03,
-1.86003298e-02, -1.41008981e-02, -3.70560363e-02,
2.93828975e-02, -1.34982960e-03, -2.04161722e-02,
1.24136116e-02, -9.80233401e-03, -2.03920677e-02,
-1.69130433e-02, 2.20954255e-03, -2.58878041e-02,
-3.52080539e-02, -5.96978189e-03, -1.32090468e-02,
-5.17193638e-02, -5.70865422e-02, 2.10830532e-02,
-3.34805958e-02, 1.22127437e-03, -2.22962927e-02,
2.40317881e-02, -3.58669013e-02, -3.24441195e-02,
8.41233134e-03, -1.15860375e-02, -2.50361245e-02,
-1.60854701e-02, -8.68551061e-03, 1.00353407e-02,
-5.53430133e-02, 6.16582893e-02, 1.11923374e-01,
6.12404831e-02, -7.29872137e-02, -8.63730256e-03,
4.72440338e-03, -6.82306737e-02, -7.07374960e-02,
-8.78674760e-02, -6.16904274e-02, -2.77277492e-02,
-4.13867384e-02, -6.46793330e-03, -1.27751729e-03,
-1.21002579e-02, -9.14348848e-03, -4.58781375e-03,
-1.67202111e-02, -2.18543829e-03, -5.48769999e-03,
-1.68568008e-02, -6.03405992e-03, -1.78531036e-02,
-1.61577817e-02, -1.01960339e-02, -3.51839513e-02,
-1.50168547e-02, 5.62509336e-02, 2.88116306e-01,
1.44166619e-01, -2.31744856e-01, -2.68254519e-01,
-9.32426900e-02, -1.18029743e-01, -2.19909735e-02,
-2.32363530e-02, -2.55021378e-02, -8.04273505e-03,
-4.35480755e-03, -4.88509750e-03, -7.27943843e-03,
-8.72568414e-03, -6.31527416e-03, -6.94198068e-03,
-9.96302813e-04, 3.21387997e-05, -8.05076957e-03,
-9.06314142e-03, -6.25903113e-03, -2.24007443e-02,
-2.70046275e-02, 2.21476510e-01, 9.05350000e-02,
1.40486732e-01, 1.39345795e-01, -2.24377036e-01,
-2.84701556e-01, -9.75412577e-02, -7.13481382e-02,
-6.74995184e-02, 1.10477125e-02, -1.92784593e-01,
5.49011044e-02, 2.86919139e-02, -9.91482008e-03,
-3.19941752e-02, -7.32443258e-02, -2.29551382e-02]])
# Third, predict the data, it returns a number between 0 and 1, the higher the number, the better chance the colony will survive
pred = model.predict(test_data)
print(pred)