-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for distributions with monotonically increasing bijector
#297
Conversation
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@sethaxen maybe you want to have a look at this |
Looks good to me, but maybe we want to wait for @sethaxen |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The approach looks right! I suggest we add overloads for more bijectors. Is there a reason this PR doesn't also add is_monotonically_decreasing
?
All univariate bijections are strictly monotonic. So we could define is_monotonically_decreasing(b) = !ismonotonically_increasing(b)
if we documented that this function is only expected to give the correct answer when a univariate bijector is passed. But this could cause problems. Do we have any way to statically detect if a bijector is univariate?
The PR needs tests to cover each of the additions.
src/interface.jl
Outdated
function is_monotonically_increasing(cf::ComposedFunction) | ||
return is_monotonically_increasing(cf.inner) && is_monotonically_increasing(cf.outer) | ||
end | ||
is_monotonically_increasing(::typeof(exp)) = true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_monotonically_increasing(::typeof(exp)) = true | |
is_monotonically_increasing(::typeof(exp)) = true | |
is_monotonically_increasing(::typeof(log)) = true | |
is_monotonically_increasing(binv::Inverse) = is_monotonically_increasing(inverse(b)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is an interface function, would it be better to place these methods in the files where the corresponding bijectors are implemented? Also, I think we can mark this as true for Logit
, LeakyReLu
, Scale
(when scale is positive), Shift
, and TruncatedBijector
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed:)
I did consider this, but AFAIK the only monotonically decreasing bijectors we have right now is But it's probably worth it, so I'll add that too 👍
Not at the moment, no.
And because we don't, I'd prefer to make it all explicit so we end up with a method error / always return |
Co-authored-by: Seth Axen <[email protected]>
Ah okay so now I remember another reason why I was holding back on The condition is_monotonically_decreasing(f.inner) && is_monotonically_decreasing(f.outer) won't be correct, e.g. EDIT: Though this is of course also an issue for EDIT 2: Nvm, it all just boils down to
|
I don't understand the table, but I believe it amounts to first checking that all bijectors are (elementwise) univariate with |
My table is conveying the same idea, just on a per-composition-basis (since we're defining the method for But I've now added support for monotonically decreasing functions too + tests:) |
Aaaalrighty! Final got this thing working:) Issue was that we added one too many transformations @sethaxen : should just be BUT one final thing: what should we put as a warning regarding usage of |
Damn. Seems like we missed something in #313 |
Note that there doesn't seem to be anything incorrect with the impl, but it's failing because it's trying to compare elements which aren't part of the triangular part |
Pfft well that was painful. Added comments regarding what the issue is + fixed it by introducing a wrapper to avoid comparing Would you have a quick look at some point @sethaxen ? 🙏 Think we're there now after we've addressed the following:)
|
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
… into torfjelde/ordered-for-monotonic
Cool will try to review this evening.
Sounds weird, I'll check the test. |
I was talking about the example that we were discussing in one of the other comments; specifically #297 (comment) |
Ah, that's expected though. I assume you un-fixed the variance parameter and randomly sampled it within the rejection sampling inner loop? The issue here is that when the mean is the same for both components, then the variance actually has no impact on whether they are ordered. I think you should see a difference if you make the mean a reverse-ordered vector. The further the two mean components, the more pronounced the difference and the harder it is to rejection sample. |
the example I talk about is not related to rejection sampling; I'm referring to the example you ran with |
I'm confused which example you're referring to then. The one in the comment you linked to compares NUTS with rejection sampling, but it does so with fixed variance, so it would not manifest the issue I'm talking about. Here's an example that does: exampleusing Turing
using Bijectors: ordered
using LinearAlgebra
using Random: Random
using PosteriorStats
@model function demo_ordered(μ)
k = length(μ)
σ² ~ filldist(truncated(Normal(), lower=0), k)
x ~ ordered(MvNormal(μ, Diagonal(σ²)))
return (; σ², x)
end
k = 2
num_samples = 1_000_000
num_chains = 8
# Sample using NUTS.
μ = [3, 0] # note: reverse-ordered, most draws will be rejected
model = demo_ordered(μ)
Random.seed!(0)
chain = sample(model, NUTS(), MCMCThreads(), num_samples ÷ num_chains, num_chains)
xs_chain = permutedims(Array(chain))
σ²_chain = cat(only(get(chain, :σ²))...; dims=3)
# Rejection sampling.
σ²_exact = mapreduce(hcat, 1:num_samples) do _
while true
σ² = rand(filldist(truncated(Normal(), lower=0), k))
d = MvNormal(μ, Diagonal(σ²))
xs = rand(d)
issorted(xs) && return σ²
end
end
qts = [0.05, 0.25, 0.5, 0.75, 0.95]
qt_names = map(q -> Symbol("q$(Int(100 * q))"), qts)
stats_with_mcses = (
Tuple(qt_names) => Base.Fix2(quantile, qts),
(Symbol("$(qn)_mcse") => (x -> mcse(x; kind=Base.Fix2(quantile, q))) for (q, qn) in zip(qts, qt_names))...,
) julia> PosteriorStats.summarize(σ²_chain, stats_with_mcses...; var_names=["σ²[1]", "σ²[2]"])
SummaryStats
q5 q25 q50 q75 q95 q5_mcse q25_mcse q50_mcse q75_mcse q95_mcse
σ²[1] 0.239 0.769 1.249 1.77 2.531 0.0031 0.0024 0.0030 0.0053 0.0031
σ²[2] 0.17 0.74 1.220 1.740 2.533 0.012 0.0052 0.0031 0.0025 0.0026
julia> PosteriorStats.summarize(reshape(σ²_exact', :, 1, 2), stats_with_mcses...; var_names=["σ²[1]", "σ²[2]"])
SummaryStats
q5 q25 q50 q75 q95 q5_mcse q25_mcse q50_mcse q75_mcse q95_mcse
σ²[1] 0.226 0.758 1.234 1.747 2.538 0.00087 0.00091 0.00092 0.0010 0.0017
σ²[2] 0.227 0.759 1.235 1.749 2.535 0.00081 0.00096 0.00094 0.0010 0.0020 Note that the rejection sampling approach makes sense. The quantiles of the two variances should be about the same, since to get an ordered draw with a well-separated reverse-ordered mean, one needs to increase the variance, but it doesn't matter which variance is increased. But if we look at the HMC draws, we see that there's an asymmetry between the variances. This is due to the missing normalization factor. If we had a closed-form expression for it, we could test that, but I don't know one. |
TBH I'm not certain if the above examples are even correct. The place I expect this to manifest is when conditioning. Which is implicitly what the rejection-sampling approach is doing (conditioning on |
Completely missed the fact that we were fixing the variance 🤦
Gotcha, gotcha; understand better now 👍 Soooo how do we summarize all this into a simple warning for the end-user? 👀 |
This PR is just waiting for the following:
Think it's worth waiting until @sethaxen is back to let him have a final say before we merge 👍 |
Maybe an admonition saying something like:
Not in love with it; feels too wordy. |
Added it, but did broke it up a bit + added a shorter warning to the initial parts of the docstring:) Thanks @sethaxen ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks @torfjelde for taking on this fix!
Related: #220 and #295