From d8ebcd4a6e8e638716bb830ad28ac870e58d1ebc Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 2 Feb 2023 19:19:06 +0100 Subject: [PATCH] Throw error if `ordered` unsafe to use (#241) * Throw error if transform not supported * Document requirements of ordered * Add ordered test * Update Project.toml --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- Project.toml | 3 ++- src/bijectors/ordered.jl | 9 ++++++++- test/bijectors/ordered.jl | 22 +++++++++++++++++++++- 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index ec3b8d34..03ef97e8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,7 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.11.0" +version = "0.11.1" + [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/src/bijectors/ordered.jl b/src/bijectors/ordered.jl index 2a43a661..5d55fd82 100644 --- a/src/bijectors/ordered.jl +++ b/src/bijectors/ordered.jl @@ -13,8 +13,15 @@ struct OrderedBijector <: Bijector end ordered(d::Distribution) Return a `Distribution` whose support are ordered vectors, i.e., vectors with increasingly ordered elements. + +This transformation is currently only supported for otherwise unconstrained distributions. """ -ordered(d::ContinuousMultivariateDistribution) = Bijectors.transformed(d, OrderedBijector()) +function ordered(d::ContinuousMultivariateDistribution) + if !isa(bijector(d), Identity) + throw(ArgumentError("ordered transform is currently only supported for unconstrained distributions.")) + end + return Bijectors.transformed(d, OrderedBijector()) +end with_logabsdet_jacobian(b::OrderedBijector, x) = transform(b, x), logabsdetjac(b, x) diff --git a/test/bijectors/ordered.jl b/test/bijectors/ordered.jl index 058ee77d..35826b33 100644 --- a/test/bijectors/ordered.jl +++ b/test/bijectors/ordered.jl @@ -1,4 +1,5 @@ -import Bijectors: OrderedBijector +import Bijectors: OrderedBijector, ordered +using LinearAlgebra @testset "OrderedBijector" begin b = OrderedBijector() @@ -14,3 +15,22 @@ import Bijectors: OrderedBijector y = b(x) @test sort(y) == y end + +@testset "ordered" begin + d = MvNormal(1:5, Diagonal(6:10)) + d_ordered = ordered(d) + @test d_ordered isa Bijectors.TransformedDistribution + @test d_ordered.dist === d + @test d_ordered.transform isa OrderedBijector + y = randn(5) + x = inv(bijector(d_ordered))(y) + @test issorted(x) + + d = Product(fill(Normal(), 5)) + # currently errors because `bijector(Product(fill(Normal(), 5)))` is not an `Identity` + @test_broken ordered(d) isa Bijectors.TransformedDistribution + + # non-Identity bijector is not supported + d = Dirichlet(ones(5)) + @test_throws ArgumentError ordered(d) +end