Skip to content

Commit

Permalink
Sort only upper tail waits (#22)
Browse files Browse the repository at this point in the history
* Deprecate sorted keyword

* Use partialsortperm

* Increment version number

* Update src/core.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
sethaxen and github-actions[bot] authored Dec 30, 2021
1 parent ae98bab commit 1af98ee
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PSIS"
uuid = "ce719bf2-d5d0-4fb9-925d-10a81b42ad04"
authors = ["Seth Axen <[email protected]> and contributors"]
version = "0.2.5"
version = "0.2.6"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
14 changes: 6 additions & 8 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,6 @@ While `psis` computes smoothed log weights out-of-place, `psis!` smooths them in
# Keywords
- `sorted=issorted(vec(log_ratios))`: whether `log_ratios` are already sorted. Only
accepted if `nparams==1`.
- `improved=false`: If `true`, use the adaptive empirical prior of [^Zhang2010].
If `false`, use the simpler prior of [^ZhangStephens2009], which is also used in
[^VehtariSimpson2021].
Expand Down Expand Up @@ -207,7 +205,7 @@ end
function psis!(
logw::AbstractVector,
reff=1;
sorted::Bool=issorted(logw),
sorted::Bool=false, # deprecated
improved::Bool=false,
warn::Bool=true,
)
Expand All @@ -219,11 +217,11 @@ function psis!(
@warn "$M tail draws is insufficient to fit the generalized Pareto distribution. $MISSING_SHAPE_SUMMARY"
return PSISResult(logw, LogExpFunctions.logsumexp(logw), reff_val, M, missing)
end
perm = sorted ? collect(eachindex(logw)) : sortperm(logw)
icut = S - M
tail_range = (icut + 1):S
@inbounds logw_tail = @views logw[perm[tail_range]]
@inbounds logu = logw[perm[icut]]
perm = partialsortperm(logw, (S - M):S)
cutoff_ind = perm[1]
tail_inds = @view perm[2:(M + 1)]
logu = logw[cutoff_ind]
logw_tail = @views logw[tail_inds]
_, tail_dist = psis_tail!(logw_tail, logu, M, improved)
warn && check_pareto_shape(tail_dist)
return PSISResult(logw, LogExpFunctions.logsumexp(logw), reff_val, M, tail_dist)
Expand Down

2 comments on commit 1af98ee

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/51435

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.6 -m "<description of version>" 1af98ee699820bf56f3a051136a76aecd961e303
git push origin v0.2.6

Please sign in to comment.