Skip to content

Commit

Permalink
Fixed a bug in rstats_central_moment_finalize where 1st order moments…
Browse files Browse the repository at this point in the history
… would always be provided regardless of p.
  • Loading branch information
borchehq committed Aug 16, 2024
1 parent fdfc9af commit db12a71
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 18 deletions.
26 changes: 14 additions & 12 deletions src/rstats/include/rstats.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ inline void rstats_central_moment(double x, double w, double *buffer, uint64_t p
*
* @param results A pointer to an array of doubles where the final mean and
* central moments will be stored:
* - `results[0]` will store the final mean value.
* - `results[1]` to `results[p+1]` will store the final central
* - `results[p + 1]` will store the final mean value.
* - `results[0]` to `results[p]` will store the final central
* moments from the 0th to the p-th order.
* results must point to an array of length p + 2.
* @param buffer A pointer to an array of doubles used in
Expand All @@ -328,20 +328,22 @@ inline void rstats_central_moment(double x, double w, double *buffer, uint64_t p
*/
inline void rstats_central_moment_finalize(double *results, double *buffer,
uint64_t p, bool standardize) {
results[0] = buffer[1]; // Mean.
results[1] = 1.0; // 0th central moment is always 1.
results[2] = 0.0; // 1st central moment is always 0.
for(uint64_t i = 3; i < p + 2; i++) {
results[i] = buffer[i - 1] / buffer[0];
for(uint64_t i = 2; i < p + 1; i++) {
results[i] = buffer[i] / buffer[0];
}
if(standardize) {
results[1] = 1.0; // 0th standardized central moment is always 1.
results[2] = 0.0; // 1st standardized central moment is always 0.
for(uint64_t i = 4; i < p + 2; i++) {
results[i] = results[i] / rstats_pow(sqrt(results[3]), i - 1);
for(uint64_t i = 3; i < p + 1; i++) {
results[i] = results[i] / rstats_pow(sqrt(results[2]), i);
}
if(p >= 2) {
results[2] = 1.0; // 2nd standardized central moment is always 1.
}
results[3] = 1.0; // 2nd standardized central moment is always 1.
}
results[0] = 1.0; // 0th standardized central moment is always 1.
if(p >= 1) {
results[1] = 0.0; // 1st standardized central moment is always 0.
}
results[p + 1] = buffer[1]; // Mean.
}

/**
Expand Down
32 changes: 26 additions & 6 deletions src/rstats/test/test_rstats.c
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ void test_central_moment() {
double results_2[4] = {0.0}, buffer_2[5] = {0.0};
double sum_weights = 0.0, sum_weights_comp = 0.0;
double tmp = 0.0;
double results_3[3], buffer_3[3];

for(size_t i = 0; i < 10; i++) {
rstats_central_moment(x[i], weights[i], buffer, p);
Expand All @@ -126,14 +127,14 @@ void test_central_moment() {
mass += weights[i] * x[i];
sum_weights_comp += weights[i];
mean_cmp = mass / sum_weights_comp;
//printf("%f, %f\n", results[0], wmoments_comp[0] / sum_weights_comp);
assert(fabs(results[0] - mean_cmp) < 1e-7);
//printf("%f, %f\n", results[16], mean_cmp);
assert(fabs(results[16] - mean_cmp) < 1e-7);
for(size_t k = 0; k < p; k++) {
for(size_t j = 0; j < i + 1; j++) {
wmoments_comp[k] += weights[j] * rstats_pow(x[j] - mean_cmp, k);
}
//printf("%f, %f\n", results[k + 1], wmoments_comp[k] / sum_weights_comp);
assert(fabs(results[k + 1] - wmoments_comp[k] / sum_weights_comp) < 1e-2);
//printf("%f, %f\n", results[k], wmoments_comp[k] / sum_weights_comp);
assert(fabs(results[k] - wmoments_comp[k] / sum_weights_comp) < 1e-2);
wmoments_comp[k] = 0.0;
}
}
Expand All @@ -148,8 +149,27 @@ void test_central_moment() {
rstats_kurtosis(x[i], weights[i], buffer_2);
rstats_kurtosis_finalize(results_2, buffer_2);
if(i > 0) {
assert(fabs(results[4] - results_2[2]) < 1e-2);
assert(fabs(results[5] - results_2[3]) < 1e-2);
assert(fabs(results[3] - results_2[2]) < 1e-2);
assert(fabs(results[4] - results_2[3]) < 1e-2);
}
}

for(size_t k = 0; k < 3; k++) {
for(size_t i = 0; i < 10; i++) {
rstats_central_moment(x[i], weights[i], buffer_3, k);
rstats_central_moment_finalize(results_3, buffer_3, k, false);
assert(results_3[0] == 1.0);
if(k >= 1) {
assert(results_3[1] == 0.0);
}
rstats_central_moment_finalize(results_3, buffer_3, k, true);
assert(results_3[0] == 1.0);
if(k >= 1) {
assert(results_3[1] == 0.0);
}
if(k >= 2) {
assert(results_3[2] == 1.0);
}
}
}
}
Expand Down

0 comments on commit db12a71

Please sign in to comment.