diff --git a/Project.toml b/Project.toml index 554aff55..6a424db1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PSIS" uuid = "ce719bf2-d5d0-4fb9-925d-10a81b42ad04" authors = ["Seth Axen and contributors"] -version = "0.2.5" +version = "0.2.6" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/core.jl b/src/core.jl index 04e42fd0..7546aa6e 100644 --- a/src/core.jl +++ b/src/core.jl @@ -170,8 +170,6 @@ While `psis` computes smoothed log weights out-of-place, `psis!` smooths them in # Keywords - - `sorted=issorted(vec(log_ratios))`: whether `log_ratios` are already sorted. Only - accepted if `nparams==1`. - `improved=false`: If `true`, use the adaptive empirical prior of [^Zhang2010]. If `false`, use the simpler prior of [^ZhangStephens2009], which is also used in [^VehtariSimpson2021]. @@ -207,7 +205,7 @@ end function psis!( logw::AbstractVector, reff=1; - sorted::Bool=issorted(logw), + sorted::Bool=false, # deprecated improved::Bool=false, warn::Bool=true, ) @@ -219,11 +217,11 @@ function psis!( @warn "$M tail draws is insufficient to fit the generalized Pareto distribution. $MISSING_SHAPE_SUMMARY" return PSISResult(logw, LogExpFunctions.logsumexp(logw), reff_val, M, missing) end - perm = sorted ? collect(eachindex(logw)) : sortperm(logw) - icut = S - M - tail_range = (icut + 1):S - @inbounds logw_tail = @views logw[perm[tail_range]] - @inbounds logu = logw[perm[icut]] + perm = partialsortperm(logw, (S - M):S) + cutoff_ind = perm[1] + tail_inds = @view perm[2:(M + 1)] + logu = logw[cutoff_ind] + logw_tail = @views logw[tail_inds] _, tail_dist = psis_tail!(logw_tail, logu, M, improved) warn && check_pareto_shape(tail_dist) return PSISResult(logw, LogExpFunctions.logsumexp(logw), reff_val, M, tail_dist)