-
Notifications
You must be signed in to change notification settings - Fork 2
/
rnn_other.c
341 lines (292 loc) · 16.3 KB
/
rnn_other.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#define FILTER_LENGTH 32;
double dense_bias = 0.03699271;
double dense_kernel[16] = {-0.25279573, -0.5144569, -0.26096338, -0.29290923, -0.30193526, 0.08099364,
-0.49538139, 0.2524238, -0.41640353, 0.44727048, 0.34025627, -0.37097168,
-0.37426633, 0.08937764, -0.27386504, -0.18269601};
double gru_bias[48] = {-0.12932426, -0.0059351 ,-0.31653342, -0.20833671, -0.4036957 , 0.11177188,
-0.03906357, -0.24038778, -0.11646247, -0.15417548, -0.01650343, 0.06402943,
-0.2489842 , -0.6098019 , -0.16329522, -0.0858978 , 0.2464184 , -0.1070224,
0.37133262, 0.10928588, 0.33772758, 0.4648387 , -0.06442235, 0.2987576,
0.35838997, 0.06179264, 0.00630747, 0.44123122, 0.23943591, 0.5250115,
0.20995834, 0.6182523 , 0.00917085, 0.00109357, 0.00355206, 0.00649702,
0.00389399, 0.00447197, 0.00083151, -0.00177034, 0.00446005, -0.0012764,
-0.00684698, 0.00312416, 0.00371838, -0.00848566, 0.00787158, 0.01035452};
double kernel[48] = {0.2622122 , 0.16006444, 0.14029719, 0.02195744, -0.23986852, 0.19578514,
-0.02642826, 0.916514 , -0.41176504, 0.07290944, -0.0103741 , 0.33833575,
0.20360304, 0.10277016, -0.1425647 , -0.4181303 , 0.18768592, -0.05268665,
0.76827174, -0.46825245, -0.5186148 , -1.6043628 , 0.36360577, -0.60572016,
-0.33310828, 0.75808346, 0.22284964, -0.5168066 , 1.0811964 , 0.38681325,
0.7158643 , 0.22535744, -0.25953695, -0.11966277, -0.11444107, -0.16182806,
-0.12450422, 0.03427833, -0.11225897, 0.24530025, -0.11999859, 0.11550533,
0.16286881, -0.13975014, -0.12346029, 0.32939914, -0.18380979, -0.25995687};
double recurrent_kernel[16][48] = {{-0.33324564, 0.18633987, -0.46684438, -0.17672615, -0.13393861, 0.13120326,
0.12043791, -0.19275099, -0.21038288, -0.00805567, -0.2989133 , -0.51597327,
-0.14483815, -0.03519003, -0.14322378, -0.45842144, 0.5420558 , 0.12060637,
0.05337898, 0.0920797 , -0.2292931 , 0.38289747, -0.0041085 , 0.18368538,
-0.16515528, -0.01412096, 0.31565234, 0.23320113, 0.42707068, 0.14022067,
0.10122933, 0.71727544, -0.41066176, -0.07518376, -0.3666031 , -0.20516059,
-0.3480875 , 0.77200353, -0.35870215, 0.28777796, -0.20233625, 0.08825762,
0.42565137, -0.23066425, -0.40517852, -0.0745728 , -0.34185708, -0.46886605},
{-0.34761482, -0.07668183, -0.4548626 , -0.23019359, -0.05764758, 0.02803677,
0.12841158, -0.4026756 , -0.09442581, -0.3402551 , -0.07513218, -0.2086365,
-0.21345004, -0.36617422, -0.26108292, -0.20937753, -0.16408326, 0.15446824,
-0.32785314, 0.13414073, -0.31908432, 0.62030137, 0.02253046, 0.42569885,
-0.44477 , 0.13455838, -0.02335798, 0.17705794, 0.35757783, 0.23010781,
0.01467692, 0.62879086, -0.03701181, 0.24380289, -0.21328464, 0.24822766,
-0.0053475 , 0.32602036, -0.00869995, 0.09179476, -0.14592564, 0.07204366,
0.07495923, -0.22580136, -0.01973284, -0.17688696, -0.07899545, -0.0433769 },
{ 0.03159099, 0.01711542, -0.10484185, -0.06258203, 0.14017868, 0.18142754,
-0.02204198, -0.23487183, 0.00541379, -0.07643563, -0.1559733 , -0.21309927,
0.27148026, 0.33272606, -0.07097985, -0.2449444 , 0.26929748, 0.34862405,
-0.4098676 , -0.28650478, -0.16343112, -0.47433573, -0.01719139, 0.21358447,
-0.22136977, -0.27829313, -0.05840438, -0.04151871, 0.01733726, -0.3077036,
-0.0625564 , 0.22362956, 0.0774032 , -0.24106106, 0.07874846, -0.6678918,
-0.23914678, 0.17828159, -0.34047198, 0.19169956, -0.49452287, 0.25358477,
0.04473523, -0.43213138, -0.3529497 , -0.16450526, -0.29587936, -0.74040514},
{-0.2757092 , 0.1513899 , 0.10168222, 0.17029373, 0.21876213, 0.42645454,
0.26845133, 0.05051502, -0.00956955, 0.2558387 , -0.09914787, 0.01259337,
-0.00241654, 0.43294385, -0.06541437, -0.10698438, 0.17861538, 0.1288432,
-0.25038853, 0.02954194, -0.08991966, 0.05463137, -0.14920586, 0.06484849,
-0.05562418, -0.33387426, -0.05309192, -0.17910376, 0.42390588, 0.02304784,
-0.28463838, 0.36908332, -0.31654972, -0.05224141, -0.51904446, -0.2576439,
-0.02743226, 0.3901019 , 0.0722606 , 0.13808247, -0.03472245, 0.03350862,
0.08887509, -0.22572924, -0.3120373 , 0.50598854, -0.18692902, 0.03104536},
{-0.13711423, 0.27218965, -0.30982837, -0.06271633, 0.11709681, 0.3655758,
0.04613091, 0.1649118 , -0.12020878, 0.48200697, 0.13922596, -0.02678711,
0.2836235 , 0.41767967, 0.20237724, 0.26682803, -0.46025336, 0.01473906,
-0.31150863, -0.13491912, -0.4247842 , -0.32958567, -0.32901376, -0.17867817,
-0.49612495, -0.2946146 , 0.14187108, 0.15475537, 0.02988405, -0.6422889,
-0.50066817, 0.02767938, -0.48536086, -0.01244744, -0.42876065, -0.23366575,
-0.3432993 , 0.5043874 , -0.11933201, 0.2769795 , -0.15428834, 0.2784369,
0.05971728, -0.47159842, -0.22549127, 0.07489682, -0.28810132, -0.28396243},
{ 0.4052296 , 0.39721385, 0.16002338, 0.46999246, 0.44904566, -0.18677579,
0.24832928, 0.40552175, 0.3757918 , 0.3989749 , 0.7466541 , 0.12526177,
0.380102 , 0.06161096, 0.06376936, 0.12316163, 0.17968553, -0.21769358,
0.29327512, 0.21461542, -0.01066717, 0.32253686, 0.09988583, -0.43407273,
0.7402072 , 0.37065497, 0.46109968, 0.72411275, -0.02024834, 0.09893885,
0.29933167, 0.12368561, -0.3957233 , -0.02218472, -0.6735557 , -0.13150525,
-0.31393284, 0.4576719 , -0.30508226, 0.25316706, -0.36334398, 0.30185857,
0.30490535, -0.37768734, -0.41823298, 0.95337546, -0.62425405, 0.21474145},
{-0.56393266, -0.11831116, -0.4802289 , -0.43266162, -0.39150375, 0.17782006,
-0.14869174, -0.2595772 , -0.2636459 , -0.01872367, -0.05078246, -0.2644873,
-0.46024954, -0.50625855, -0.3577595 , -0.32341388, 0.22251658, 0.15820642,
0.19421425, 0.1580318 , -0.1129314 , 0.26936835, -0.17293265, 0.5696975,
-0.10716978, 0.1804149 , -0.08604706, 0.6311513 , 0.36868146, 0.32397324,
0.16717687, 0.8073754 , 0.08213035, -0.19142093, 0.06709275, -0.05189571,
-0.06387871, 0.44571698, -0.21647316, -0.12338956, -0.05081637, -0.10340592,
0.3024256 , -0.06075826, -0.02166707, -0.28597444, -0.10445018, 0.02116952},
{ 0.28735894, -0.00943804, 0.5608048 , 0.02494083, 0.11069093, -0.2260689,
0.03268875, 0.22180696, 0.12901126, -0.04739738, 0.56028235, 0.28044164,
0.38697064, 0.1871149 , 0.27396998, 0.57376647, -0.05026472, -0.14250597,
0.02356944, -0.24221162, 0.23043945, -0.16553421, 0.34873822, -0.27734458,
0.23212954, -0.0326644 , 0.06927733, 0.09688118, -0.21654509, -0.16727884,
0.28051716, -0.41790453, -0.42687538, 0.02502382, -0.14250888, -0.3302217,
-0.37665576, -0.33375698, 0.13302384, -0.21538857, 0.01647016, 0.03878022,
0.03750464, 0.14440563, -0.4221202 , 0.5718157 , 0.03735317, 0.47470593},
{-0.04338795, 0.01065533, -0.22340627, -0.16674694, 0.08048011, 0.5407966,
0.03627587, -0.11789448, -0.00462468, 0.33129472, -0.30931604, -0.33942345,
0.25519586, 0.62158775, -0.19336964, 0.00969895, -0.16092438, 0.32265547,
-0.3862282 , 0.10244315, -0.36191425, -0.2928678 , -0.21181387, 0.07916971,
-0.6252262 , -0.04262253, -0.16243209, -0.32281786, 0.22644208, -0.07446007,
-0.3674131 , 0.0252867 , 0.30513066, -0.01551689, 0.42287797, 0.13742124,
0.5540895 , 0.5426447 , 0.08805652, 0.0631345 , -0.13853103, -0.20738626,
-0.01199325, -0.00485282, -0.1025574 , -0.94176763, 0.11043977, -0.17477657},
{ 0.6057775 , 0.0810724 , 0.41553885, 0.73374546, 0.50211585, -0.66890264,
0.08621392, 0.4167605 , 0.24823412, 0.34987044, 0.5675426 , 0.3505066,
0.11479317, 0.28876683, 0.220345 , 0.30800283, -0.05972926, -0.12195904,
0.10005048, -0.21550988, 0.08965074, -0.36103082, -0.03101348, -0.46489114,
0.2522224 , -0.07621143, -0.02272156, -0.07949062, -0.2113978 , -0.3823656,
-0.06086957, -0.6945538 , 0.19045472, 0.09044611, 0.20159659, 0.01999267,
0.38128832, -0.489164 , 0.15787831, -0.30637702, -0.172217 , -0.12810084,
-0.06776062, 0.3925702 , 0.07021673, 0.06525978, 0.25342825, 0.08659276},
{-0.42430827, -0.43088093, -0.78688675, -0.4678062 , -0.9050431 , -0.44643643,
-0.09322385, -0.59906834, -0.50362754, -0.06601983, -0.38999277, -0.28211606,
-0.34804407, -0.58609164, -0.25697276, -0.12127738, 0.79577285, -0.00076221,
0.46145 , 0.3988987 , 0.9626502 , 0.9213116 , 0.23690379, 0.49355784,
0.45780322, 0.6611823 , 0.47779062, 0.55226153, 0.4985791 , 0.76731545,
0.62595975, 0.27158496, -0.04951643, 0.14671867, -0.5974002 , -0.24730986,
-0.07500714, -0.30647743, -0.18272081, 0.36255264, 0.02985331, 0.00112209,
0.05433936, 0.02013764, 0.00067295, 0.6602448 , -0.15957068, 0.1264886},
{ 0.03648907, -0.05168581, 0.04616986, -0.15576027, 0.5062357 , 0.3721129,
-0.15111613, 0.00915703, 0.278042 , 0.01206551, -0.27195582, -0.2855333,
-0.33555037, 0.01351699, 0.05177139, -0.0699712 , -0.43159798, 0.024824,
-0.42055145, -0.11886617, -0.43509626, -0.6195991 , -0.30236137, 0.1160633,
-0.31618825, -0.47493848, -0.4877523 , -0.39990705, 0.09827142, -0.16683179,
-0.2740942 , -0.5660414 , 0.29104194, 0.30937797, 0.53713495, 0.233047,
0.277571 , 0.2838329 , 0.20824887, -0.18529612, 0.1355415 , -0.07765692,
-0.14421932, 0.28495267, 0.08438616, -0.6111454 , 0.25984398, -0.37778968},
{-0.398869 , -0.3738167 , -0.2413041 , -0.33082163, -0.50837195, 0.23415886,
-0.10890647, -0.5592783 , -0.53567296, 0.00306334, -0.33149353, -0.3554444,
0.04174213, -0.5158709 , -0.39875466, -0.3362605 , 0.2597986 , 0.14698753,
-0.02452215, 0.4936351 , 0.19711255, 0.18564673, -0.02003697, 0.451195,
-0.09074799, 0.38756105, 0.39005986, 0.04105154, 0.21765305, 0.33913207,
-0.00495462, 0.76309645, -0.45248973, 0.07160244, -0.34371528, -0.01013598,
-0.29228753, 0.47487226, -0.03024807, 0.25958252, -0.2609167 , 0.28557548,
0.34593832, -0.18475756, -0.2657135 , -0.03253972, -0.19920373, -0.41044986},
{-0.01249933, -0.20828654, -0.09173236, -0.02522583, -0.11680512, 0.17592879,
-0.17254815, 0.10457627, 0.02961675, -0.11064648, 0.54663044, 0.10568127,
-0.22991191, -0.25532097, -0.2988905 , 0.33062956, -0.12431063, -0.36108628,
0.17875512, -0.10259079, 0.24313982, -0.05137338, -0.1534299 , -0.17436668,
-0.25355053, -0.08057237, -0.76685524, -0.5702398 , -0.34621766, 0.04837879,
-0.20729208, -0.46325675, 0.61350256, 0.44168544, 0.52617466, 0.51075065,
0.3989174 , -0.45767152, 0.08354512, -0.28294808, 0.18848686, -0.25274694,
-0.5948065 , 0.22851115, 0.16232683, -0.5361507 , 0.541543 , 0.36248288,},
{-0.22447987, 0.03499718, -0.25295433, -0.12751198, -0.07232181, 0.43365872,
-0.25229058, 0.07128626, -0.0099504 , 0.16978341, -0.42396438, -0.07229665,
-0.03738681, 0.1014331 , -0.36785045, 0.05694272, -0.02871237, 0.38308486,
-0.19731559, 0.23654024, -0.25995767, -0.10228045, 0.15564986, 0.03407062,
-0.37430888, -0.23239242, -0.04656844, -0.1194833 , 0.28841755, -0.28932327,
-0.11651783, 0.44139612, 0.02863682, -0.10376391, -0.1079159 , -0.00907907,
-0.34113762, 0.33423543, -0.25243354, 0.38842604, -0.3146441 , 0.27762327,
0.31690654, -0.04755968, -0.32804552, -0.2485983 , -0.06238414, -0.253345 },
{ 0.7248914 , 0.39772227, 0.12826681, -0.11950496, 0.6220402 , 0.7518849,
0.34184027, 0.49392584, 0.623124 , 0.18207619, -0.43304783, -0.11267078,
0.24862415, 0.30127215, 0.17973863, 0.30810758, -0.5255435 , 0.10269722,
0.05070333, -0.31527802, -0.71968037, -0.3175742 , -0.30189514, 0.12895653,
-0.68077767, -0.5527323 , -0.8084097 , -0.68349946, -0.34303364, -0.36693686,
-0.499456 , -0.21946928, 0.28324634, 0.3974623 , 0.26757115, 0.6331997,
0.7872374 , -0.03028867, 0.36110917, -0.45320314, 0.42889097, -0.5312182,
-0.3625467 , 0.5565551 , 0.65372646, -0.5217376 , 0.6831335 , 0.10624173}};
double* hadamard(double* first_array, double* second_array){
double* return_array = (double *)malloc(16*sizeof(double));
for(int j=0; j<16; j++){
return_array[j] = first_array[j]*second_array[j];
}
return return_array;
}
double* sum_funct_2D(double* first_array, double* second_array){
double* return_array = (double *)malloc(16*sizeof(double));
for(int j=0; j<16; j++){
return_array[j] = first_array[j]+second_array[j];
}
return return_array;
}
double* sub_1(double* first_array){
double* return_array = (double *)malloc(16*sizeof(double));
for(int j=0; j<16; j++){
return_array[j] = 1-first_array[j];
}
return return_array;
}
double* sigma(double* sum_stage_1){
for(int j=0; j<16; j++){
if(sum_stage_1[j]<-2.5) sum_stage_1[j] = 0;
else if(sum_stage_1[j]>2.5) sum_stage_1[j] = 1;
else sum_stage_1[j] = (0.2*sum_stage_1[j]) + 0.5;
}
return sum_stage_1;
}
double* thanh(double* sum_stage_2){
for(int j=0; j<16; j++){
sum_stage_2[j] = 1- (2/(1+exp(2*sum_stage_2[j])));
}
return sum_stage_2;
}
double* dot(double first_array, double second_array[]){
double* return_array = (double *)malloc(16*sizeof(double));
for(int j=0; j<16; j++){
return_array[j] = first_array*second_array[j];
}
return return_array;
}
double* recurrent_dot(double* first_array, double** second_array){
double* return_array = (double *)malloc(16*sizeof(double));
for(int j=0; j<16; j++){
for(int l=0; l<16; l++){
return_array[j] = return_array[j]+(first_array[l]*second_array[l][j]);
}
}
return return_array;
}
double* bias_add(double* first_array, double* second_array){
double* return_array = (double *)malloc(16*sizeof(double));
for(int j=0; j<16; j++){
return_array[j] = first_array[j] + second_array[j];
}
return return_array;
}
double* gru_single_cell(double* h_tm1, double inputs){
double inputs_z = inputs;
double inputs_r = inputs;
double inputs_h = inputs;
double* h_tm1_z = h_tm1;
double* h_tm1_r = h_tm1;
double* h_tm1_h = h_tm1;
double* kernel_z = (double *)malloc(16*sizeof(double));
double* kernel_r = (double *)malloc(16*sizeof(double));
double* kernel_h = (double *)malloc(16*sizeof(double));
double* input_bias_z = (double *)malloc(16*sizeof(double));
double* input_bias_r = (double *)malloc(16*sizeof(double));
double* input_bias_h = (double *)malloc(16*sizeof(double));
for(int i=0; i<16; i++){
kernel_z[i] = kernel[i];
kernel_r[i] = kernel[16+i];
kernel_h[i] = kernel[32+i];
input_bias_z[i] = gru_bias[i];
input_bias_r[i] = gru_bias[16+i];
input_bias_h[i] = gru_bias[32+i];
}
double* recurrent_kernel_z[16];
for(int k=0; k<16; k++){
recurrent_kernel_z[k] = (double *)malloc(16*sizeof(double));
}
double* recurrent_kernel_r[16];
for(int k=0; k<16; k++){
recurrent_kernel_r[k] = (double *)malloc(16*sizeof(double));
}
double* recurrent_kernel_h[16];
for(int k=0; k<16; k++){
recurrent_kernel_h[k] = (double *)malloc(16*sizeof(double));
}
for(int i=0; i<16; i++){
for(int j=0; j<16; j++){
recurrent_kernel_z[j][i] = recurrent_kernel[j][i];
recurrent_kernel_r[j][i] = recurrent_kernel[j][16+i];
recurrent_kernel_h[j][i] = recurrent_kernel[j][32+i];
}
}
double* x_z = dot(inputs_z, kernel_z);
double* x_r = dot(inputs_r, kernel_r);
double* x_h = dot(inputs_h, kernel_h);
//use bias = True, therefore adding bias to above
//reset_after = False, therefor no recurrent_bias
x_z = bias_add(x_z, input_bias_z);
x_r = bias_add(x_r, input_bias_r);
x_h = bias_add(x_h, input_bias_h);
double* recurrent_z = recurrent_dot(h_tm1_z, recurrent_kernel_z);
double* recurrent_r = recurrent_dot(h_tm1_r, recurrent_kernel_r);
double* z = sigma(sum_funct_2D(x_z, recurrent_z));
double* r = sigma(sum_funct_2D(x_r, recurrent_r));
double* recurrent_h = recurrent_dot(hadamard(r, h_tm1_h), recurrent_kernel_h);
double* hh = thanh(sum_funct_2D(x_h, recurrent_h));
double* h = sum_funct_2D(hadamard(z, h_tm1), hadamard(sub_1(z), hh));
double* y = (double*) malloc(16*sizeof(double));
for(int i=0; i<16; i++){
y[i] = h[i];
h_tm1[i] = h[i];
printf("y%d = %f\n", i+1, y[i]);
}
return y;
}
void main(){
double* h_tm1 = (double *)malloc(16*sizeof(double));
for(int i=0; i<16; i++){
h_tm1[i] = 0;
}
double inputs[32] = {1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1};
double* y;
double out = 0;
printf("Recursive_input=%f\n", h_tm1[0]);
for (int i=0; i<32; i++){
y = gru_single_cell(h_tm1, inputs[i]);
for(int j=0; j<16; j++){
out = out + y[j]*dense_kernel[j];
}
out = out + dense_bias;
if(i==31) printf("Out_%d = %f\n", i, out);
out = 0;
}
}