From 197542ff7632e5a12b99aada8719be0785a348f6 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 30 Sep 2024 13:38:18 -0400 Subject: [PATCH] using nth_element --- src/stan/mcmc/chainset.hpp | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/src/stan/mcmc/chainset.hpp b/src/stan/mcmc/chainset.hpp index d9a019685c..d7d4a463e7 100644 --- a/src/stan/mcmc/chainset.hpp +++ b/src/stan/mcmc/chainset.hpp @@ -9,8 +9,7 @@ #include #include #include -#include -#include +#include #include #include #include @@ -249,8 +248,10 @@ class chainset { * @param index parameter index * @return median */ - double median(const int index) const { return quantile(index, 0.5); } - + double median(const int index) const { + return (quantile(index, 0.5)); + } + /** * Compute median value of specified parameter across all chains. * @@ -275,9 +276,9 @@ class chainset { auto center = median(index); Eigen::MatrixXd abs_dev = (draws.array() - center).abs(); size_t idx = static_cast(0.5 * (abs_dev.size() - 1)); - std::vector sorted(abs_dev.data(), abs_dev.data() + abs_dev.size()); - std::nth_element(sorted.begin(), sorted.begin() + idx, sorted.end()); - return 1.4826 * sorted[idx]; + std::vector copy(abs_dev.data(), abs_dev.data() + abs_dev.size()); + std::nth_element(copy.begin(), copy.begin() + idx, copy.end()); + return 1.4826 * copy[idx]; } /** @@ -311,10 +312,10 @@ class chainset { throw std::out_of_range("Probability must be between 0 and 1."); } Eigen::MatrixXd draws = samples(index); - size_t idx = static_cast(prob * (draws.size() - 1)); - std::vector sorted(draws.data(), draws.data() + draws.size()); - std::nth_element(sorted.begin(), sorted.begin() + idx, sorted.end()); - return sorted[idx]; + std::vector copy(draws.data(), draws.data() + draws.size()); + int idx = static_cast(prob * (copy.size() - 1)); + std::nth_element(copy.begin(), copy.begin() + idx, copy.end()); + return copy[idx]; } /** @@ -349,12 +350,8 @@ class chainset { if (probs.minCoeff() < 0.0 || probs.maxCoeff() > 1.0) { throw std::out_of_range("Probabilities must be between 0 and 1."); } - Eigen::MatrixXd draws = samples(index); - std::vector sorted(draws.data(), draws.data() + draws.size()); - std::sort(sorted.begin(), sorted.end()); for (size_t i = 0; i < probs.size(); ++i) { - size_t idx = static_cast(probs[i] * (sorted.size() - 1)); - quantiles[i] = sorted[idx]; + quantiles[i] = quantile(index, probs[i]); } return quantiles; }