Skip to content

Commit

Permalink
using nth_element
Browse files Browse the repository at this point in the history
  • Loading branch information
mitzimorris committed Sep 30, 2024
1 parent d33c4c4 commit 197542f
Showing 1 changed file with 13 additions and 16 deletions.
29 changes: 13 additions & 16 deletions src/stan/mcmc/chainset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
#include <boost/accumulators/accumulators.hpp>
#include <boost/accumulators/statistics/stats.hpp>
#include <boost/accumulators/statistics/mean.hpp>
#include <boost/accumulators/statistics/tail_quantile.hpp>
#include <boost/accumulators/statistics/p_square_quantile.hpp>
#include <boost/accumulators/statistics/median.hpp>
#include <boost/accumulators/statistics/variance.hpp>
#include <boost/accumulators/statistics/covariance.hpp>
#include <boost/accumulators/statistics/variates/covariate.hpp>
Expand Down Expand Up @@ -249,8 +248,10 @@ class chainset {
* @param index parameter index
* @return median
*/
double median(const int index) const { return quantile(index, 0.5); }

double median(const int index) const {
return (quantile(index, 0.5));
}

/**
* Compute median value of specified parameter across all chains.
*
Expand All @@ -275,9 +276,9 @@ class chainset {
auto center = median(index);
Eigen::MatrixXd abs_dev = (draws.array() - center).abs();
size_t idx = static_cast<size_t>(0.5 * (abs_dev.size() - 1));
std::vector<double> sorted(abs_dev.data(), abs_dev.data() + abs_dev.size());
std::nth_element(sorted.begin(), sorted.begin() + idx, sorted.end());
return 1.4826 * sorted[idx];
std::vector<double> copy(abs_dev.data(), abs_dev.data() + abs_dev.size());
std::nth_element(copy.begin(), copy.begin() + idx, copy.end());
return 1.4826 * copy[idx];
}

/**
Expand Down Expand Up @@ -311,10 +312,10 @@ class chainset {
throw std::out_of_range("Probability must be between 0 and 1.");
}
Eigen::MatrixXd draws = samples(index);
size_t idx = static_cast<size_t>(prob * (draws.size() - 1));
std::vector<double> sorted(draws.data(), draws.data() + draws.size());
std::nth_element(sorted.begin(), sorted.begin() + idx, sorted.end());
return sorted[idx];
std::vector<double> copy(draws.data(), draws.data() + draws.size());
int idx = static_cast<int>(prob * (copy.size() - 1));
std::nth_element(copy.begin(), copy.begin() + idx, copy.end());
return copy[idx];
}

/**
Expand Down Expand Up @@ -349,12 +350,8 @@ class chainset {
if (probs.minCoeff() < 0.0 || probs.maxCoeff() > 1.0) {
throw std::out_of_range("Probabilities must be between 0 and 1.");
}
Eigen::MatrixXd draws = samples(index);
std::vector<double> sorted(draws.data(), draws.data() + draws.size());
std::sort(sorted.begin(), sorted.end());
for (size_t i = 0; i < probs.size(); ++i) {
size_t idx = static_cast<size_t>(probs[i] * (sorted.size() - 1));
quantiles[i] = sorted[idx];
quantiles[i] = quantile(index, probs[i]);
}
return quantiles;
}
Expand Down

0 comments on commit 197542f

Please sign in to comment.