Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for distributions with monotonically increasing bijector #297

Merged
merged 51 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
459082a
add support for `ordered` when bijector is monotonically increasing
torfjelde Dec 3, 2023
15ec103
bump patch version
torfjelde Dec 3, 2023
011e41b
Update src/interface.jl
torfjelde Dec 3, 2023
5ffa2ea
formatting
torfjelde Dec 3, 2023
03fc4b4
Update src/interface.jl
torfjelde Dec 10, 2023
d37da88
added more impls of is_monotonically_increasing
torfjelde Dec 10, 2023
94bc3b9
added `is_monotonically_increasing` for `Shift`
torfjelde Dec 10, 2023
ce26a29
reverted is_monotonicall_increasing impl for Scale but added for
torfjelde Dec 10, 2023
c303eaa
added impl of `is_monotonically_decreasing` and corrected impls for c…
torfjelde Dec 10, 2023
84bd3ec
added monotonic impls for `Scale`
torfjelde Dec 10, 2023
e8504f6
added monotonic impls for `TruncatedBijector`
torfjelde Dec 10, 2023
bab389d
`ordered` now also supports monotonically decreasing transformations
torfjelde Dec 10, 2023
3d404f1
added `inverse` impl for `SignFlip`
torfjelde Dec 10, 2023
505b92e
Merge remote-tracking branch 'origin/master' into torfjelde/ordered-f…
torfjelde Dec 10, 2023
9144887
fixed `output_size` for `SignFlip`
torfjelde Dec 10, 2023
064de1e
formatting
torfjelde Dec 10, 2023
790f62c
another test case
torfjelde Dec 10, 2023
565aaf6
updated a comment
torfjelde Dec 10, 2023
1cec356
Merge branch 'master' into torfjelde/ordered-for-monotonic
torfjelde Apr 15, 2024
e8cb6f9
added some additional comments
torfjelde Apr 19, 2024
2e89ce2
Merge branch 'master' into torfjelde/ordered-for-monotonic
torfjelde Apr 25, 2024
003c488
Apply suggestions from code review
torfjelde May 8, 2024
1929040
Update test/bijectors/ordered.jl
torfjelde May 8, 2024
98f993c
added `OrderedDistribution` to address bugs in current `ordered`
torfjelde May 12, 2024
173be81
return `OrderedDistribution` from `ordered`
torfjelde May 12, 2024
85285f0
move the `ordered` definition be near `OrderedDistribution`
torfjelde May 12, 2024
cb24aa3
initial work on adding tests
torfjelde May 12, 2024
fbc6ec3
added currently failing correctness tests
torfjelde May 12, 2024
bdc547c
Merge remote-tracking branch 'origin/torfjelde/ordered-for-monotonic'…
torfjelde May 12, 2024
f87a87b
fixed `rand` for `OrderedDistribution`
torfjelde May 13, 2024
d496221
more extensive correctness testing of `ordered`
torfjelde May 13, 2024
35e7d31
test ordered for higher dims
torfjelde May 17, 2024
1d6d5a6
Apply suggestions from code review
torfjelde May 17, 2024
6bc4a17
don't use `InverseGamma` as target due to heavy tails
torfjelde May 17, 2024
baf15c8
Merge remote-tracking branch 'origin/torfjelde/ordered-for-monotonic'…
torfjelde May 17, 2024
826032d
Update src/bijectors/ordered.jl
torfjelde May 17, 2024
8b9761f
fixed syntax error
torfjelde May 17, 2024
deece0f
Merge branch 'master' into torfjelde/ordered-for-monotonic
torfjelde Jun 4, 2024
d287a96
Merge branch 'master' into torfjelde/ordered-for-monotonic
torfjelde Jun 5, 2024
08ac70a
fixed OrderedBijector + added some docs for it
torfjelde Jun 5, 2024
3ebbb12
forgot to uncomment tests in previous commit + fixed them
torfjelde Jun 5, 2024
688ecb7
more test uncommented
torfjelde Jun 5, 2024
4577c00
fixed failing tests for LKJ
torfjelde Jun 5, 2024
c610ec9
Update test/ad/chainrules.jl
torfjelde Jun 5, 2024
c837965
Update test/ad/chainrules.jl
torfjelde Jun 5, 2024
fa60b5d
Update test/ad/chainrules.jl
torfjelde Jun 5, 2024
6873d96
better initialization for ordered chains
torfjelde Jun 5, 2024
9f2aa92
Merge remote-tracking branch 'origin/torfjelde/ordered-for-monotonic'…
torfjelde Jun 5, 2024
c9304af
added the description of the un-normalized `oredered` issue
torfjelde Jun 26, 2024
7b67cbf
fixed docstring of OrderedBijeector
torfjelde Jun 26, 2024
79aa3ec
Merge branch 'master' into torfjelde/ordered-for-monotonic
sethaxen Jun 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.13.7"
version = "0.13.8"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
13 changes: 10 additions & 3 deletions src/bijectors/ordered.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,21 @@ Return a `Distribution` whose support are ordered vectors, i.e., vectors with in
This transformation is currently only supported for otherwise unconstrained distributions.
"""
function ordered(d::ContinuousMultivariateDistribution)
if bijector(d) !== identity
# We're good if the map from unconstrained (in which we apply the ordered bijector)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be useful to add to the docstring a comment about when this is safe to use? e.g. in general, if d has variable parameters, then this is unsafe, because it implicitly renormalizes the distribution, but the log of the new normalization constant is never computed so is never included in the log density or gradient, though it would need to be when the parameters are variable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you say "variable parameters", do you mean that the size / length can vary?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I mean parameters that are being sampled, optimized in the model.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you're saying

sigma ~ InverseGamma(2, 3)
x ~ ordered(MvNormal(zeros(2), sigma^2 * I))

will be incorrect?

Copy link
Member Author

@torfjelde torfjelde May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because it implicitly renormalizes the distribution, but the log of the new normalization constant is never computed so is never included in the log density or gradient,

Okay so I understand what you mean by this now, but I'm struggling a bit to understand why this is an issue here but not for other "typical" diffeomorphisms. The jacobian correction is exactly so we don't have to worry about normalization, no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well another way to put it is that the function is not in fact a diffeomorphisms, because it's injective, not bijective. The original distribution is defined on some support, but ordered(...) has a support that is a subset of the original distribution. So we're not just applying a bijective map and worrying about the change of measure. We're additionally restricting the support, which the Jacobian cannot capture. If we had a magic function that could compute normalization factors, then we would use that to compute the normalized logpdf after restricting the support, and then Ordered would indeed be a bijection.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this I can get behind for the ordered, but I don't get this for exp?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same deal. Because you're taking a density whose support is the whole real line and using a change of variables from the real line to the positive reals (not the whole real line). So you're changing the support if you use Exp for this distribution.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nah, then I still don't get it. I'm with you that you're changing the domain, but $\exp: \mathbb{R} \to (0, \infty)$ is bijective, which are the domains we are interested in. Why does it matter that $\exp$ treated as a map from $\mathbb{R} \to \mathbb{R}$ is not bijective? o.O

Could you maybe write down exactly what you mean mathematically, or alternatively point me to a resource where I can read about this? I don't think I fully understand exactly what issue you're pointing to here, but really feel like this is something I should know so very keen on learning more.

Copy link
Member Author

@torfjelde torfjelde May 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw just implemented sampling from the following model:

using Turing
using Bijectors: ordered
using LinearAlgebra
using Random: Random

function marginal(d::Distributions.AbstractMvNormal, i::Int)
    μ = mean(d)[i]
    σ² = var(d)[i]
    return Normal(μ, sqrt(σ²))
end

@model function demo_ordered(μ)
    k = length(μ)
    σ² ~ filldist(truncated(Normal(), lower=0), k)
    x ~ ordered(MvNormal(μ, Diagonal(σ²)))
    return (; σ², x)
end

k = 2
num_samples = 10_000

# Sample using NUTS.
μ = zeros(k)
σ² = ones(k)

model = fix(demo_ordered(μ); σ²)
chain = sample(model, NUTS(), num_samples)
xs_chain = permutedims(Array(chain))

# Rejection sampling.
d = MvNormal(μ, Diagonal(σ²))
xs_exact = mapreduce(hcat, 1:num_samples) do _
    xs = [rand(marginal(d, 1))]
    for i = 2:k
        x_prev = xs[end]
        while true
            x = rand(marginal(d, i))
            if x >= x_prev
                push!(xs, x)
                break
            end
        end
    end
    return xs
end

display(chain)

qts = [0.05, 0.25, 0.5, 0.75, 0.95]
qs_chain = mapslices(xs_chain; dims=2) do xs
    quantile(xs, qts)
end
qs_exact = mapslices(xs_exact; dims=2) do xs
    quantile(xs, qts)
end

and even that fails to produce consistent results 😕

Are these two really supposed to be the same?

Here's the output of the above:

┌ Info: Found initial step size
└   ϵ = 3.2
Chains MCMC chain (10000×14×1 Array{Float64, 3}):

Iterations        = 1001:1:11000
Number of chains  = 1
Samples per chain = 10000
Wall duration     = 1.75 seconds
Compute duration  = 1.75 seconds
parameters        = x[1], x[2]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse     ess_bulk    ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64      Float64     Float64   Float64       Float64 

        x[1]   -0.0000    1.0099    0.0100   10294.6856   8010.6755    0.9999     5872.6101
        x[2]    1.6619    2.3489    0.0266    9314.9650   7138.7240    1.0000     5313.7279

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

        x[1]   -2.0060   -0.6943   -0.0090    0.7007    1.9483
        x[2]   -1.2217    0.3319    1.2620    2.3638    7.3225

┌ Info: quantiles
│   qs_chain =2×5 Matrix{Float64}:-1.64868   -0.694321  -0.00904604  0.700658  1.64419-0.841051   0.331897   1.26199     2.36376   5.42821
│   qs_exact =2×5 Matrix{Float64}:-1.62889   -0.648943  0.00702831  0.664035  1.63491-0.502251   0.297564  0.874958    1.48053   2.38214

As you can see, the quantiles are different.

# to constrained is monotonically increasing, i.e. order-preserving. In that case,
# we can form the ordered transformation as `binv ∘ OrderedBijector() ∘ b`.
# TODO: Add support for monotonically _decreasing_ transformations. This will be the
# the same as above, but the ordering will be reversed by `binv` so we need to handle this.
b = bijector(d)
binv = inverse(b)
if !is_monotonically_increasing(binv)
throw(
ArgumentError(
"ordered transform is currently only supported for unconstrained distributions.",
"ordered transform is not supported for $d.",
),
)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end
return transformed(d, OrderedBijector())
return transformed(d, binv ∘ OrderedBijector() ∘ b)
end

with_logabsdet_jacobian(b::OrderedBijector, x) = transform(b, x), logabsdetjac(b, x)
Expand Down
17 changes: 17 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,23 @@ transform!(::typeof(identity), x, y) = copy!(y, x)
logabsdetjac(::typeof(identity), x) = zero(eltype(x))
logabsdetjac!(::typeof(identity), x, logjac) = logjac

###################
# Other utilities #
###################
"""
is_monotonically_increasing(f)

Returns `true` if `f` is monotonically increasing.
"""
is_monotonically_increasing(f) = false
is_monotonically_increasing(::typeof(identity)) = true
function is_monotonically_increasing(cf::ComposedFunction)
return is_monotonically_increasing(cf.inner) && is_monotonically_increasing(cf.outer)
end
is_monotonically_increasing(::typeof(exp)) = true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
is_monotonically_increasing(::typeof(exp)) = true
is_monotonically_increasing(::typeof(exp)) = true
is_monotonically_increasing(::typeof(log)) = true
is_monotonically_increasing(binv::Inverse) = is_monotonically_increasing(inverse(b))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is an interface function, would it be better to place these methods in the files where the corresponding bijectors are implemented? Also, I think we can mark this as true for Logit, LeakyReLu, Scale (when scale is positive), Shift, and TruncatedBijector.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed:)

is_monotonically_increasing(ef::Elementwise) = is_monotonically_increasing(ef.x)


torfjelde marked this conversation as resolved.
Show resolved Hide resolved
######################
# Bijectors includes #
######################
Expand Down
4 changes: 3 additions & 1 deletion test/bijectors/ordered.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ end
MvTDist(1, collect(1.0:5), Matrix(I(5))),
product_distribution(fill(Normal(), 5)),
product_distribution(fill(TDist(1), 5)),
product_distribution(fill(LogNormal(), 5)),
product_distribution(fill(InverseGamma(2, 3), 5)),
]
d_ordered = ordered(d)
@test d_ordered isa Bijectors.TransformedDistribution
@test d_ordered.dist === d
@test d_ordered.transform isa OrderedBijector
# @test d_ordered.transform isa OrderedBijector
y = randn(5)
x = inverse(bijector(d_ordered))(y)
@test issorted(x)
Expand Down
Loading