Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for distributions with monotonically increasing bijector #297

Merged
merged 51 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
459082a
add support for `ordered` when bijector is monotonically increasing
torfjelde Dec 3, 2023
15ec103
bump patch version
torfjelde Dec 3, 2023
011e41b
Update src/interface.jl
torfjelde Dec 3, 2023
5ffa2ea
formatting
torfjelde Dec 3, 2023
03fc4b4
Update src/interface.jl
torfjelde Dec 10, 2023
d37da88
added more impls of is_monotonically_increasing
torfjelde Dec 10, 2023
94bc3b9
added `is_monotonically_increasing` for `Shift`
torfjelde Dec 10, 2023
ce26a29
reverted is_monotonicall_increasing impl for Scale but added for
torfjelde Dec 10, 2023
c303eaa
added impl of `is_monotonically_decreasing` and corrected impls for c…
torfjelde Dec 10, 2023
84bd3ec
added monotonic impls for `Scale`
torfjelde Dec 10, 2023
e8504f6
added monotonic impls for `TruncatedBijector`
torfjelde Dec 10, 2023
bab389d
`ordered` now also supports monotonically decreasing transformations
torfjelde Dec 10, 2023
3d404f1
added `inverse` impl for `SignFlip`
torfjelde Dec 10, 2023
505b92e
Merge remote-tracking branch 'origin/master' into torfjelde/ordered-f…
torfjelde Dec 10, 2023
9144887
fixed `output_size` for `SignFlip`
torfjelde Dec 10, 2023
064de1e
formatting
torfjelde Dec 10, 2023
790f62c
another test case
torfjelde Dec 10, 2023
565aaf6
updated a comment
torfjelde Dec 10, 2023
1cec356
Merge branch 'master' into torfjelde/ordered-for-monotonic
torfjelde Apr 15, 2024
e8cb6f9
added some additional comments
torfjelde Apr 19, 2024
2e89ce2
Merge branch 'master' into torfjelde/ordered-for-monotonic
torfjelde Apr 25, 2024
003c488
Apply suggestions from code review
torfjelde May 8, 2024
1929040
Update test/bijectors/ordered.jl
torfjelde May 8, 2024
98f993c
added `OrderedDistribution` to address bugs in current `ordered`
torfjelde May 12, 2024
173be81
return `OrderedDistribution` from `ordered`
torfjelde May 12, 2024
85285f0
move the `ordered` definition be near `OrderedDistribution`
torfjelde May 12, 2024
cb24aa3
initial work on adding tests
torfjelde May 12, 2024
fbc6ec3
added currently failing correctness tests
torfjelde May 12, 2024
bdc547c
Merge remote-tracking branch 'origin/torfjelde/ordered-for-monotonic'…
torfjelde May 12, 2024
f87a87b
fixed `rand` for `OrderedDistribution`
torfjelde May 13, 2024
d496221
more extensive correctness testing of `ordered`
torfjelde May 13, 2024
35e7d31
test ordered for higher dims
torfjelde May 17, 2024
1d6d5a6
Apply suggestions from code review
torfjelde May 17, 2024
6bc4a17
don't use `InverseGamma` as target due to heavy tails
torfjelde May 17, 2024
baf15c8
Merge remote-tracking branch 'origin/torfjelde/ordered-for-monotonic'…
torfjelde May 17, 2024
826032d
Update src/bijectors/ordered.jl
torfjelde May 17, 2024
8b9761f
fixed syntax error
torfjelde May 17, 2024
deece0f
Merge branch 'master' into torfjelde/ordered-for-monotonic
torfjelde Jun 4, 2024
d287a96
Merge branch 'master' into torfjelde/ordered-for-monotonic
torfjelde Jun 5, 2024
08ac70a
fixed OrderedBijector + added some docs for it
torfjelde Jun 5, 2024
3ebbb12
forgot to uncomment tests in previous commit + fixed them
torfjelde Jun 5, 2024
688ecb7
more test uncommented
torfjelde Jun 5, 2024
4577c00
fixed failing tests for LKJ
torfjelde Jun 5, 2024
c610ec9
Update test/ad/chainrules.jl
torfjelde Jun 5, 2024
c837965
Update test/ad/chainrules.jl
torfjelde Jun 5, 2024
fa60b5d
Update test/ad/chainrules.jl
torfjelde Jun 5, 2024
6873d96
better initialization for ordered chains
torfjelde Jun 5, 2024
9f2aa92
Merge remote-tracking branch 'origin/torfjelde/ordered-for-monotonic'…
torfjelde Jun 5, 2024
c9304af
added the description of the un-normalized `oredered` issue
torfjelde Jun 26, 2024
7b67cbf
fixed docstring of OrderedBijeector
torfjelde Jun 26, 2024
79aa3ec
Merge branch 'master' into torfjelde/ordered-for-monotonic
sethaxen Jun 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/bijectors/exp_log.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ logabsdetjac(b::Elementwise{typeof(exp)}, x) = sum(x)

logabsdetjac(b::typeof(log), x::Real) = -log(x)
logabsdetjac(b::Elementwise{typeof(log)}, x) = -sum(log, x)

is_monotonically_increasing(::typeof(exp)) = true
is_monotonically_increasing(::typeof(log)) = true
2 changes: 2 additions & 0 deletions src/bijectors/leaky_relu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@ function with_logabsdet_jacobian(b::LeakyReLU, x::AbstractArray)
J = mask .* b.α .+ (!).(mask)
return J .* x, sum(log.(abs.(J)))
end

is_monotonically_increasing(::LeakyReLU) = true
3 changes: 3 additions & 0 deletions src/bijectors/logit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ logabsdetjac(b::Logit, x) = sum(logit_logabsdetjac.(x, b.a, b.b))
function with_logabsdet_jacobian(b::Logit, x)
return _logit.(x, b.a, b.b), sum(logit_logabsdetjac.(x, b.a, b.b))
end


torfjelde marked this conversation as resolved.
Show resolved Hide resolved
is_monotonically_increasing(::Logit) = true
30 changes: 23 additions & 7 deletions src/bijectors/ordered.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
struct SignFlip <: Bijector end

with_logabsdet_jacobian(::SignFlip, x) = -x, zero(eltype(x))
inverse(::SignFlip) = SignFlip()
output_size(::SignFlip, x) = size(x)
is_monotonically_increasing(::SignFlip) = false
is_monotonically_decreasing(::SignFlip) = false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be true?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! Good catch:)

torfjelde marked this conversation as resolved.
Show resolved Hide resolved

"""
OrderedBijector()

Expand All @@ -17,14 +25,22 @@ Return a `Distribution` whose support are ordered vectors, i.e., vectors with in
This transformation is currently only supported for otherwise unconstrained distributions.
"""
function ordered(d::ContinuousMultivariateDistribution)
if bijector(d) !== identity
throw(
ArgumentError(
"ordered transform is currently only supported for unconstrained distributions.",
),
)
# We're good if the map from unconstrained (in which we apply the ordered bijector)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be useful to add to the docstring a comment about when this is safe to use? e.g. in general, if d has variable parameters, then this is unsafe, because it implicitly renormalizes the distribution, but the log of the new normalization constant is never computed so is never included in the log density or gradient, though it would need to be when the parameters are variable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you say "variable parameters", do you mean that the size / length can vary?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I mean parameters that are being sampled, optimized in the model.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you're saying

sigma ~ InverseGamma(2, 3)
x ~ ordered(MvNormal(zeros(2), sigma^2 * I))

will be incorrect?

Copy link
Member Author

@torfjelde torfjelde May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because it implicitly renormalizes the distribution, but the log of the new normalization constant is never computed so is never included in the log density or gradient,

Okay so I understand what you mean by this now, but I'm struggling a bit to understand why this is an issue here but not for other "typical" diffeomorphisms. The jacobian correction is exactly so we don't have to worry about normalization, no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well another way to put it is that the function is not in fact a diffeomorphisms, because it's injective, not bijective. The original distribution is defined on some support, but ordered(...) has a support that is a subset of the original distribution. So we're not just applying a bijective map and worrying about the change of measure. We're additionally restricting the support, which the Jacobian cannot capture. If we had a magic function that could compute normalization factors, then we would use that to compute the normalized logpdf after restricting the support, and then Ordered would indeed be a bijection.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this I can get behind for the ordered, but I don't get this for exp?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same deal. Because you're taking a density whose support is the whole real line and using a change of variables from the real line to the positive reals (not the whole real line). So you're changing the support if you use Exp for this distribution.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nah, then I still don't get it. I'm with you that you're changing the domain, but $\exp: \mathbb{R} \to (0, \infty)$ is bijective, which are the domains we are interested in. Why does it matter that $\exp$ treated as a map from $\mathbb{R} \to \mathbb{R}$ is not bijective? o.O

Could you maybe write down exactly what you mean mathematically, or alternatively point me to a resource where I can read about this? I don't think I fully understand exactly what issue you're pointing to here, but really feel like this is something I should know so very keen on learning more.

Copy link
Member Author

@torfjelde torfjelde May 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw just implemented sampling from the following model:

using Turing
using Bijectors: ordered
using LinearAlgebra
using Random: Random

function marginal(d::Distributions.AbstractMvNormal, i::Int)
    μ = mean(d)[i]
    σ² = var(d)[i]
    return Normal(μ, sqrt(σ²))
end

@model function demo_ordered(μ)
    k = length(μ)
    σ² ~ filldist(truncated(Normal(), lower=0), k)
    x ~ ordered(MvNormal(μ, Diagonal(σ²)))
    return (; σ², x)
end

k = 2
num_samples = 10_000

# Sample using NUTS.
μ = zeros(k)
σ² = ones(k)

model = fix(demo_ordered(μ); σ²)
chain = sample(model, NUTS(), num_samples)
xs_chain = permutedims(Array(chain))

# Rejection sampling.
d = MvNormal(μ, Diagonal(σ²))
xs_exact = mapreduce(hcat, 1:num_samples) do _
    xs = [rand(marginal(d, 1))]
    for i = 2:k
        x_prev = xs[end]
        while true
            x = rand(marginal(d, i))
            if x >= x_prev
                push!(xs, x)
                break
            end
        end
    end
    return xs
end

display(chain)

qts = [0.05, 0.25, 0.5, 0.75, 0.95]
qs_chain = mapslices(xs_chain; dims=2) do xs
    quantile(xs, qts)
end
qs_exact = mapslices(xs_exact; dims=2) do xs
    quantile(xs, qts)
end

and even that fails to produce consistent results 😕

Are these two really supposed to be the same?

Here's the output of the above:

┌ Info: Found initial step size
└   ϵ = 3.2
Chains MCMC chain (10000×14×1 Array{Float64, 3}):

Iterations        = 1001:1:11000
Number of chains  = 1
Samples per chain = 10000
Wall duration     = 1.75 seconds
Compute duration  = 1.75 seconds
parameters        = x[1], x[2]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse     ess_bulk    ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64      Float64     Float64   Float64       Float64 

        x[1]   -0.0000    1.0099    0.0100   10294.6856   8010.6755    0.9999     5872.6101
        x[2]    1.6619    2.3489    0.0266    9314.9650   7138.7240    1.0000     5313.7279

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

        x[1]   -2.0060   -0.6943   -0.0090    0.7007    1.9483
        x[2]   -1.2217    0.3319    1.2620    2.3638    7.3225

┌ Info: quantiles
│   qs_chain =2×5 Matrix{Float64}:-1.64868   -0.694321  -0.00904604  0.700658  1.64419-0.841051   0.331897   1.26199     2.36376   5.42821
│   qs_exact =2×5 Matrix{Float64}:-1.62889   -0.648943  0.00702831  0.664035  1.63491-0.502251   0.297564  0.874958    1.48053   2.38214

As you can see, the quantiles are different.

# to constrained is monotonically increasing, i.e. order-preserving. In that case,
# we can form the ordered transformation as `binv ∘ OrderedBijector() ∘ b`.
# TODO: Add support for monotonically _decreasing_ transformations. This will be the
# the same as above, but the ordering will be reversed by `binv` so we need to handle this.
b = bijector(d)
binv = inverse(b)
if is_monotonically_decreasing(binv)
ordered_b = binv ∘ SignFlip() ∘ OrderedBijector() ∘ SignFlip() ∘ b
elseif is_monotonically_increasing(binv)
ordered_b = binv ∘ OrderedBijector() ∘ b
else
throw(ArgumentError("ordered transform is currently not supported for $d."))
end
return transformed(d, OrderedBijector())

return transformed(d, ordered_b)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hang on, is transformed even doing the right thing here? e.g. if d is an IID normal, then, ordered_b is an OrderedBijector that maps from unconstrained y to constrained x. So the result is an unconstrained distribution one samples from by first sampling from the IID normal and then constraining so the result is sorted. But this is not what we want. This is not the same thing as constraining x to be distributed according to d but restricted to the ordered subset of the support. Instead, what we want here is something that wraps d but assigns it to have the bijector ordered_b. Random generation is not possible, but log-density to within a normalization factor can be computed.

e.g. something like

using Distributions: Distribution, VariateForm, ValueSupport, ArrayLikeVariate

struct WithBijector{B,F<:VariateForm,S<:ValueSupport,D<:Distribution{F,S}} <: Distribution{F,S}
    dist::D
    bijector::B
end
with_bijector(d::Distribution, b) = WithBijector(d, b)
Bijectors.bijector(d::WithBijector) = d.bijector
Distributions.logpdf(d::WithBijector, x) = Distributions.logpdf(d.dist, x)
function Distributions.logpdf(d::WithBijector{<:Any, ArrayLikeVariate{N}}, x::AbstractArray{<:Real, M}) where {N, M}
    return Distributions.logpdf(d.dist, x)
end
Base.length(d::WithBijector) = length(d.dist)

If you then use return with_bijector(d, inverse(ordered_b)) in this line, your example in https://github.com/TuringLang/Bijectors.jl/pull/297/files#r1597099187 seems to work.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the same thing as constraining x to be distributed according to d but restricted to the ordered subset of the support.

Ooookay so now I'm with you! I originally thought that this was just a transformation that moved us to an ordered distribution, but didn't necessary have any relation to the "base" distribution (until you're previous comment on the docstring clarifying what it was supposed to do), hence never considered this to be an issue.

But yeah this is a pretty significant bug then, and has been around for a while 😬

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's definitely get some tests in the CI that would have caught this bug.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. Working on a fix + tests as we speak.

Copy link
Member

@sethaxen sethaxen May 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, something like with_bijector would be a convenient addition to the API for power users to manually select the bijector for a distribution. Applications include cases where the target distribution has high probability mass near a singularity of the default bijector l so one wants a custom bijector that moves the probability mass in the unconstrained space away from the singularity. Another application is programmatic benchmarking of alternative bijectors.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something like with_bijector would be a convenient addition to the API

Agreed! But I think this is such a special case it warrants its own distribution (because we can implement rand using rejection sampling for several cases used in practice). But we should also add a more general with_bijector method 👍

But I tried implementing this, and I'm still seeing fairly drastic discrepancies in the quantiles 😕 Feeling a bit out of it today though so might just have messed up the impl or tests somewhere; let me know if you see anything obvious! I pushed the changes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I tried implementing this, and I'm still seeing fairly drastic discrepancies in the quantiles 😕 Feeling a bit out of it today though so might just have messed up the impl or tests somewhere; let me know if you see anything obvious! I pushed the changes.

Ah, really? The example you posted earlier looks pretty good to me:

using Turing
using Bijectors: ordered
using LinearAlgebra
using Random: Random
using Turing.DynamicPPL: fix
using PosteriorStats
using Random

@model function demo_ordered(μ)
    k = length(μ)
    σ² ~ filldist(truncated(Normal(), lower=0), k)
    x ~ ordered(MvNormal(μ, Diagonal(σ²)))
    return (; σ², x)
end

k = 2
num_samples = 1_000_000
num_chains = 8

# Sample using NUTS.
μ = zeros(k)
σ² = ones(k)

model = fix(demo_ordered(μ); σ²)

Random.seed!(0)
chain = sample(model, NUTS(), MCMCThreads(), num_samples ÷ num_chains, num_chains)
xs_chain = permutedims(Array(chain))

# Rejection sampling.
d = MvNormal(μ, Diagonal(σ²))
xs_exact = mapreduce(hcat, 1:num_samples) do _
    while true
        xs = rand(d)
        issorted(xs) && return xs
    end
end

qts = [0.05, 0.25, 0.5, 0.75, 0.95]
qt_names = map(q -> Symbol("q$(Int(100 * q))"), qts)
stats_with_mcses = (
    Tuple(qt_names) => Base.Fix2(quantile, qts),
    (Symbol("$(qn)_mcse") => (x -> mcse(x; kind=Base.Fix2(quantile, q))) for (q, qn) in zip(qts, qt_names))...,
)

Comparing the quantiles, they're generally within 3 MCSEs of each other. Not really sure what the cut-off should be, but given that the MCSEs are themselves estimated using a method that uses an asymptotic approximation, I'd guess this is good enough:

julia> PosteriorStats.summarize(reshape(xs_chain', :, 1, 2), stats_with_mcses...)
SummaryStats
        q5     q25     q50      q75    q95  q5_mcse  q25_mcse  q50_mcse  q75_mcse  q95_mcse 
 1  -1.961  -1.110  -0.546  -1.e-04  0.757   0.0042    0.0022    0.0017    0.0015    0.0021
 2  -0.761  -0.001   0.544   1.107   1.955   0.0020    0.0014    0.0013    0.0015    0.0027

julia> PosteriorStats.summarize(reshape(xs_exact', :, 1, 2), stats_with_mcses...)
SummaryStats
        q5     q25     q50      q75    q95  q5_mcse  q25_mcse  q50_mcse  q75_mcse  q95_mcse 
 1  -1.955  -1.109  -0.546  -0.0003  0.763   0.0021    0.0012    0.0010    0.0011    0.0017
 2  -0.761  -0.001   0.544   1.107   1.954   0.0017    0.0011    0.0010    0.0012    0.0017

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My target was incorrect 🤦‍♂️ I implemented sequential rejection sampling, i.e. sample x[1] and then conditionally on this sample ordered vector using the marginals. This ofc leads to a different distribution.

Using the correct rejection sampling approach now:)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Further progress: it now works for most things but getting correctness issues for product dist with negative support (see tests).

Copy link
Member Author

@torfjelde torfjelde May 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I just saw all your comments below 🤦‍♂️ Will address those soon, hopefully (seem to have caught something so a bit under the weather atm)

end

with_logabsdet_jacobian(b::OrderedBijector, x) = transform(b, x), logabsdetjac(b, x)
Expand Down
3 changes: 3 additions & 0 deletions src/bijectors/scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,6 @@ _logabsdetjac_scale(a::AbstractVector, x::AbstractMatrix, ::Val{2}) = sum(log
# Matrix: single input.
_logabsdetjac_scale(a::AbstractMatrix, x::AbstractVector, ::Val{1}) = logabsdet(a)[1]
_logabsdetjac_scale(a::AbstractMatrix, x::AbstractMatrix, ::Val{2}) = logabsdet(a)[1]

is_monotonically_increasing(a::Scale) = all(Base.Fix1(>, 0), a.a)
is_monotonically_decreasing(a::Scale) = all(Base.Fix1(<, 0), a.a)
2 changes: 2 additions & 0 deletions src/bijectors/shift.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ _logabsdetjac_shift(a, x) = zero(eltype(x))
_logabsdetjac_shift_array_batch(a, x) = zeros(eltype(x), size(x, ndims(x)))

with_logabsdet_jacobian(b::Shift, x) = transform(b, x), logabsdetjac(b, x)

is_monotonically_increasing(::Shift) = true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this also need implementation of is_monotonically_decreasing ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep yep!

torfjelde marked this conversation as resolved.
Show resolved Hide resolved
33 changes: 33 additions & 0 deletions src/bijectors/truncated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,36 @@ function truncated_logabsdetjac(x, a, b)
end

with_logabsdet_jacobian(b::TruncatedBijector, x) = transform(b, x), logabsdetjac(b, x)

# It's only monotonically decreasing if it's only upper-bounded.
# In the multivariate case, we can only say something reasonable if entries are monotonic.
function is_monotonically_increasing(b::TruncatedBijector)
lowerbounded, upperbounded = all(isfinite, b.lb), all(isfinite, b.ub)
return if lowerbounded
true
elseif upperbounded
# => decreasing
false
elseif all(!isfinite, b.lb) && all(!isfinite, b.ub)
# => all are unbounded so we have the identity
true
else
# => some are unbounded and some are bounded
false
end
end
function is_monotonically_decreasing(b::TruncatedBijector)
lowerbounded, upperbounded = all(isfinite, b.lb), all(isfinite, b.ub)
return if lowerbounded
false
elseif upperbounded
# => decreasing
true
elseif all(!isfinite, b.lb) && all(!isfinite, b.ub)
# => all are unbounded so we have the identity
false
else
# => some are unbounded and some are bounded
true
end
end
57 changes: 57 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,63 @@ transform!(::typeof(identity), x, y) = copy!(y, x)
logabsdetjac(::typeof(identity), x) = zero(eltype(x))
logabsdetjac!(::typeof(identity), x, logjac) = logjac

###################
# Other utilities #
###################
"""
is_monotonically_increasing(f)

Returns `true` if `f` is monotonically increasing.
"""
is_monotonically_increasing(f) = false
is_monotonically_increasing(::typeof(identity)) = true
is_monotonically_increasing(binv::Inverse) = is_monotonically_increasing(inverse(binv))
is_monotonically_increasing(ef::Elementwise) = is_monotonically_increasing(ef.x)
function is_monotonically_increasing(cf::ComposedFunction)
# Here we have a few different cases:
#
# inner \ outer | inc | dec | other
# --------------+-----+-----+------
# inc | inc | dec | NA
# dec | dec | inc | NA
# other | NA | NA | NA
# --------------+-----+-----+------
return if is_monotonically_increasing(cf.inner)
is_monotonically_increasing(cf.outer)
elseif is_monotonically_decreasing(cf.inner)
is_monotonically_decreasing(cf.outer)
else
false
end
end

"""
is_monotonically_decreasing(f)

Returns `true` if `f` is monotonically decreasing.
"""
is_monotonically_decreasing(f) = false
is_monotonically_decreasing(::typeof(identity)) = false
is_monotonically_decreasing(binv::Inverse) = is_monotonically_decreasing(inverse(binv))
is_monotonically_decreasing(ef::Elementwise) = is_monotonically_decreasing(ef.x)
function is_monotonically_decreasing(cf::ComposedFunction)
# Here we have a few different cases:
#
# inner \ outer | inc | dec | other
# --------------+-----+-----+------
# inc | inc | dec | NA
# dec | dec | inc | NA
# other | NA | NA | NA
# --------------+-----+-----+------
return if is_monotonically_increasing(cf.inner)
is_monotonically_decreasing(cf.outer)
elseif is_monotonically_decreasing(cf.inner)
is_monotonically_increasing(cf.outer)
else
false
end
end

torfjelde marked this conversation as resolved.
Show resolved Hide resolved
######################
# Bijectors includes #
######################
Expand Down
11 changes: 10 additions & 1 deletion test/bijectors/ordered.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,20 @@ end
MvTDist(1, collect(1.0:5), Matrix(I(5))),
product_distribution(fill(Normal(), 5)),
product_distribution(fill(TDist(1), 5)),
# positive supports
product_distribution(fill(LogNormal(), 5)),
product_distribution(fill(InverseGamma(2, 3), 5)),
# negative supports
product_distribution(fill(-1 * InverseGamma(2, 3), 5)),
# bounded supports
product_distribution(fill(Uniform(1, 2), 5)),
# different transformations
product_distribution(fill(transformed(InverseGamma(2, 3), Bijectors.Scale(3)), 5)),
product_distribution(fill(transformed(InverseGamma(2, 3), Bijectors.Scale(-3)), 5)),
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
]
d_ordered = ordered(d)
@test d_ordered isa Bijectors.TransformedDistribution
@test d_ordered.dist === d
@test d_ordered.transform isa OrderedBijector
y = randn(5)
x = inverse(bijector(d_ordered))(y)
@test issorted(x)
Expand Down
Loading