Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for distributions with monotonically increasing bijector #297

Merged
merged 51 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
459082a
add support for `ordered` when bijector is monotonically increasing
torfjelde Dec 3, 2023
15ec103
bump patch version
torfjelde Dec 3, 2023
011e41b
Update src/interface.jl
torfjelde Dec 3, 2023
5ffa2ea
formatting
torfjelde Dec 3, 2023
03fc4b4
Update src/interface.jl
torfjelde Dec 10, 2023
d37da88
added more impls of is_monotonically_increasing
torfjelde Dec 10, 2023
94bc3b9
added `is_monotonically_increasing` for `Shift`
torfjelde Dec 10, 2023
ce26a29
reverted is_monotonicall_increasing impl for Scale but added for
torfjelde Dec 10, 2023
c303eaa
added impl of `is_monotonically_decreasing` and corrected impls for c…
torfjelde Dec 10, 2023
84bd3ec
added monotonic impls for `Scale`
torfjelde Dec 10, 2023
e8504f6
added monotonic impls for `TruncatedBijector`
torfjelde Dec 10, 2023
bab389d
`ordered` now also supports monotonically decreasing transformations
torfjelde Dec 10, 2023
3d404f1
added `inverse` impl for `SignFlip`
torfjelde Dec 10, 2023
505b92e
Merge remote-tracking branch 'origin/master' into torfjelde/ordered-f…
torfjelde Dec 10, 2023
9144887
fixed `output_size` for `SignFlip`
torfjelde Dec 10, 2023
064de1e
formatting
torfjelde Dec 10, 2023
790f62c
another test case
torfjelde Dec 10, 2023
565aaf6
updated a comment
torfjelde Dec 10, 2023
1cec356
Merge branch 'master' into torfjelde/ordered-for-monotonic
torfjelde Apr 15, 2024
e8cb6f9
added some additional comments
torfjelde Apr 19, 2024
2e89ce2
Merge branch 'master' into torfjelde/ordered-for-monotonic
torfjelde Apr 25, 2024
003c488
Apply suggestions from code review
torfjelde May 8, 2024
1929040
Update test/bijectors/ordered.jl
torfjelde May 8, 2024
98f993c
added `OrderedDistribution` to address bugs in current `ordered`
torfjelde May 12, 2024
173be81
return `OrderedDistribution` from `ordered`
torfjelde May 12, 2024
85285f0
move the `ordered` definition be near `OrderedDistribution`
torfjelde May 12, 2024
cb24aa3
initial work on adding tests
torfjelde May 12, 2024
fbc6ec3
added currently failing correctness tests
torfjelde May 12, 2024
bdc547c
Merge remote-tracking branch 'origin/torfjelde/ordered-for-monotonic'…
torfjelde May 12, 2024
f87a87b
fixed `rand` for `OrderedDistribution`
torfjelde May 13, 2024
d496221
more extensive correctness testing of `ordered`
torfjelde May 13, 2024
35e7d31
test ordered for higher dims
torfjelde May 17, 2024
1d6d5a6
Apply suggestions from code review
torfjelde May 17, 2024
6bc4a17
don't use `InverseGamma` as target due to heavy tails
torfjelde May 17, 2024
baf15c8
Merge remote-tracking branch 'origin/torfjelde/ordered-for-monotonic'…
torfjelde May 17, 2024
826032d
Update src/bijectors/ordered.jl
torfjelde May 17, 2024
8b9761f
fixed syntax error
torfjelde May 17, 2024
deece0f
Merge branch 'master' into torfjelde/ordered-for-monotonic
torfjelde Jun 4, 2024
d287a96
Merge branch 'master' into torfjelde/ordered-for-monotonic
torfjelde Jun 5, 2024
08ac70a
fixed OrderedBijector + added some docs for it
torfjelde Jun 5, 2024
3ebbb12
forgot to uncomment tests in previous commit + fixed them
torfjelde Jun 5, 2024
688ecb7
more test uncommented
torfjelde Jun 5, 2024
4577c00
fixed failing tests for LKJ
torfjelde Jun 5, 2024
c610ec9
Update test/ad/chainrules.jl
torfjelde Jun 5, 2024
c837965
Update test/ad/chainrules.jl
torfjelde Jun 5, 2024
fa60b5d
Update test/ad/chainrules.jl
torfjelde Jun 5, 2024
6873d96
better initialization for ordered chains
torfjelde Jun 5, 2024
9f2aa92
Merge remote-tracking branch 'origin/torfjelde/ordered-for-monotonic'…
torfjelde Jun 5, 2024
c9304af
added the description of the un-normalized `oredered` issue
torfjelde Jun 26, 2024
7b67cbf
fixed docstring of OrderedBijeector
torfjelde Jun 26, 2024
79aa3ec
Merge branch 'master' into torfjelde/ordered-for-monotonic
sethaxen Jun 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
Expand Down Expand Up @@ -45,8 +46,9 @@ ChainRulesCore = "0.10.11, 1"
ChangesOfVariables = "0.1"
Compat = "3.46, 4.2"
Distributions = "0.25.33"
ForwardDiff = "0.10"
DistributionsAD = "0.6"
DocStringExtensions = "0.9"
ForwardDiff = "0.10"
Functors = "0.1, 0.2, 0.3, 0.4"
InverseFunctions = "0.1"
IrrationalConstants = "0.1, 0.2"
Expand Down
1 change: 1 addition & 0 deletions src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ using IrrationalConstants: IrrationalConstants
using LogExpFunctions: LogExpFunctions
using Roots: Roots
using Compat: Compat
using DocStringExtensions: TYPEDFIELDS

export TransformDistribution,
PositiveDistribution,
Expand Down
3 changes: 3 additions & 0 deletions src/bijectors/exp_log.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ logabsdetjac(b::Elementwise{typeof(exp)}, x) = sum(x)

logabsdetjac(b::typeof(log), x::Real) = -log(x)
logabsdetjac(b::Elementwise{typeof(log)}, x) = -sum(log, x)

is_monotonically_increasing(::typeof(exp)) = true
is_monotonically_increasing(::typeof(log)) = true
2 changes: 2 additions & 0 deletions src/bijectors/leaky_relu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@ function with_logabsdet_jacobian(b::LeakyReLU, x::AbstractArray)
J = mask .* b.α .+ (!).(mask)
return J .* x, sum(log.(abs.(J)))
end

is_monotonically_increasing(::LeakyReLU) = true
2 changes: 2 additions & 0 deletions src/bijectors/logit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,5 @@ logabsdetjac(b::Logit, x) = sum(logit_logabsdetjac.(x, b.a, b.b))
function with_logabsdet_jacobian(b::Logit, x)
return _logit.(x, b.a, b.b), sum(logit_logabsdetjac.(x, b.a, b.b))
end

is_monotonically_increasing(::Logit) = true
116 changes: 97 additions & 19 deletions src/bijectors/ordered.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,22 @@
struct SignFlip <: Bijector end

with_logabsdet_jacobian(::SignFlip, x) = -x, zero(eltype(x))
inverse(::SignFlip) = SignFlip()
output_size(::SignFlip, dim) = dim
is_monotonically_increasing(::SignFlip) = false
is_monotonically_decreasing(::SignFlip) = true

"""
OrderedBijector()

A bijector mapping ordered vectors in ℝᵈ to unordered vectors in ℝᵈ.
A bijector mapping unordered vectors in ℝᵈ to ordered vectors in ℝᵈ.

## See also
- [Stan's documentation](https://mc-stan.org/docs/2_27/reference-manual/ordered-vector.html)
- Note that this transformation and its inverse are the _opposite_ of in this reference.
"""
struct OrderedBijector <: Bijector end

"""
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
ordered(d::Distribution)

Return a `Distribution` whose support are ordered vectors, i.e., vectors with increasingly ordered elements.

This transformation is currently only supported for otherwise unconstrained distributions.
"""
function ordered(d::ContinuousMultivariateDistribution)
if bijector(d) !== identity
throw(
ArgumentError(
"ordered transform is currently only supported for unconstrained distributions.",
),
)
end
return transformed(d, OrderedBijector())
end

with_logabsdet_jacobian(b::OrderedBijector, x) = transform(b, x), logabsdetjac(b, x)

transform(b::OrderedBijector, y::AbstractVecOrMat) = _transform_ordered(y)
Expand Down Expand Up @@ -88,3 +78,91 @@ end

logabsdetjac(b::OrderedBijector, x::AbstractVector) = sum(@view(x[2:end]))
logabsdetjac(b::OrderedBijector, x::AbstractMatrix) = vec(sum(@view(x[2:end, :]); dims=1))

# Need a custom distribution type to handle this properly.
"""
OrderedDistribution

Wraps a distribution to restrict its support to the subspace of ordered vectors.

# Fields
$(TYPEDFIELDS)
"""
struct OrderedDistribution{D<:ContinuousMultivariateDistribution,B} <:
ContinuousMultivariateDistribution
"distribution transformed to have ordered support"
dist::D
"transformation from constrained space to ordered unconstrained space"
transform::B
end

"""
ordered(d::Distribution)

Return a `Distribution` whose support are ordered vectors, i.e., vectors with increasingly ordered elements.

Specifically, `d` is restricted to the subspace of its domain containing only ordered elements.

!!! warning
`rand` is implemented using rejection sampling, which can be slow for high-dimensional distributions.
In such cases, consider using MCMC methods to sample from the distribution instead.

!!! warning
The resulting ordered distribution is un-normalized, which can cause issues in some contexts, e.g. in
hierarchical models where the parameters of the ordered distribution are themselves sampled.
See the notes below for a more detailed discussion.

## Notes on `ordered` being un-normalized

The resulting ordered distribution is un-normalized. This is not a problem if used in a context where the
normalizing factor is irrelevant, but if the value of the normalizing factor impacts the resulting computation,
the results may be inaccurate.

For example, if the distribution is used in sampling a posterior distribution with MCMC and the parameters
of the ordered distribution are themselves sampled, then the normalizing factor would in general be needed
for accurate sampling, and `ordered` should not be used. However, if the parameters are fixed, then since
MCMC does not require distributions be normalized, `ordered` may be used without problems.

A common case is where the distribution being ordered is a joint distribution of `n` identical univariate
distributions. In this case the normalization factor works out to be the constant `n!`, and `ordered` can
again be used without problems even if the parameters of the univariate distribution are sampled.
"""
function ordered(d::ContinuousMultivariateDistribution)
# We're good if the map from unconstrained (in which we apply the ordered bijector)
# to constrained is monotonically increasing, i.e. order-preserving. In that case,
# we can form the ordered transformation as `binv ∘ OrderedBijector() ∘ b`.
# Similarly, if we're working with monotonically decreasing maps, we can do the same
# but with the addition of a sign flip before and after the ordered bijector.
b = bijector(d)
binv = inverse(b)
ordered_b = if is_monotonically_decreasing(binv)
SignFlip() ∘ inverse(OrderedBijector()) ∘ SignFlip() ∘ b
elseif is_monotonically_increasing(binv)
inverse(OrderedBijector()) ∘ b
else
throw(ArgumentError("ordered transform is currently not supported for $d."))
end

return OrderedDistribution(d, ordered_b)
end

bijector(d::OrderedDistribution) = d.transform

torfjelde marked this conversation as resolved.
Show resolved Hide resolved
Base.eltype(::Type{<:OrderedDistribution{D}}) where {D} = eltype(D)
Base.eltype(d::OrderedDistribution) = eltype(d.dist)
function Distributions._logpdf(d::OrderedDistribution, x::AbstractVector{<:Real})
lp = Distributions.logpdf(d.dist, x)
issorted(x) && return lp
return oftype(lp, -Inf)
end
Base.length(d::OrderedDistribution) = length(d.dist)

function Distributions._rand!(
rng::AbstractRNG, d::OrderedDistribution, x::AbstractVector{<:Real}
)
# Rejection sampling.
while true
Distributions.rand!(rng, d.dist, x)
issorted(x) && return x
end
end
3 changes: 3 additions & 0 deletions src/bijectors/scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,6 @@ _logabsdetjac_scale(a::AbstractVector, x::AbstractMatrix, ::Val{2}) = sum(log
# Matrix: single input.
_logabsdetjac_scale(a::AbstractMatrix, x::AbstractVector, ::Val{1}) = logabsdet(a)[1]
_logabsdetjac_scale(a::AbstractMatrix, x::AbstractMatrix, ::Val{2}) = logabsdet(a)[1]

is_monotonically_increasing(a::Scale) = all(Base.Fix1(>, 0), a.a)
is_monotonically_decreasing(a::Scale) = all(Base.Fix1(<, 0), a.a)
3 changes: 3 additions & 0 deletions src/bijectors/shift.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ _logabsdetjac_shift(a, x) = zero(eltype(x))
_logabsdetjac_shift_array_batch(a, x) = zeros(eltype(x), size(x, ndims(x)))

with_logabsdet_jacobian(b::Shift, x) = transform(b, x), logabsdetjac(b, x)

is_monotonically_increasing(::Shift) = true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this also need implementation of is_monotonically_decreasing ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep yep!

torfjelde marked this conversation as resolved.
Show resolved Hide resolved
is_monotonically_decreasing(::Shift) = true
33 changes: 33 additions & 0 deletions src/bijectors/truncated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,36 @@ function truncated_logabsdetjac(x, a, b)
end

with_logabsdet_jacobian(b::TruncatedBijector, x) = transform(b, x), logabsdetjac(b, x)

# It's only monotonically decreasing if it's only upper-bounded.
# In the multivariate case, we can only say something reasonable if entries are monotonic.
function is_monotonically_increasing(b::TruncatedBijector)
lowerbounded, upperbounded = all(isfinite, b.lb), all(isfinite, b.ub)
return if lowerbounded
true
elseif upperbounded
# => decreasing
false
elseif all(!isfinite, b.lb) && all(!isfinite, b.ub)
# => all are unbounded so we have the identity
true
else
# => some are unbounded and some are bounded
false
end
end
function is_monotonically_decreasing(b::TruncatedBijector)
lowerbounded, upperbounded = all(isfinite, b.lb), all(isfinite, b.ub)
return if lowerbounded
false
elseif upperbounded
# => decreasing
true
elseif all(!isfinite, b.lb) && all(!isfinite, b.ub)
# => all are unbounded so we have the identity
false
else
# => some are unbounded and some are bounded
true
end
end
63 changes: 63 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,69 @@ transform!(::typeof(identity), x, y) = copy!(y, x)
logabsdetjac(::typeof(identity), x) = zero(eltype(x))
logabsdetjac!(::typeof(identity), x, logjac) = logjac

###################
# Other utilities #
###################
"""
is_monotonically_increasing(f)

Returns `true` if `f` is monotonically increasing.
"""
is_monotonically_increasing(f) = false
is_monotonically_increasing(::typeof(identity)) = true
is_monotonically_increasing(binv::Inverse) = is_monotonically_increasing(inverse(binv))
is_monotonically_increasing(ef::Elementwise) = is_monotonically_increasing(ef.x)
function is_monotonically_increasing(cf::ComposedFunction)
# Here we have a few different cases:
#
# inner \ outer | inc | dec | other
# --------------+-----+-----+------
# inc | inc | dec | NA
# dec | dec | inc | NA
# other | NA | NA | NA
# --------------+-----+-----+------
#
# where `inc` means monotonically increasing, `dec` means monotonically decreasing,
# and `NA` means not applicable, i.e. we should return `false`.
return if is_monotonically_increasing(cf.inner)
is_monotonically_increasing(cf.outer)
elseif is_monotonically_decreasing(cf.inner)
is_monotonically_decreasing(cf.outer)
else
false
end
end

"""
is_monotonically_decreasing(f)

Returns `true` if `f` is monotonically decreasing.
"""
is_monotonically_decreasing(f) = false
is_monotonically_decreasing(::typeof(identity)) = false
is_monotonically_decreasing(binv::Inverse) = is_monotonically_decreasing(inverse(binv))
is_monotonically_decreasing(ef::Elementwise) = is_monotonically_decreasing(ef.x)
function is_monotonically_decreasing(cf::ComposedFunction)
# Here we have a few different cases:
#
# inner \ outer | inc | dec | other
# --------------+-----+-----+------
# inc | inc | dec | NA
# dec | dec | inc | NA
# other | NA | NA | NA
# --------------+-----+-----+------
#
# where `inc` means monotonically increasing, `dec` means monotonically decreasing,
# and `NA` means not applicable, i.e. we should return `false`.
return if is_monotonically_increasing(cf.inner)
is_monotonically_decreasing(cf.outer)
elseif is_monotonically_decreasing(cf.inner)
is_monotonically_increasing(cf.outer)
else
false
end
end

torfjelde marked this conversation as resolved.
Show resolved Hide resolved
######################
# Bijectors includes #
######################
Expand Down
8 changes: 8 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Expand All @@ -11,14 +13,18 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractMCMC = "5"
AdvancedHMC = "0.6"
ChainRulesTestUtils = "0.7, 1"
ChangesOfVariables = "0.1"
Combinatorics = "1.0.2"
Expand All @@ -30,7 +36,9 @@ ForwardDiff = "0.10.12"
Functors = "0.1, 0.2, 0.3, 0.4"
InverseFunctions = "0.1"
LazyArrays = "1, 2"
LogDensityProblems = "2"
LogExpFunctions = "0.3.1"
MCMCDiagnosticTools = "0.3"
ReverseDiff = "1.4.2"
Tracker = "0.2.11"
Zygote = "0.6.63"
Expand Down
52 changes: 46 additions & 6 deletions test/ad/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,25 @@
using ChainRulesTestUtils: ChainRulesCore

# HACK: This is a workaround to test `Bijectors._inv_link_chol_lkj` which produces an
# upper-triangular `Matrix`, leading to `test_rrule` comaring the _full_ `Matrix`,
# including the lower-triangular part which potentially contains `undef` entries.
# Here we simply wrap the rrule we want to test to also convert to PD form, thus
# avoiding any issues with the lower-triangular part.
function _inv_link_chol_lkj_wrapper(y)
W, logJ = Bijectors._inv_link_chol_lkj(y)
return Bijectors.pd_from_upper(W), logJ
end
function ChainRulesCore.rrule(::typeof(_inv_link_chol_lkj_wrapper), y::AbstractVector)
(W, logJ), back = ChainRulesCore.rrule(Bijectors._inv_link_chol_lkj, y)
X, back_X = ChainRulesCore.rrule(Bijectors.pd_from_upper, W)
function pullback_inv_link_chol_lkj_wrapper((ΔX, ΔlogJ))
(_, ΔW) = back_X(ChainRulesCore.unthunk(ΔX))
(_, Δy) = back((ΔW, ΔlogJ))
return (ChainRulesCore.NoTangent(), Δy)
end
return (X, logJ), pullback_inv_link_chol_lkj_wrapper
end

@testset "chainrules" begin
x = randn()
y = expm1(randn())
Expand All @@ -22,11 +44,29 @@

# LKJ and LKJCholesky bijector
dist = LKJCholesky(3, 4)
x = rand(dist)
test_rrule(Bijectors._link_chol_lkj_from_upper, x.U)
test_rrule(Bijectors._link_chol_lkj_from_lower, x.L)
# Run multiple tests because we're working with `undef` entries, and so we
# want to make sure that we hit cases where the `undef` entries have different values.
# It's also just useful to test numerical stability for different realizations of `dist`.
for i in 1:30
x = rand(dist)
test_rrule(
Bijectors._link_chol_lkj_from_upper,
x.U;
testset_name="_link_chol_lkj_from_upper on $(typeof(x)) [$i]",
)
test_rrule(
Bijectors._link_chol_lkj_from_lower,
x.L;
testset_name="_link_chol_lkj_from_lower on $(typeof(x)) [$i]",
)

b = bijector(dist)
y = b(x)

b = bijector(dist)
y = b(x)
test_rrule(Bijectors._inv_link_chol_lkj, y)
test_rrule(
_inv_link_chol_lkj_wrapper,
y;
testset_name="_inv_link_chol_lkj on $(typeof(x)) [$i]",
)
end
end
Loading
Loading