forked from hxim/paq8px
-
Notifications
You must be signed in to change notification settings - Fork 0
/
LMS.hpp
110 lines (100 loc) · 2.25 KB
/
LMS.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#pragma once
#include <cassert>
#include <cstdint>
#include <cstdio>
#include <math.h>
#include "cstdint"
#include "SystemDefines.hpp"
/**
* Least Mean Squares predictor
* @tparam F
* @tparam T
*/
template<typename F, typename T>
class LMS {
private:
F *weights;
F *eg;
F *buffer;
F rates[2];
F rho;
F complement;
F eps;
F prediction;
int s;
int d;
public:
/**
*
* @param s
* @param d
* @param lRate
* @param rRate
* @param rho
* @param eps
*/
LMS(const int s, const int d, const F lRate, const F rRate, const F rho = (F) 0.95, const F eps = (F) 1e-3) : rates {lRate, rRate},
rho(rho), complement(1.0f - rho), eps(eps), prediction(0.0f), s(s), d(d) {
assert(s > 0 && d > 0);
weights = new F[s + d], eg = new F[s + d], buffer = new F[s + d];
reset();
}
/**
*
*/
~LMS() {
delete weights, delete eg, delete buffer;
}
float rsqrt(const float x) {
#if defined(ARM_NEON_AVAILABLE)
float r = vgetq_lane_f32(vrsqrteq_f32(vdupq_n_f32(x)), 0); //NEON
#elif defined(X64_SIMD_AVAILABLE)
float r = _mm_cvtss_f32(_mm_rsqrt_ss(_mm_set_ss(x))); //SSE
#else
float r = 1.0f / sqrt(x);
#endif
return (0.5F * (r + 1.0F / (x * r)));
}
/**
*
* @param sample
* @return
*/
F predict(const T sample) {
memmove(&buffer[s + 1], &buffer[s], (d - 1) * sizeof(F));
buffer[s] = sample;
prediction = 0.;
for( int i = 0; i < s + d; i++ ) {
prediction += weights[i] * buffer[i];
}
return prediction;
}
/**
*
* @param sample
*/
void update(const T sample) {
const F error = sample - prediction;
int i = 0;
for( ; i < s; i++ ) {
const F gradient = error * buffer[i];
eg[i] = rho * eg[i] + complement * (gradient * gradient);
weights[i] += (rates[0] * gradient * rsqrt(eg[i] + eps));
}
for( ; i < s + d; i++ ) {
const F gradient = error * buffer[i];
eg[i] = rho * eg[i] + complement * (gradient * gradient);
weights[i] += (rates[1] * gradient * rsqrt(eg[i] + eps));
}
memmove(&buffer[1], &buffer[0], (s - 1) * sizeof(F));
buffer[0] = sample;
}
/**
*
*/
void reset() {
for( int i = 0; i < s + d; i++ ) {
weights[i] = eg[i] = buffer[i] = 0.;
}
}
};