Skip to content

Commit

Permalink
Merge branch 'feature/3299-diagnostics-chainset' of https://github.co…
Browse files Browse the repository at this point in the history
…m/stan-dev/stan into feature/3299-diagnostics-chainset
  • Loading branch information
mitzimorris committed Sep 30, 2024
2 parents 76d8984 + 021ac04 commit 23a2022
Showing 1 changed file with 10 additions and 17 deletions.
27 changes: 10 additions & 17 deletions src/stan/mcmc/chainset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
#include <boost/accumulators/accumulators.hpp>
#include <boost/accumulators/statistics/stats.hpp>
#include <boost/accumulators/statistics/mean.hpp>
#include <boost/accumulators/statistics/tail_quantile.hpp>
#include <boost/accumulators/statistics/p_square_quantile.hpp>
#include <boost/accumulators/statistics/median.hpp>
#include <boost/accumulators/statistics/variance.hpp>
#include <boost/accumulators/statistics/covariance.hpp>
#include <boost/accumulators/statistics/variates/covariate.hpp>
Expand Down Expand Up @@ -249,9 +248,7 @@ class chainset {
* @param index parameter index
* @return median
*/
double median(const int index) const {
return quantile(index, 0.5);
}
double median(const int index) const { return (quantile(index, 0.5)); }

/**
* Compute median value of specified parameter across all chains.
Expand All @@ -277,9 +274,9 @@ class chainset {
auto center = median(index);
Eigen::MatrixXd abs_dev = (draws.array() - center).abs();
size_t idx = static_cast<size_t>(0.5 * (abs_dev.size() - 1));
std::vector<double> sorted(abs_dev.data(), abs_dev.data() + abs_dev.size());
std::nth_element(sorted.begin(), sorted.begin() + idx, sorted.end());
return 1.4826 * sorted[idx];
std::vector<double> copy(abs_dev.data(), abs_dev.data() + abs_dev.size());
std::nth_element(copy.begin(), copy.begin() + idx, copy.end());
return 1.4826 * copy[idx];
}

/**
Expand Down Expand Up @@ -313,10 +310,10 @@ class chainset {
throw std::out_of_range("Probability must be between 0 and 1.");
}
Eigen::MatrixXd draws = samples(index);
size_t idx = static_cast<size_t>(prob * (draws.size() - 1));
std::vector<double> sorted(draws.data(), draws.data() + draws.size());
std::nth_element(sorted.begin(), sorted.begin() + idx, sorted.end());
return sorted[idx];
std::vector<double> copy(draws.data(), draws.data() + draws.size());
int idx = static_cast<int>(prob * (copy.size() - 1));
std::nth_element(copy.begin(), copy.begin() + idx, copy.end());
return copy[idx];
}

/**
Expand Down Expand Up @@ -351,12 +348,8 @@ class chainset {
if (probs.minCoeff() < 0.0 || probs.maxCoeff() > 1.0) {
throw std::out_of_range("Probabilities must be between 0 and 1.");
}
Eigen::MatrixXd draws = samples(index);
std::vector<double> sorted(draws.data(), draws.data() + draws.size());
std::sort(sorted.begin(), sorted.end());
for (size_t i = 0; i < probs.size(); ++i) {
size_t idx = static_cast<size_t>(probs[i] * (sorted.size() - 1));
quantiles[i] = sorted[idx];
quantiles[i] = quantile(index, probs[i]);
}
return quantiles;
}
Expand Down

0 comments on commit 23a2022

Please sign in to comment.