diff --git a/Project.toml b/Project.toml index a1af283e..93db8e41 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.11.1" - +version = "0.12.0" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/README.md b/README.md index 34abd42b..036efc8f 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ The following table lists mathematical operations for a bijector and the corresp In this table, `b` denotes a `Bijector`, `J(b, x)` denotes the Jacobian of `b` evaluated at `x`, `b_*` denotes the [push-forward](https://www.wikiwand.com/en/Pushforward_measure) of `p` by `b`, and `x ∼ p` denotes `x` sampled from the distribution with density `p`. -The "Automatic" column in the table refers to whether or not you are required to implement the feature for a custom `Bijector`. "AD" refers to the fact that it can be implemented "automatically" using automatic differentiation, i.e. `ADBijector`. +The "Automatic" column in the table refers to whether or not you are required to implement the feature for a custom `Bijector`. "AD" refers to the fact that this can be implemented "automatically" using automatic differentiation, e.g. ForwardDiff.jl. ## Functions diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 0bb967fb..9121ea70 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -64,11 +64,9 @@ export TransformDistribution, logabsdetjac!, logabsdetjacinv, Bijector, - ADBijector, Inverse, Stacked, stack, - Identity, bijector, transformed, UnivariateTransformed, diff --git a/src/bijectors/adbijector.jl b/src/bijectors/adbijector.jl deleted file mode 100644 index 7c62c4ac..00000000 --- a/src/bijectors/adbijector.jl +++ /dev/null @@ -1,29 +0,0 @@ -""" -Abstract type for a `Bijector` making use of auto-differentation (AD) to -implement `jacobian` and, by impliciation, `logabsdetjac`. -""" -abstract type ADBijector{AD} <: Bijector end - -struct SingularJacobianException{B<:Bijector} <: Exception - b::B -end -function Base.showerror(io::IO, e::SingularJacobianException) - print(io, "jacobian of $(e.b) is singular") -end - -# concrete implementations with optional dependencies ForwardDiff and Tracker -function jacobian end - -# TODO: allow batch-computation, especially for univariate case? -"Computes the absolute determinant of the Jacobian of the inverse-transformation." -function logabsdetjac(b::ADBijector, x::Real) - res = log(abs(jacobian(b, x))) - return isfinite(res) ? res : throw(SingularJacobianException(b)) -end - -function logabsdetjac(b::ADBijector, x::AbstractVector{<:Real}) - fact = lu(jacobian(b, x), check=false) - return issuccess(fact) ? logabsdet(fact)[1] : throw(SingularJacobianException(b)) -end - -with_logabsdet_jacobian(b::ADBijector, x) = (b(x), logabsdetjac(b, x)) diff --git a/src/bijectors/ordered.jl b/src/bijectors/ordered.jl index 5d55fd82..9bc172b2 100644 --- a/src/bijectors/ordered.jl +++ b/src/bijectors/ordered.jl @@ -17,10 +17,10 @@ Return a `Distribution` whose support are ordered vectors, i.e., vectors with in This transformation is currently only supported for otherwise unconstrained distributions. """ function ordered(d::ContinuousMultivariateDistribution) - if !isa(bijector(d), Identity) + if bijector(d) !== identity throw(ArgumentError("ordered transform is currently only supported for unconstrained distributions.")) end - return Bijectors.transformed(d, OrderedBijector()) + return transformed(d, OrderedBijector()) end with_logabsdet_jacobian(b::OrderedBijector, x) = transform(b, x), logabsdetjac(b, x) diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index eec82eab..f8fdada9 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -16,7 +16,7 @@ where `bs[i]::Bijector` is applied to `x[ranges[i]]::UnitRange{Int}`. # Examples ``` b1 = Logit(0.0, 1.0) -b2 = Identity() +b2 = identity b = stack(b1, b2) b([0.0, 1.0]) == [b1(0.0), 1.0] # => true ``` diff --git a/src/compat/distributionsad.jl b/src/compat/distributionsad.jl index 8abc5b36..080097f3 100644 --- a/src/compat/distributionsad.jl +++ b/src/compat/distributionsad.jl @@ -9,22 +9,22 @@ using Distributions: AbstractMvLogNormal bijector(::TuringDirichlet) = SimplexBijector() bijector(::TuringWishart) = PDBijector() bijector(::TuringInverseWishart) = PDBijector() -bijector(::TuringScalMvNormal) = Identity() -bijector(::TuringDiagMvNormal) = Identity() -bijector(::TuringDenseMvNormal) = Identity() +bijector(::TuringScalMvNormal) = identity +bijector(::TuringDiagMvNormal) = identity +bijector(::TuringDenseMvNormal) = identity bijector(d::FillVectorOfUnivariate{Continuous}) = bijector(d.v.value) bijector(d::FillMatrixOfUnivariate{Continuous}) = up1(bijector(d.dists.value)) -bijector(d::MatrixOfUnivariate{Discrete}) = Identity() +bijector(d::MatrixOfUnivariate{Discrete}) = identity bijector(d::MatrixOfUnivariate{Continuous}) = TruncatedBijector(_minmax(d.dists)...) -bijector(d::VectorOfMultivariate{Discrete}) = Identity() +bijector(d::VectorOfMultivariate{Discrete}) = identity for T in (:VectorOfMultivariate, :FillVectorOfMultivariate) @eval begin - bijector(d::$T{Continuous, <:MvNormal}) = Identity() - bijector(d::$T{Continuous, <:TuringScalMvNormal}) = Identity() - bijector(d::$T{Continuous, <:TuringDiagMvNormal}) = Identity() - bijector(d::$T{Continuous, <:TuringDenseMvNormal}) = Identity() - bijector(d::$T{Continuous, <:MvNormalCanon}) = Identity() + bijector(d::$T{Continuous, <:MvNormal}) = identity + bijector(d::$T{Continuous, <:TuringScalMvNormal}) = identity + bijector(d::$T{Continuous, <:TuringDiagMvNormal}) = identity + bijector(d::$T{Continuous, <:TuringDenseMvNormal}) = identity + bijector(d::$T{Continuous, <:MvNormalCanon}) = identity bijector(d::$T{Continuous, <:AbstractMvLogNormal}) = Log() bijector(d::$T{Continuous, <:SimplexDistribution}) = SimplexBijector() bijector(d::$T{Continuous, <:TuringDirichlet}) = SimplexBijector() diff --git a/src/compat/forwarddiff.jl b/src/compat/forwarddiff.jl index ddb58f34..a947cf18 100644 --- a/src/compat/forwarddiff.jl +++ b/src/compat/forwarddiff.jl @@ -3,20 +3,6 @@ import .ForwardDiff _eps(::Type{<:ForwardDiff.Dual{<:Any, Real}}) = _eps(Real) _eps(::Type{<:ForwardDiff.Dual{<:Any, <:Integer}}) = _eps(Real) -# AD implementations -function jacobian( - b::Union{<:ADBijector{<:ForwardDiffAD}, Inverse{<:ADBijector{<:ForwardDiffAD}}}, - x::Real -) - return ForwardDiff.derivative(b, x) -end -function jacobian( - b::Union{<:ADBijector{<:ForwardDiffAD}, Inverse{<:ADBijector{<:ForwardDiffAD}}}, - x::AbstractVector{<:Real} -) - return ForwardDiff.jacobian(b, x) -end - # Define forward-mode rule for ForwardDiff and don't trust support for ForwardDiff in Roots # https://github.com/JuliaMath/Roots.jl/issues/314 function find_alpha( diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index 86130418..c498205a 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -5,8 +5,7 @@ using ..ReverseDiff: ReverseDiff, @grad, value, track, TrackedReal, TrackedVecto using Requires, LinearAlgebra using ..Bijectors: Elementwise, SimplexBijector, maphcat, simplex_link_jacobian, - simplex_invlink_jacobian, simplex_logabsdetjac_gradient, ADBijector, - ReverseDiffAD, Inverse + simplex_invlink_jacobian, simplex_logabsdetjac_gradient, Inverse import ..Bijectors: _eps, logabsdetjac, _logabsdetjac_scale, _simplex_bijector, _simplex_inv_bijector, replace_diag, jacobian, getpd, lower, _inv_link_chol_lkj, _link_chol_lkj, _transform_ordered, _transform_inverse_ordered, @@ -17,20 +16,6 @@ import ChainRulesCore using Compat: eachcol using Distributions: LocationScale -# AD implementations -function jacobian( - b::Union{<:ADBijector{<:ReverseDiffAD}, Inverse{<:ADBijector{<:ReverseDiffAD}}}, - x::Real -) - return ReverseDiff.gradient(x -> b(x[1]), [x])[1] -end -function jacobian( - b::Union{<:ADBijector{<:ReverseDiffAD}, Inverse{<:ADBijector{<:ReverseDiffAD}}}, - x::AbstractVector{<:Real} -) - return ReverseDiff.jacobian(b, x) -end - _eps(::Type{<:TrackedReal{T}}) where {T} = _eps(T) function Base.minimum(d::LocationScale{<:TrackedReal}) m = minimum(d.ρ) diff --git a/src/compat/tracker.jl b/src/compat/tracker.jl index 53727813..1166a29e 100644 --- a/src/compat/tracker.jl +++ b/src/compat/tracker.jl @@ -12,8 +12,7 @@ using ..Tracker: Tracker, param import ..Bijectors -using ..Bijectors: Elementwise, SimplexBijector, ADBijector, - TrackerAD, Inverse, Stacked +using ..Bijectors: Elementwise, SimplexBijector, Inverse, Stacked import ChainRulesCore import LogExpFunctions @@ -49,21 +48,6 @@ function Base.maximum(d::LocationScale{<:TrackedReal}) end end -# AD implementations -function Bijectors.jacobian( - b::Union{<:ADBijector{<:TrackerAD}, Inverse{<:ADBijector{<:TrackerAD}}}, - x::Real -) - return data(Tracker.gradient(b, x)[1]) -end -function Bijectors.jacobian( - b::Union{<:ADBijector{<:TrackerAD}, Inverse{<:ADBijector{<:TrackerAD}}}, - x::AbstractVector{<:Real} -) - # We extract `data` so that we don't return a `Tracked` type - return data(Tracker.jacobian(b, x)) -end - # implementations for Shift bijector function Bijectors._logabsdetjac_shift(a::TrackedReal, x::Real, ::Val{0}) return tracker_shift_logabsdetjac(a, x, Val(0)) diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl index 0b19467b..eedf4b3d 100644 --- a/src/compat/zygote.jl +++ b/src/compat/zygote.jl @@ -32,18 +32,6 @@ end end # AD implementations -function jacobian( - b::Union{<:ADBijector{<:ZygoteAD}, Inverse{<:ADBijector{<:ZygoteAD}}}, - x::Real -) - return Zygote.gradient(b, x)[1] -end -function jacobian( - b::Union{<:ADBijector{<:ZygoteAD}, Inverse{<:ADBijector{<:ZygoteAD}}}, - x::AbstractVector{<:Real} -) - return Zygote.jacobian(b, x) -end @adjoint function _logabsdetjac_scale(a::Real, x::Real, ::Val{0}) return _logabsdetjac_scale(a, x, Val(0)), Δ -> (inv(a) .* Δ, nothing, nothing) end diff --git a/src/interface.jl b/src/interface.jl index 5487d01a..ea10d2f0 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -18,31 +18,6 @@ elementwise(f) = Base.Fix1(broadcast, f) # the way to go. elementwise(f::ComposedFunction) = ComposedFunction(elementwise(f.outer), elementwise(f.inner)) -####################################### -# AD stuff "extracted" from Turing.jl # -####################################### - -abstract type ADBackend end -struct ForwardDiffAD <: ADBackend end -struct ReverseDiffAD <: ADBackend end -struct TrackerAD <: ADBackend end -struct ZygoteAD <: ADBackend end - -const ADBACKEND = Ref(:forwarddiff) -setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym)) -setadbackend(::Val{:forwarddiff}) = ADBACKEND[] = :forwarddiff -setadbackend(::Val{:reversediff}) = ADBACKEND[] = :reversediff -setadbackend(::Val{:tracker}) = ADBACKEND[] = :tracker -setadbackend(::Val{:zygote}) = ADBACKEND[] = :zygote - -ADBackend() = ADBackend(ADBACKEND[]) -ADBackend(T::Symbol) = ADBackend(Val(T)) -ADBackend(::Val{:forwarddiff}) = ForwardDiffAD -ADBackend(::Val{:reversediff}) = ReverseDiffAD -ADBackend(::Val{:tracker}) = TrackerAD -ADBackend(::Val{:zygote}) = ZygoteAD -ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.") - ###################### # Bijector interface # ###################### @@ -197,12 +172,8 @@ Just an alias for `logabsdetjac(inverse(b), y)`. logabsdetjacinv(b, y) = logabsdetjac(inverse(b), y) ############################## -# Example bijector: Identity # +# Example bijector: identity # ############################## -Identity() = identity - -# Here we don't need to separate between batched version and non-batched, and so -# we can just overload `transform`, etc. directly. transform(::typeof(identity), x) = copy(x) transform!(::typeof(identity), x, y) = copy!(y, x) @@ -213,7 +184,6 @@ logabsdetjac!(::typeof(identity), x, logjac) = logjac # Bijectors includes # ###################### # General -include("bijectors/adbijector.jl") include("bijectors/composed.jl") include("bijectors/stacked.jl") diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index e515d65d..3b21bf4e 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -35,11 +35,14 @@ transformed(d) = transformed(d, bijector(d)) Returns the constrained-to-unconstrained bijector for distribution `d`. """ -bijector(td::TransformedDistribution) = bijector(td.dist) ∘ inverse(td.transform) -bijector(d::DiscreteUnivariateDistribution) = Identity() -bijector(d::DiscreteMultivariateDistribution) = Identity() +function bijector(td::TransformedDistribution) + b = bijector(td.dist) + return b === identity ? inverse(td.transform) : b ∘ inverse(td.transform) +end +bijector(d::DiscreteUnivariateDistribution) = identity +bijector(d::DiscreteMultivariateDistribution) = identity bijector(d::ContinuousUnivariateDistribution) = TruncatedBijector(minimum(d), maximum(d)) -bijector(d::Product{Discrete}) = Identity() +bijector(d::Product{Discrete}) = identity function bijector(d::Product{Continuous}) return TruncatedBijector(_minmax(d.v)...) end @@ -52,8 +55,8 @@ end end end -bijector(d::Normal) = Identity() -bijector(d::Distributions.AbstractMvNormal) = Identity() +bijector(d::Normal) = identity +bijector(d::Distributions.AbstractMvNormal) = identity bijector(d::Distributions.AbstractMvLogNormal) = elementwise(log) bijector(d::PositiveDistribution) = elementwise(log) bijector(d::SimplexDistribution) = SimplexBijector() diff --git a/test/bijectors/ordered.jl b/test/bijectors/ordered.jl index 35826b33..0f9802e7 100644 --- a/test/bijectors/ordered.jl +++ b/test/bijectors/ordered.jl @@ -23,7 +23,7 @@ end @test d_ordered.dist === d @test d_ordered.transform isa OrderedBijector y = randn(5) - x = inv(bijector(d_ordered))(y) + x = inverse(bijector(d_ordered))(y) @test issorted(x) d = Product(fill(Normal(), 5)) diff --git a/test/interface.jl b/test/interface.jl index e975e486..ac5d3b94 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -10,34 +10,14 @@ using Tracker using DistributionsAD using Bijectors -using Bijectors: Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, ADBijector, RationalQuadraticSpline, LeakyReLU +using Bijectors: Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, RationalQuadraticSpline, LeakyReLU Random.seed!(123) -struct MyADBijector{AD,B} <: ADBijector{AD} - b::B -end -MyADBijector(d::Distribution) = MyADBijector{Bijectors.ADBackend()}(d) -MyADBijector{AD}(d::Distribution) where {AD} = MyADBijector{AD}(bijector(d)) -MyADBijector{AD}(b) where {AD} = MyADBijector{AD, typeof(b)}(b) -(b::MyADBijector)(x) = b.b(x) -Bijectors.transform(b::MyADBijector, x) = b.b(x) -Bijectors.transform(b::Inverse{<:MyADBijector}, x) = inverse(b.orig.b)(x) - -struct NonInvertibleBijector{AD} <: ADBijector{AD} end - contains(predicate::Function, b::Bijector) = predicate(b) contains(predicate::Function, b::ComposedFunction) = any(contains.(predicate, b.ts)) contains(predicate::Function, b::Stacked) = any(contains.(predicate, b.bs)) -# Scalar tests -@testset "<: ADBijector{AD}" begin - (b::NonInvertibleBijector)(x) = clamp.(x, 0, 1) - - b = NonInvertibleBijector{Bijectors.ADBackend()}() - @test_throws Bijectors.SingularJacobianException logabsdetjac(b, [1.0, 10.0]) -end - @testset "Univariate" begin # Tests with scalar-valued distributions. uni_dists = [ @@ -104,32 +84,6 @@ end @test log(abs(ForwardDiff.derivative(b, x))) ≈ logabsdetjac(b, x) atol=1e-6 @test log(abs(ForwardDiff.derivative(inverse(b), y))) ≈ logabsdetjac(inverse(b), y) atol=1e-6 end - - @testset "$dist: ForwardDiff AD" begin - x = rand(dist) - b = MyADBijector{Bijectors.ADBackend(:forwarddiff)}(dist) - - @test abs(det(Bijectors.jacobian(b, x))) > 0 - @test logabsdetjac(b, x) ≠ Inf - - y = b(x) - b⁻¹ = inverse(b) - @test abs(det(Bijectors.jacobian(b⁻¹, y))) > 0 - @test logabsdetjac(b⁻¹, y) ≠ Inf - end - - @testset "$dist: Tracker AD" begin - x = rand(dist) - b = MyADBijector{Bijectors.ADBackend(:reversediff)}(dist) - - @test abs(det(Bijectors.jacobian(b, x))) > 0 - @test logabsdetjac(b, x) ≠ Inf - - y = b(x) - b⁻¹ = inverse(b) - @test abs(det(Bijectors.jacobian(b⁻¹, y))) > 0 - @test logabsdetjac(b⁻¹, y) ≠ Inf - end end end @@ -266,26 +220,6 @@ end @test logabsdetjac(sb2, [x, x, y, y]) ≈ 0.0 atol=1e-12 @test res2[2] ≈ 0.0 atol=1e-12 - # `logabsdetjac` with AD - b = MyADBijector(d) - y = b(x) - - sb1 = stack(b, b, inverse(b), inverse(b)) # <= Tuple - res1 = with_logabsdet_jacobian(sb1, [x, x, y, y]) - @test sb1(param([x, x, y, y])) isa TrackedArray - - @test sb1([x, x, y, y]) == res1[1] - @test logabsdetjac(sb1, [x, x, y, y]) ≈ 0 atol=1e-12 - @test res1[2] ≈ 0.0 atol=1e-12 - - sb2 = Stacked([b, b, inverse(b), inverse(b)]) # <= Array - res2 = with_logabsdet_jacobian(sb2, [x, x, y, y]) - @test sb2(param([x, x, y, y])) isa TrackedArray - - @test sb2([x, x, y, y]) == res2[1] - @test logabsdetjac(sb2, [x, x, y, y]) ≈ 0.0 atol=1e-12 - @test res2[2] ≈ 0.0 atol=1e-12 - # value-test x = ones(3) sb = @inferred stack(elementwise(exp), elementwise(log), Shift(5.0)) @@ -479,7 +413,7 @@ end @testset "Equality" begin bs = [ - Identity(), + identity, elementwise(exp), elementwise(log), Scale(2.0), diff --git a/test/runtests.jl b/test/runtests.jl index 7fcd6dfc..3185bf4c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,7 +14,7 @@ using Zygote using Random, LinearAlgebra, Test using Bijectors: Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, - PlanarLayer, RadialLayer, Stacked, TruncatedBijector, ADBijector + PlanarLayer, RadialLayer, Stacked, TruncatedBijector using ChangesOfVariables: ChangesOfVariables using InverseFunctions: InverseFunctions