From 8b924d0f091377bc190ed50359213043b14b4d37 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 1 Feb 2023 23:47:50 +0100 Subject: [PATCH] Widening the scope of the package and dropping support for batching (#214) * renamed rv to result in forward * added abstrac type Transform and removed dimensionality from Bijector * updated Composed to new interface * updated Exp and Log to new interface * updated Logit to new interface * removed something that shouldnt be there * removed false statement in docstring of Transform * fixed a typo in implementation of logabsdetjac_batch * added types for representing batches * make it possible to use broadcasting for working with batches * updated SimplexBijector to new interface, I think * updated PDBijector to new interface * use transform_batch rather than broadcasting * added default implementations for batches * updated ADBijector to new interface * updated CorrBijector to new interface * updated Coupling to new interface * updated LeakyReLU to new interface * updated NamedBijector to new interface * updated BatchNormalisation to new interface * updated Permute to new interface * updated PlanarLayer to new interface * updated RadialLayer to new interface * updated RationalQuadraticSpline to new interface * updated Scale to new interface * updated Shift to new interface * updated Stacked to new interface * updated TruncatedBijector to new interface * added ConstructionBase as dependency * fixed a bunch of small typos and errors from previous commits * forgot to wrap some in Batch * allow inverses of non-bijectors * relax definition of VectorBatch so Vector{<:Real} is covered * just perform invertibility check in Inverse rather than inv * moved some code arround * added docstrings and default impls for mutating batched methods * add elementype to VectorBatch * simplify Shift bijector * added rrules for logabsdetjac_shift * use type-stable implementation of eachslice * initial work on adding proper testing * make Batch compatible with Zygote * updated OrderedBijector * temporary stuff * added docs * removed all batch related functionality * move bijectors over to with_logabsdet_jacobian and drop official batch support * updated compat * updated tests * updated docs * removed reundndat dep * remove batch * remove redundant defs of transform * removed unnecessary impls of with_logabsdet_jacobian * remove usage of Exp and Log in tests * fixed docs * added bijectors with docs to docs * small change to docs * fixed bug in computation of logabsdetjac of truncated * bump minor version * run GH actions on Julia 1.6, which is the new LTS, instead of 1.3 * added Github actions for making docs, etc. * removed left-overs from batch impls * removed redundant comment * dont return NamedTuple from with_logabsdet_jacobian * remove unnused methods * remove old deprecation warnings * fix exports * updated tests for deprecations * completed some random TODOs * fix SimplexBijector tests * removed whitespace * made some docstrings into doctests * removed unnused method * improved show for scale and shift * converted example for Coupling into doctest * added reference to Coupling bijector for NamedCoupling * fixed docstring * fixed documentation setup * nvm, now I fixed documentation setup * removed references to dimensionality in code * fixed typo * add impl of invertible for Elementwise * added transforms and distributions as separate pages in docs * removed all the unnecessary stuff in README * added examples to docs * added some show methods for certain bijectors * added compat entries to docs * updated docstring for RationalQuadraticSpline * removed commented code * remove reference to logpdf_forward * remove enforcement of type of input and output being the same in tests * make logpdf_with_trans compatible with logpdf when it comes to handling batches * Apply suggestions from code review Co-authored-by: David Widmann * remove usage of invertible, etc. and use InverseFunctions.NoInverse instead * specialze transform on Function * removed unnecessary show and deprecation warnings * remove references to Log and Exp --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> Co-authored-by: David Widmann --- .github/workflows/DocsPreviewCleanup.yml | 26 ++ Project.toml | 2 +- README.md | 252 +----------- docs/make.jl | 2 +- docs/src/transforms.md | 103 +++++ src/Bijectors.jl | 37 +- src/bijectors/adbijector.jl | 8 +- src/bijectors/composed.jl | 242 +----------- src/bijectors/corr.jl | 20 +- src/bijectors/coupling.jl | 58 ++- src/bijectors/exp_log.jl | 52 +-- src/bijectors/leaky_relu.jl | 96 +---- src/bijectors/logit.jl | 39 +- src/bijectors/named_bijector.jl | 199 +++------- src/bijectors/normalise.jl | 21 +- src/bijectors/ordered.jl | 9 +- src/bijectors/pd.jl | 12 +- src/bijectors/permute.jl | 11 +- src/bijectors/planar_layer.jl | 10 +- src/bijectors/radial_layer.jl | 14 +- src/bijectors/rational_quadratic_spline.jl | 77 ++-- src/bijectors/scale.jl | 63 +-- src/bijectors/shift.jl | 37 +- src/bijectors/simplex.jl | 115 ++---- src/bijectors/stacked.jl | 41 +- src/bijectors/truncated.jl | 108 +----- src/compat/distributionsad.jl | 34 +- src/compat/reversediff.jl | 39 +- src/compat/tracker.jl | 75 +--- src/compat/zygote.jl | 23 +- src/interface.jl | 204 ++++++---- src/transformed_distribution.jl | 171 +-------- test/ad/flows.jl | 12 +- test/ad/utils.jl | 7 +- test/bijectors/coupling.jl | 10 +- test/bijectors/leaky_relu.jl | 96 ++--- test/bijectors/named_bijector.jl | 22 +- test/bijectors/ordered.jl | 10 +- test/bijectors/rational_quadratic_spline.jl | 12 +- test/bijectors/utils.jl | 242 ++++-------- test/interface.jl | 405 ++------------------ test/norm_flows.jl | 2 +- test/runtests.jl | 6 +- test/transform.jl | 62 +-- 44 files changed, 851 insertions(+), 2235 deletions(-) create mode 100644 .github/workflows/DocsPreviewCleanup.yml create mode 100644 docs/src/transforms.md diff --git a/.github/workflows/DocsPreviewCleanup.yml b/.github/workflows/DocsPreviewCleanup.yml new file mode 100644 index 00000000..4f57bc46 --- /dev/null +++ b/.github/workflows/DocsPreviewCleanup.yml @@ -0,0 +1,26 @@ +name: DocsPreviewCleanup + +on: + pull_request: + types: [closed] + +jobs: + cleanup: + runs-on: ubuntu-latest + steps: + - name: Checkout gh-pages branch + uses: actions/checkout@v2 + with: + ref: gh-pages + - name: Delete preview and history + push changes + run: | + if [ -d "previews/PR$PRNUM" ]; then + git config user.name "Documenter.jl" + git config user.email "documenter@juliadocs.github.io" + git rm -rf "previews/PR$PRNUM" + git commit -m "delete preview" + git branch gh-pages-new $(echo "delete history" | git commit-tree HEAD^{tree}) + git push --force origin gh-pages-new:gh-pages + fi + env: + PRNUM: ${{ github.event.number }} diff --git a/Project.toml b/Project.toml index 938d4735..ec3b8d34 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.10.6" +version = "0.11.0" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/README.md b/README.md index 14fba13a..34abd42b 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ # Bijectors.jl +[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://turinglang.github.io/Bijectors.jl/stable) [![Interface tests](https://github.com/TuringLang/Bijectors.jl/workflows/Interface%20tests/badge.svg?branch=master)](https://github.com/TuringLang/Bijectors.jl/actions?query=workflow%3A%22Interface+tests%22+branch%3Amaster) [![AD tests](https://github.com/TuringLang/Bijectors.jl/workflows/AD%20tests/badge.svg?branch=master)](https://github.com/TuringLang/Bijectors.jl/actions?query=workflow%3A%22AD+tests%22+branch%3Amaster) @@ -135,19 +136,6 @@ true Pretty neat, huh? `Inverse{Logit}` is also a `Bijector` where we've defined `(ib::Inverse{<:Logit})(y)` as the inverse transformation of `(b::Logit)(x)`. Note that it's not always the case that `inverse(b) isa Inverse`, e.g. the inverse of `Exp` is simply `Log` so `inverse(Exp()) isa Log` is true. -#### Dimensionality -One more thing. See the `0` in `Inverse{Logit{Float64}, 0}`? It represents the *dimensionality* of the bijector, in the same sense as for an `AbstractArray` with the exception of `0` which means it expects 0-dim input and output, i.e. `<:Real`. This can also be accessed through `dimension(b)`: - -```julia -julia> Bijectors.dimension(b) -0 - -julia> Bijectors.dimension(Exp{1}()) -1 -``` - -In most cases specification of the dimensionality is unnecessary as a `Bijector{N}` is usually only defined for a particular value of `N`, e.g. `Logit isa Bijector{0}` since it only makes sense to apply `Logit` to a real number (or a vector of reals if you're doing batch-computation). As a user, you'll rarely have to deal with this dimensionality specification. Unfortunately there are exceptions, e.g. `Exp` which can be applied to both real numbers and a vector of real numbers, in both cases treating it as a single input. This means that when `Exp` receives a vector input `x` as input, it's ambiguous whether or not to treat `x` as a *batch* of 0-dim inputs or as a single 1-dim input. As a result, to support batch-computation it is necessary to know the expected dimensionality of the input and output. Notice that we assume the dimensionality of the input and output to be the *same*. This is a reasonable assumption considering we're working with *bijections*. - #### Composition Also, we can _compose_ bijectors: @@ -491,244 +479,6 @@ julia> x, y, logjac, logpdf_y = forward(flow) # sample + transform and returns a This method is for example useful when computing quantities such as the _expected lower bound (ELBO)_ between this transformed distribution and some other joint density. If no analytical expression is available, we have to approximate the ELBO by a Monte Carlo estimate. But one term in the ELBO is the entropy of the base density, which we _do_ know analytically in this case. Using the analytical expression for the entropy and then using a monte carlo estimate for the rest of the terms in the ELBO gives an estimate with lower variance than if we used the monte carlo estimate for the entire expectation. -### Normalizing flows with bounded support - - -## Implementing your own `Bijector` -There's mainly two ways you can implement your own `Bijector`, and which way you choose mainly depends on the following question: are you bothered enough to manually implement `logabsdetjac`? If the answer is "Yup!", then you subtype from `Bijector`, if "Naaaah" then you subtype `ADBijector`. - -### `<:Bijector` -Here's a simple example taken from the source code, the `Identity`: - -```julia -import Bijectors: logabsdetjac - -struct Identity{N} <: Bijector{N} end -(::Identity)(x) = x # transform itself, "forward" -(::Inverse{<: Identity})(y) = y # inverse tramsform, "backward" - -# see the proper implementation for `logabsdetjac` in general -logabsdetjac(::Identity{0}, y::Real) = zero(eltype(y)) # ∂ₓid(x) = ∂ₓ x = 1 → log(abs(1)) = log(1) = 0 -``` - -A slightly more complex example is `Logit`: - -```julia -using LogExpFunctions: logit, logistic - -struct Logit{T<:Real} <: Bijector{0} - a::T - b::T -end - -(b::Logit)(x::Real) = logit((x - b.a) / (b.b - b.a)) -(b::Logit)(x) = map(b, x) -# `orig` contains the `Bijector` which was inverted -(ib::Inverse{<:Logit})(y::Real) = (ib.orig.b - ib.orig.a) * logistic(y) + ib.orig.a -(ib::Inverse{<:Logit})(y) = map(ib, y) - -logabsdetjac(b::Logit, x::Real) = - log((x - b.a) * (b.b - x) / (b.b - b.a)) -logabsdetjac(b::Logit, x) = map(logabsdetjac, x) -``` - -(Batch computation is not fully supported by all bijectors yet (see issue #35), but is actively worked on. In the particular case of `Logit` there's only one thing that makes sense, which is elementwise application. Therefore we've added `@.` to the implementation above, thus this works for any `AbstractArray{<:Real}`.) - -Then - -```julia -julia> b = Logit(0.0, 1.0) -Logit{Float64}(0.0, 1.0) - -julia> b(0.6) -0.4054651081081642 - -julia> inverse(b)(y) -Tracked 2-element Array{Float64,1}: - 0.3078149833748082 - 0.72380041667891 - -julia> logabsdetjac(b, 0.6) -1.4271163556401458 - -julia> logabsdetjac(inverse(b), y) # defaults to `- logabsdetjac(b, inverse(b)(x))` -Tracked 2-element Array{Float64,1}: - -1.546158373866469 - -1.6098711387913573 - -julia> with_logabsdet_jacobian(b, 0.6) # defaults to `(b(x), logabsdetjac(b, x))` -(0.4054651081081642, 1.4271163556401458) -``` - -For further efficiency, one could manually implement `with_logabsdet_jacobian(b::Logit, x)`: - -```julia -julia> using Bijectors: Logit - -julia> import Bijectors: with_logabsdet_jacobian - -julia> function with_logabsdet_jacobian(b::Logit{<:Real}, x) - totally_worth_saving = @. (x - b.a) / (b.b - b.a) # spoiler: it's probably not - y = logit.(totally_worth_saving) - logjac = @. - log((b.b - x) * totally_worth_saving) - return (y, logjac) - end -forward (generic function with 16 methods) - -julia> with_logabsdet_jacobian(b, 0.6) -(0.4054651081081642, 1.4271163556401458) - -julia> @which with_logabsdet_jacobian(b, 0.6) -with_logabsdet_jacobian(b::Logit{#s4} where #s4<:Real, x) in Main at REPL[43]:2 -``` - -As you can see it's a very contrived example, but you get the idea. - -### `<:ADBijector` - -We could also have implemented `Logit` as an `ADBijector`: - -```julia -using LogExpFunctions: logit, logistic -using Bijectors: ADBackend - -struct ADLogit{T, AD} <: ADBijector{AD, 0} - a::T - b::T -end - -# ADBackend() returns ForwardDiffAD, which means we use ForwardDiff.jl for AD -ADLogit(a::T, b::T) where {T<:Real} = ADLogit{T, ADBackend()}(a, b) - -(b::ADLogit)(x) = @. logit((x - b.a) / (b.b - b.a)) -(ib::Inverse{<:ADLogit{<:Real}})(y) = @. (ib.orig.b - ib.orig.a) * logistic(y) + ib.orig.a -``` - -No implementation of `logabsdetjac`, but: - -```julia -julia> b_ad = ADLogit(0.0, 1.0) -ADLogit{Float64,Bijectors.ForwardDiffAD}(0.0, 1.0) - -julia> logabsdetjac(b_ad, 0.6) -1.4271163556401458 - -julia> y = b_ad(0.6) -0.4054651081081642 - -julia> inverse(b_ad)(y) -0.6 - -julia> logabsdetjac(inverse(b_ad), y) --1.4271163556401458 -``` - -Neat! And just to verify that everything works: - -```julia -julia> b = Logit(0.0, 1.0) -Logit{Float64}(0.0, 1.0) - -julia> logabsdetjac(b, 0.6) -1.4271163556401458 - -julia> logabsdetjac(b_ad, 0.6) ≈ logabsdetjac(b, 0.6) -true -``` - -We can also use Tracker.jl for the AD, rather than ForwardDiff.jl: - -```julia -julia> Bijectors.setadbackend(:reversediff) -:reversediff - -julia> b_ad = ADLogit(0.0, 1.0) -ADLogit{Float64,Bijectors.TrackerAD}(0.0, 1.0) - -julia> logabsdetjac(b_ad, 0.6) -1.4271163556401458 -``` - - -### Reference -Most of the methods and types mention below will have docstrings with more elaborate explanation and examples, e.g. -```julia -help?> Bijectors.Composed - Composed(ts::A) - - ∘(b1::Bijector{N}, b2::Bijector{N})::Composed{<:Tuple} - composel(ts::Bijector{N}...)::Composed{<:Tuple} - composer(ts::Bijector{N}...)::Composed{<:Tuple} - - where A refers to either - - • Tuple{Vararg{<:Bijector{N}}}: a tuple of bijectors of dimensionality N - - • AbstractArray{<:Bijector{N}}: an array of bijectors of dimensionality N - - A Bijector representing composition of bijectors. composel and composer results in a Composed for which application occurs from left-to-right and right-to-left, respectively. - - Note that all the alternative ways of constructing a Composed returns a Tuple of bijectors. This ensures type-stability of implementations of all relating methods, e.g. inverse. - - If you want to use an Array as the container instead you can do - - Composed([b1, b2, ...]) - - In general this is not advised since you lose type-stability, but there might be cases where this is desired, e.g. if you have a insanely large number of bijectors to compose. - - Examples - ≡≡≡≡≡≡≡≡≡≡ - - It's important to note that ∘ does what is expected mathematically, which means that the bijectors are applied to the input right-to-left, e.g. first applying b2 and then b1: - - (b1 ∘ b2)(x) == b1(b2(x)) # => true - - But in the Composed struct itself, we store the bijectors left-to-right, so that - - cb1 = b1 ∘ b2 # => Composed.ts == (b2, b1) - cb2 = composel(b2, b1) # => Composed.ts == (b2, b1) - cb1(x) == cb2(x) == b1(b2(x)) # => true -``` -If anything is lacking or not clear in docstrings, feel free to open an issue or PR. - -#### Types -The following are the bijectors available: -- Abstract: - - `Bijector`: super-type of all bijectors. - - `ADBijector{AD} <: Bijector`: subtypes of this only require the user to implement `(b::UserBijector)(x)` and `(ib::Inverse{<:UserBijector})(y)`. Automatic differentation will be used to compute the `jacobian(b, x)` and thus `logabsdetjac(b, x). -- Concrete: - - `Composed`: represents a composition of bijectors. - - `Stacked`: stacks univariate and multivariate bijectors - - `Identity`: does what it says, i.e. nothing. - - `Logit` - - `Exp` - - `Log` - - `Scale`: scaling by scalar value, though at the moment only well-defined `logabsdetjac` for univariate. - - `Shift`: shifts by a scalar value. - - `Permute`: permutes the input array using matrix multiplication - - `SimplexBijector`: mostly used as the constrained-to-unconstrained bijector for `SimplexDistribution`, e.g. `Dirichlet`. - - `PlanarLayer`: §4.1 Eq. (10) in [1] - - `RadialLayer`: §4.1 Eq. (14) in [1] - -The distribution interface consists of: -- `TransformedDistribution <: Distribution`: implements the `Distribution` interface from Distributions.jl. This means `rand` and `logpdf` are provided at the moment. - -#### Methods -The following methods are implemented by all subtypes of `Bijector`, this also includes bijectors such as `Composed`. -- `(b::Bijector)(x)`: implements the transform of the `Bijector` -- `inverse(b::Bijector)`: returns the inverse of `b`, i.e. `ib::Bijector` s.t. `(ib ∘ b)(x) ≈ x`. In most cases this is `Inverse{<:Bijector}`. -- `logabsdetjac(b::Bijector, x)`: computes log(abs(det(jacobian(b, x)))). -- `with_logabsdet_jacobian(b::Bijector, x)`: returns the tuple `(b(x), logabsdetjac(b, x))` in the most efficient manner. -- `∘`, `composel`, `composer`: convenient and type-safe constructors for `Composed`. `composel(bs...)` composes s.t. the resulting composition is evaluated left-to-right, while `composer(bs...)` is evaluated right-to-left. `∘` is right-to-left, as excepted from standard mathematical notation. -- `jacobian(b::Bijector, x)` [OPTIONAL]: returns the Jacobian of the transformation. In some cases the analytical Jacobian has been implemented for efficiency. -- `dimension(b::Bijector)`: returns the dimensionality of `b`. -- `isclosedform(b::Bijector)`: returns `true` or `false` depending on whether or not `b(x)` has a closed-form implementation. - -For `TransformedDistribution`, together with default implementations for `Distribution`, we have the following methods: -- `bijector(d::Distribution)`: returns the default constrained-to-unconstrained bijector for `d` -- `transformed(d::Distribution)`, `transformed(d::Distribution, b::Bijector)`: constructs a `TransformedDistribution` from `d` and `b`. -- `logpdf_forward(d::Distribution, x)`, `logpdf_forward(d::Distribution, x, logjac)`: computes the `logpdf(td, td.transform(x))` using the forward pass, which is potentially faster depending on the transform at hand. -- `forward(d::Distribution)`: returns `(x = rand(dist), y = b(x), logabsdetjac = logabsdetjac(b, x), logpdf = logpdf_forward(td, x))` where `b = td.transform`. This combines sampling from base distribution and transforming into one function. The intention is that this entire process should be performed in the most efficient manner, e.g. the `logabsdetjac(b, x)` call might instead be implemented as `- logabsdetjac(inverse(b), b(x))` depending on which is most efficient. - # Bibliography 1. Rezende, D. J., & Mohamed, S. (2015). Variational Inference With Normalizing Flows. [arXiv:1505.05770](https://arxiv.org/abs/1505.05770v6). 2. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2016). Automatic Differentiation Variational Inference. [arXiv:1603.00788](https://arxiv.org/abs/1603.00788v1). diff --git a/docs/make.jl b/docs/make.jl index 22908cf1..e1138577 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -8,7 +8,7 @@ makedocs( sitename = "Bijectors", format = Documenter.HTML(), modules = [Bijectors], - pages = ["Home" => "index.md", "Distributions.jl integration" => "distributions.md", "Examples" => "examples.md"], + pages = ["Home" => "index.md", "Transforms" => "transforms.md", "Distributions.jl integration" => "distributions.md", "Examples" => "examples.md"], strict=false, checkdocs=:exports, ) diff --git a/docs/src/transforms.md b/docs/src/transforms.md new file mode 100644 index 00000000..cf9223aa --- /dev/null +++ b/docs/src/transforms.md @@ -0,0 +1,103 @@ +## Usage + +A very simple example of a "bijector"/diffeomorphism, i.e. a differentiable transformation with a differentiable inverse, is the `exp` function: +- The inverse of `exp` is `log`. +- The derivative of `exp` at an input `x` is simply `exp(x)`, hence `logabsdetjac` is simply `x`. + +```@repl usage +using Bijectors +transform(exp, 1.0) +logabsdetjac(exp, 1.0) +with_logabsdet_jacobian(exp, 1.0) +``` + +Some transformations are well-defined for different types of inputs, e.g. `exp` can also act elementwise on an `N`-dimensional `Array{<:Real,N}`. +To specify that a transformation should act elementwise, we use the [`elementwise`](@ref) method: + +```@repl usage +x = ones(2, 2) +transform(elementwise(exp), x) +logabsdetjac(elementwise(exp), x) +with_logabsdet_jacobian(elementwise(exp), x) +``` + +These methods also work nicely for compositions of transformations: + +```@repl usage +transform(elementwise(log ∘ exp), x) +``` + +Unlike `exp`, some transformations have parameters affecting the resulting transformation they represent, e.g. `Logit` has two parameters `a` and `b` representing the lower- and upper-bound, respectively, of its domain: + +```@repl usage +using Bijectors: Logit + +f = Logit(0.0, 1.0) +f(rand()) # takes us from `(0, 1)` to `(-∞, ∞)` +``` + +## User-facing methods + +Without mutation: + +```@docs +transform +logabsdetjac +``` + +```julia +with_logabsdet_jacobian +``` + +With mutation: + +```@docs +transform! +logabsdetjac! +with_logabsdet_jacobian! +``` + +## Implementing a transformation + +Any callable can be made into a bijector by providing an implementation of `ChangeOfVariables.with_logabsdet_jacobian(b, x)`. + +You can also optionally implement [`transform`](@ref) and [`logabsdetjac`](@ref) to avoid redundant computations. This is usually only worth it if you expect `transform` or `logabsdetjac` to be used heavily without the other. + +Similarly with the mutable versions [`with_logabsdet_jacobian!`](@ref), [`transform!`](@ref), and [`logabsdetjac!`](@ref). + +## Working with Distributions.jl + +```@docs +Bijectors.bijector +Bijectors.transformed(d::Distribution, b::Bijector) +``` + +## Utilities + +```@docs +Bijectors.elementwise +Bijectors.isinvertible +Bijectors.isclosedform(t::Bijectors.Transform) +``` + +## API + +```@docs +Bijectors.Transform +Bijectors.Bijector +Bijectors.Inverse +``` + +## Bijectors + +```@docs +Bijectors.CorrBijector +Bijectors.LeakyReLU +Bijectors.Stacked +Bijectors.RationalQuadraticSpline +Bijectors.Coupling +Bijectors.OrderedBijector +Bijectors.NamedTransform +Bijectors.NamedCoupling +``` + diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 193327fa..0bb967fb 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -35,6 +35,8 @@ using MappedArrays using Base.Iterators: drop using LinearAlgebra: AbstractTriangular +using InverseFunctions: InverseFunctions + import ChangesOfVariables: with_logabsdet_jacobian import InverseFunctions: inverse @@ -54,16 +56,16 @@ export TransformDistribution, logpdf_with_trans, isclosedform, transform, + transform!, with_logabsdet_jacobian, + with_logabsdet_jacobian!, inverse, - forward, logabsdetjac, + logabsdetjac!, logabsdetjacinv, Bijector, ADBijector, Inverse, - Composed, - compose, Stacked, stack, Identity, @@ -71,12 +73,11 @@ export TransformDistribution, transformed, UnivariateTransformed, MultivariateTransformed, - logpdf_with_jac, - logpdf_forward, PlanarLayer, RadialLayer, - CouplingLayer, - InvertibleBatchNorm + Coupling, + InvertibleBatchNorm, + elementwise if VERSION < v"1.1" using Compat: eachcol @@ -127,6 +128,19 @@ end link(d::Distribution, x) = bijector(d)(x) invlink(d::Distribution, y) = inverse(bijector(d))(y) + +# To still allow `logpdf_with_trans` to work with "batches" in a similar way +# as `logpdf` can. +_logabsdetjac_dist(d::UnivariateDistribution, x::Real) = logabsdetjac(bijector(d), x) +_logabsdetjac_dist(d::UnivariateDistribution, x::AbstractArray) = logabsdetjac.((bijector(d),), x) + +_logabsdetjac_dist(d::MultivariateDistribution, x::AbstractVector) = logabsdetjac(bijector(d), x) +_logabsdetjac_dist(d::MultivariateDistribution, x::AbstractMatrix) = logabsdetjac.((bijector(d),), eachcol(x)) + +_logabsdetjac_dist(d::MatrixDistribution, x::AbstractMatrix) = logabsdetjac(bijector(d), x) +_logabsdetjac_dist(d::MatrixDistribution, x::AbstractVector{<:AbstractMatrix}) = logabsdetjac.((bijector(d),), x) + + function logpdf_with_trans(d::Distribution, x, transform::Bool) if ispd(d) return pd_logpdf_with_trans(d, x, transform) @@ -136,7 +150,7 @@ function logpdf_with_trans(d::Distribution, x, transform::Bool) l = logpdf(d, x) end if transform - return l - logabsdetjac(bijector(d), x) + return l - _logabsdetjac_dist(d, x) else return l end @@ -253,13 +267,6 @@ include("utils.jl") include("interface.jl") include("chainrules.jl") -Base.@deprecate forward(b::AbstractBijector, x) NamedTuple{(:rv,:logabsdetjac)}(with_logabsdet_jacobian(b, x)) - -@noinline function Base.inv(b::AbstractBijector) - Base.depwarn("`Base.inv(b::AbstractBijector)` is deprecated, use `inverse(b)` instead.", :inv) - inverse(b) -end - # Broadcasting here breaks Tracker for some reason maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...) maporbroadcast(f, x::AbstractArray...) = f.(x...) diff --git a/src/bijectors/adbijector.jl b/src/bijectors/adbijector.jl index b7596b1d..7c62c4ac 100644 --- a/src/bijectors/adbijector.jl +++ b/src/bijectors/adbijector.jl @@ -1,8 +1,8 @@ """ -Abstract type for a `Bijector{N}` making use of auto-differentation (AD) to +Abstract type for a `Bijector` making use of auto-differentation (AD) to implement `jacobian` and, by impliciation, `logabsdetjac`. """ -abstract type ADBijector{AD, N} <: Bijector{N} end +abstract type ADBijector{AD} <: Bijector end struct SingularJacobianException{B<:Bijector} <: Exception b::B @@ -24,4 +24,6 @@ end function logabsdetjac(b::ADBijector, x::AbstractVector{<:Real}) fact = lu(jacobian(b, x), check=false) return issuccess(fact) ? logabsdet(fact)[1] : throw(SingularJacobianException(b)) -end \ No newline at end of file +end + +with_logabsdet_jacobian(b::ADBijector, x) = (b(x), logabsdetjac(b, x)) diff --git a/src/bijectors/composed.jl b/src/bijectors/composed.jl index 3221f94b..ee103a0e 100644 --- a/src/bijectors/composed.jl +++ b/src/bijectors/composed.jl @@ -1,235 +1,25 @@ -############### -# Composition # -############### +isinvertible(cb::ComposedFunction) = isinvertible(cb.inner) && isinvertible(cb.outer) +isclosedform(cb::ComposedFunction) = isclosedform(cb.inner) && isclosedform(cb.outer) -""" - Composed(ts::A) +transform(cb::ComposedFunction, x) = transform(cb.outer, transform(cb.inner, x)) - ∘(b1::Bijector{N}, b2::Bijector{N})::Composed{<:Tuple} - composel(ts::Bijector{N}...)::Composed{<:Tuple} - composer(ts::Bijector{N}...)::Composed{<:Tuple} - -where `A` refers to either -- `Tuple{Vararg{<:Bijector{N}}}`: a tuple of bijectors of dimensionality `N` -- `AbstractArray{<:Bijector{N}}`: an array of bijectors of dimensionality `N` - -A `Bijector` representing composition of bijectors. `composel` and `composer` results in a -`Composed` for which application occurs from left-to-right and right-to-left, respectively. - -Note that all the alternative ways of constructing a `Composed` returns a `Tuple` of bijectors. -This ensures type-stability of implementations of all relating methdos, e.g. `inverse`. - -If you want to use an `Array` as the container instead you can do - - Composed([b1, b2, ...]) - -In general this is not advised since you lose type-stability, but there might be cases -where this is desired, e.g. if you have a insanely large number of bijectors to compose. - -# Examples -## Simple example -Let's consider a simple example of `Exp`: -```julia-repl -julia> using Bijectors: Exp - -julia> b = Exp() -Exp{0}() - -julia> b ∘ b -Composed{Tuple{Exp{0},Exp{0}},0}((Exp{0}(), Exp{0}())) - -julia> (b ∘ b)(1.0) == exp(exp(1.0)) # evaluation -true - -julia> inverse(b ∘ b)(exp(exp(1.0))) == 1.0 # inversion -true - -julia> logabsdetjac(b ∘ b, 1.0) # determinant of jacobian -3.718281828459045 -``` - -# Notes -## Order -It's important to note that `∘` does what is expected mathematically, which means that the -bijectors are applied to the input right-to-left, e.g. first applying `b2` and then `b1`: -```julia -(b1 ∘ b2)(x) == b1(b2(x)) # => true -``` -But in the `Composed` struct itself, we store the bijectors left-to-right, so that -```julia -cb1 = b1 ∘ b2 # => Composed.ts == (b2, b1) -cb2 = composel(b2, b1) # => Composed.ts == (b2, b1) -cb1(x) == cb2(x) == b1(b2(x)) # => true -``` - -## Structure -`∘` will result in "flatten" the composition structure while `composel` and -`composer` preserve the compositional structure. This is most easily seen by an example: -```julia-repl -julia> b = Exp() -Exp{0}() - -julia> cb1 = b ∘ b; cb2 = b ∘ b; - -julia> (cb1 ∘ cb2).ts # <= different -(Exp{0}(), Exp{0}(), Exp{0}(), Exp{0}()) - -julia> (cb1 ∘ cb2).ts isa NTuple{4, Exp{0}} -true - -julia> Bijectors.composer(cb1, cb2).ts -(Composed{Tuple{Exp{0},Exp{0}},0}((Exp{0}(), Exp{0}())), Composed{Tuple{Exp{0},Exp{0}},0}((Exp{0}(), Exp{0}()))) - -julia> Bijectors.composer(cb1, cb2).ts isa Tuple{Composed, Composed} -true -``` - -""" -struct Composed{A, N} <: Bijector{N} - ts::A +function transform!(cb::ComposedFunction, x, y) + transform!(cb.inner, x, y) + return transform!(cb.outer, y, y) end -Composed(bs::Tuple{Vararg{<:Bijector{N}}}) where N = Composed{typeof(bs),N}(bs) -Composed(bs::AbstractArray{<:Bijector{N}}) where N = Composed{typeof(bs),N}(bs) - -# field contains nested numerical parameters -Functors.@functor Composed - -isclosedform(b::Composed) = all(isclosedform, b.ts) -up1(b::Composed) = Composed(up1.(b.ts)) -function Base.:(==)(b1::Composed{<:Any, N}, b2::Composed{<:Any, N}) where {N} - ts1, ts2 = b1.ts, b2.ts - return length(ts1) == length(ts2) && all(x == y for (x, y) in zip(ts1, ts2)) +function logabsdetjac(cb::ComposedFunction, x) + y, logjac = with_logabsdet_jacobian(cb.inner, x) + return logabsdetjac(cb.outer, y) + logjac end -""" - composel(ts::Bijector...)::Composed{<:Tuple} - -Constructs `Composed` such that `ts` are applied left-to-right. -""" -composel(ts::Bijector{N}...) where {N} = Composed(ts) - -""" - composer(ts::Bijector...)::Composed{<:Tuple} - -Constructs `Composed` such that `ts` are applied right-to-left. -""" -composer(ts::Bijector{N}...) where {N} = Composed(reverse(ts)) - -# The transformation of `Composed` applies functions left-to-right -# but in mathematics we usually go from right-to-left; this reversal ensures that -# when we use the mathematical composition ∘ we get the expected behavior. -# TODO: change behavior of `transform` of `Composed`? -@generated function ∘(b1::Bijector{N1}, b2::Bijector{N2}) where {N1, N2} - if N1 == N2 - return :(composel(b2, b1)) - else - return :(throw(DimensionMismatch("$(typeof(b1)) expects $(N1)-dim but $(typeof(b2)) expects $(N2)-dim"))) - end +function logabsdetjac!(cb::ComposedFunction, x, logjac) + y = similar(x) + logjac = last(with_logabsdet_jacobian!(cb.inner, x, y, logjac)) + return logabsdetjac!(cb.outer, y, y, logjac) end -# type-stable composition rules -∘(b1::Composed{<:Tuple}, b2::Bijector) = composel(b2, b1.ts...) -∘(b1::Bijector, b2::Composed{<:Tuple}) = composel(b2.ts..., b1) -∘(b1::Composed{<:Tuple}, b2::Composed{<:Tuple}) = composel(b2.ts..., b1.ts...) - -# type-unstable composition rules -∘(b1::Composed{<:AbstractArray}, b2::Bijector) = Composed(pushfirst!(copy(b1.ts), b2)) -∘(b1::Bijector, b2::Composed{<:AbstractArray}) = Composed(push!(copy(b2.ts), b1)) -function ∘(b1::Composed{<:AbstractArray}, b2::Composed{<:AbstractArray}) - return Composed(append!(copy(b2.ts), copy(b1.ts))) -end - -# if combining type-unstable and type-stable, return type-unstable -function ∘(b1::T1, b2::T2) where {T1<:Composed{<:Tuple}, T2<:Composed{<:AbstractArray}} - error("Cannot compose compositions of different container-types; ($T1, $T2)") -end -function ∘(b1::T1, b2::T2) where {T1<:Composed{<:AbstractArray}, T2<:Composed{<:Tuple}} - error("Cannot compose compositions of different container-types; ($T1, $T2)") -end - - -∘(::Identity{N}, ::Identity{N}) where {N} = Identity{N}() -∘(::Identity{N}, b::Bijector{N}) where {N} = b -∘(b::Bijector{N}, ::Identity{N}) where {N} = b - -inverse(ct::Composed) = Composed(reverse(map(inverse, ct.ts))) - -# # TODO: should arrays also be using recursive implementation instead? -function (cb::Composed{<:AbstractArray{<:Bijector}})(x) - @assert length(cb.ts) > 0 - res = cb.ts[1](x) - for b ∈ Base.Iterators.drop(cb.ts, 1) - res = b(res) - end - - return res -end - -@generated function (cb::Composed{T})(x) where {T<:Tuple} - @assert length(T.parameters) > 0 - expr = :(x) - for i in 1:length(T.parameters) - expr = :(cb.ts[$i]($expr)) - end - return expr -end - -function logabsdetjac(cb::Composed, x) - y, logjac = with_logabsdet_jacobian(cb.ts[1], x) - for i = 2:length(cb.ts) - y, res_logjac = with_logabsdet_jacobian(cb.ts[i], y) - logjac += res_logjac - end - - return logjac -end - -@generated function logabsdetjac(cb::Composed{T}, x) where {T<:Tuple} - N = length(T.parameters) - - expr = Expr(:block) - sym_y, sym_ladj, sym_tmp_ladj = gensym(:y), gensym(:lady), gensym(:tmp_lady) - push!(expr.args, :(($sym_y, $sym_ladj) = with_logabsdet_jacobian(cb.ts[1], x))) - sym_last_y, sym_last_ladj = sym_y, sym_ladj - for i = 2:N - 1 - sym_y, sym_ladj, sym_tmp_ladj = gensym(:y), gensym(:lady), gensym(:tmp_lady) - push!(expr.args, :(($sym_y, $sym_tmp_ladj) = with_logabsdet_jacobian(cb.ts[$i], $sym_last_y))) - push!(expr.args, :($sym_ladj = $sym_tmp_ladj + $sym_last_ladj)) - sym_last_y, sym_last_ladj = sym_y, sym_ladj - end - # don't need to evaluate the last bijector, only it's `logabsdetjac` - sym_ladj, sym_tmp_ladj = gensym(:lady), gensym(:tmp_lady) - push!(expr.args, :($sym_tmp_ladj = logabsdetjac(cb.ts[$N], $sym_last_y))) - push!(expr.args, :($sym_ladj = $sym_tmp_ladj + $sym_last_ladj)) - push!(expr.args, :(return $sym_ladj)) - - return expr -end - - -function with_logabsdet_jacobian(cb::Composed, x) - rv, logjac = with_logabsdet_jacobian(cb.ts[1], x) - - for t in cb.ts[2:end] - rv, res_logjac = with_logabsdet_jacobian(t, rv) - logjac += res_logjac - end - return (rv, logjac) -end - -@generated function with_logabsdet_jacobian(cb::Composed{T}, x) where {T<:Tuple} - expr = Expr(:block) - sym_y, sym_ladj, sym_tmp_ladj = gensym(:y), gensym(:lady), gensym(:tmp_lady) - push!(expr.args, :(($sym_y, $sym_ladj) = with_logabsdet_jacobian(cb.ts[1], x))) - sym_last_y, sym_last_ladj = sym_y, sym_ladj - for i = 2:length(T.parameters) - sym_y, sym_ladj, sym_tmp_ladj = gensym(:y), gensym(:lady), gensym(:tmp_lady) - push!(expr.args, :(($sym_y, $sym_tmp_ladj) = with_logabsdet_jacobian(cb.ts[$i], $sym_last_y))) - push!(expr.args, :($sym_ladj = $sym_tmp_ladj + $sym_last_ladj)) - sym_last_y, sym_last_ladj = sym_y, sym_ladj - end - push!(expr.args, :(return ($sym_y, $sym_ladj))) - - return expr +function with_logabsdet_jacobian!(cb::ComposedFunction, x, y, logjac) + logjac = last(with_logabsdet_jacobian!(cb.inner, x, y, logjac)) + return with_logabsdet_jacobian!(cb.outer, y, y, logjac) end diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 5ec999db..252ecc68 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -1,5 +1,5 @@ """ - CorrBijector <: Bijector{2} + CorrBijector <: Bijector A bijector implementation of Stan's parametrization method for Correlation matrix: https://mc-stan.org/docs/2_23/reference-manual/correlation-matrix-transform-section.html @@ -61,9 +61,11 @@ Note: The implementation doesn't follow their "manageable expression" directly, because their equation seems wrong (7/30/2020). Insteadly it follows definition above the "manageable expression" directly, which is also described in above doc. """ -struct CorrBijector <: Bijector{2} end +struct CorrBijector <: Bijector end -function (b::CorrBijector)(x::AbstractMatrix{<:Real}) +with_logabsdet_jacobian(b::CorrBijector, x) = transform(b, x), logabsdetjac(b, x) + +function transform(b::CorrBijector, x::AbstractMatrix{<:Real}) w = cholesky(x).U # keep LowerTriangular until here can avoid some computation r = _link_chol_lkj(w) return r + zero(x) @@ -71,14 +73,10 @@ function (b::CorrBijector)(x::AbstractMatrix{<:Real}) # https://github.com/TuringLang/Bijectors.jl/blob/b0aaa98f90958a167a0b86c8e8eca9b95502c42d/test/transform.jl#L67 end -(b::CorrBijector)(X::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, X) - -function (ib::Inverse{<:CorrBijector})(y::AbstractMatrix{<:Real}) +function transform(ib::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}) w = _inv_link_chol_lkj(y) return w' * w end -(ib::Inverse{<:CorrBijector})(Y::AbstractArray{<:AbstractMatrix{<:Real}}) = map(ib, Y) - function logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}) K = LinearAlgebra.checksquare(y) @@ -102,12 +100,6 @@ function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real}) =# return -logabsdetjac(inverse(b), (b(X))) end -function logabsdetjac(b::CorrBijector, X::AbstractArray{<:AbstractMatrix{<:Real}}) - return mapvcat(X) do x - logabsdetjac(b, x) - end -end - function _inv_link_chol_lkj(y) K = LinearAlgebra.checksquare(y) diff --git a/src/bijectors/coupling.jl b/src/bijectors/coupling.jl index 03b00ba5..9aaaf829 100644 --- a/src/bijectors/coupling.jl +++ b/src/bijectors/coupling.jl @@ -134,25 +134,23 @@ Partitions `x` into 3 disjoint subvectors. Implements a coupling-layer as defined in [1]. # Examples -```julia-repl -julia> m = PartitionMask(3, [1], [2]) # <= going to use x[2] to parameterize transform of x[1] -PartitionMask{SparseArrays.SparseMatrixCSC{Float64,Int64}}( - [1, 1] = 1.0, - [2, 1] = 1.0, - [3, 1] = 1.0) +```jldoctest +julia> using Bijectors: Shift, Coupling, PartitionMask, coupling, couple -julia> cl = Coupling(θ -> Shift(θ[1]), m) # <= will do `y[1:1] = x[1:1] + x[2:2]`; +julia> m = PartitionMask(3, [1], [2]); # <= going to use x[2] to parameterize transform of x[1] + +julia> cl = Coupling(Shift, m); # <= will do `y[1:1] = x[1:1] + x[2:2]`; julia> x = [1., 2., 3.]; julia> cl(x) -3-element Array{Float64,1}: +3-element Vector{Float64}: 3.0 2.0 3.0 julia> inverse(cl)(cl(x)) -3-element Array{Float64,1}: +3-element Vector{Float64}: 1.0 2.0 3.0 @@ -161,13 +159,16 @@ julia> coupling(cl) # get the `Bijector` map `θ -> b(⋅, θ)` Shift julia> couple(cl, x) # get the `Bijector` resulting from `x` -Shift{Array{Float64,1},1}([2.0]) +Shift([2.0]) + +julia> with_logabsdet_jacobian(cl, x) +([3.0, 2.0, 3.0], 0.0) ``` # References [1] Kobyzev, I., Prince, S., & Brubaker, M. A., Normalizing flows: introduction and ideas, CoRR, (), (2019). """ -struct Coupling{F, M} <: Bijector{1} where {F, M <: PartitionMask} +struct Coupling{F, M} <: Bijector where {F, M <: PartitionMask} θ::F mask::M end @@ -195,7 +196,31 @@ function couple(cl::Coupling, x::AbstractVector) return b end -function (cl::Coupling)(x::AbstractVector) +function with_logabsdet_jacobian(cl::Coupling, x) + # partition vector using `cl.mask::PartitionMask` + x_1, x_2, x_3 = partition(cl.mask, x) + + # construct bijector `B` using θ(x₂) + b = cl.θ(x_2) + + y_1, logjac = with_logabsdet_jacobian(b, x_1) + return combine(cl.mask, y_1, x_2, x_3), logjac +end + +function with_logabsdet_jacobian(icl::Inverse{<:Coupling}, y) + cl = icl.orig + + # partition vector using `cl.mask::PartitionMask` + y_1, y_2, y_3 = partition(cl.mask, y) + + # construct bijector `B` using θ(y₂) + b = cl.θ(y_2) + + x_1, logjac = with_logabsdet_jacobian(inverse(b), y_1) + return combine(cl.mask, x_1, y_2, y_3), logjac +end + +function transform(cl::Coupling, x::AbstractVector) # partition vector using `cl.mask::PartitionMask` x_1, x_2, x_3 = partition(cl.mask, x) @@ -205,10 +230,8 @@ function (cl::Coupling)(x::AbstractVector) # recombine the vector again using the `PartitionMask` return combine(cl.mask, b(x_1), x_2, x_3) end -(cl::Coupling)(x::AbstractMatrix) = eachcolmaphcat(cl, x) - -function (icl::Inverse{<:Coupling})(y::AbstractVector) +function transform(icl::Inverse{<:Coupling}, y::AbstractVector) cl = icl.orig y_1, y_2, y_3 = partition(cl.mask, y) @@ -218,7 +241,6 @@ function (icl::Inverse{<:Coupling})(y::AbstractVector) return combine(cl.mask, ib(y_1), y_2, y_3) end -(icl::Inverse{<:Coupling})(y::AbstractMatrix) = eachcolmaphcat(icl, y) function logabsdetjac(cl::Coupling, x::AbstractVector) x_1, x_2, x_3 = partition(cl.mask, x) @@ -228,7 +250,3 @@ function logabsdetjac(cl::Coupling, x::AbstractVector) # therefore we sum to ensure such a thing does not happen return sum(logabsdetjac(b, x_1)) end - -function logabsdetjac(cl::Coupling, x::AbstractMatrix) - return [logabsdetjac(cl, view(x, :, i)) for i in axes(x, 2)] -end diff --git a/src/bijectors/exp_log.jl b/src/bijectors/exp_log.jl index 0f5f4683..7236a74e 100644 --- a/src/bijectors/exp_log.jl +++ b/src/bijectors/exp_log.jl @@ -1,49 +1,7 @@ -############# -# Exp & Log # -############# +transform!(b::Union{Elementwise{typeof(log)}, Elementwise{typeof(exp)}}, x, y) = broadcast!(b.x, y, x) -struct Exp{N} <: Bijector{N} end -struct Log{N} <: Bijector{N} end -up1(::Exp{N}) where {N} = Exp{N + 1}() -up1(::Log{N}) where {N} = Log{N + 1}() +logabsdetjac(b::typeof(exp), x::Real) = x +logabsdetjac(b::Elementwise{typeof(exp)}, x) = sum(x) -Exp() = Exp{0}() -Log() = Log{0}() - -(b::Exp{0})(y::Real) = exp(y) -(b::Log{0})(x::Real) = log(x) - -(b::Exp{0})(y::AbstractArray{<:Real}) = exp.(y) -(b::Log{0})(x::AbstractArray{<:Real}) = log.(x) - -(b::Exp{1})(y::AbstractVector{<:Real}) = exp.(y) -(b::Exp{1})(y::AbstractMatrix{<:Real}) = exp.(y) -(b::Log{1})(x::AbstractVector{<:Real}) = log.(x) -(b::Log{1})(x::AbstractMatrix{<:Real}) = log.(x) - -(b::Exp{2})(y::AbstractMatrix{<:Real}) = exp.(y) -(b::Log{2})(x::AbstractMatrix{<:Real}) = log.(x) - -(b::Exp{2})(y::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, y) -(b::Log{2})(x::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, x) - -inverse(b::Exp{N}) where {N} = Log{N}() -inverse(b::Log{N}) where {N} = Exp{N}() - -logabsdetjac(b::Exp{0}, x::Real) = x -logabsdetjac(b::Exp{0}, x::AbstractVector) = x -logabsdetjac(b::Exp{1}, x::AbstractVector) = sum(x) -logabsdetjac(b::Exp{1}, x::AbstractMatrix) = vec(sum(x; dims = 1)) -logabsdetjac(b::Exp{2}, x::AbstractMatrix) = sum(x) -logabsdetjac(b::Exp{2}, x::AbstractArray{<:AbstractMatrix}) = map(x) do x - logabsdetjac(b, x) -end - -logabsdetjac(b::Log{0}, x::Real) = -log(x) -logabsdetjac(b::Log{0}, x::AbstractVector) = .-log.(x) -logabsdetjac(b::Log{1}, x::AbstractVector) = - sum(log, x) -logabsdetjac(b::Log{1}, x::AbstractMatrix) = - vec(sum(log, x; dims = 1)) -logabsdetjac(b::Log{2}, x::AbstractMatrix) = - sum(log, x) -logabsdetjac(b::Log{2}, x::AbstractArray{<:AbstractMatrix}) = map(x) do x - logabsdetjac(b, x) -end +logabsdetjac(b::typeof(log), x::Real) = -log(x) +logabsdetjac(b::Elementwise{typeof(log)}, x) = -sum(log, x) diff --git a/src/bijectors/leaky_relu.jl b/src/bijectors/leaky_relu.jl index c91e0faf..9d76bbb0 100644 --- a/src/bijectors/leaky_relu.jl +++ b/src/bijectors/leaky_relu.jl @@ -1,5 +1,5 @@ """ - LeakyReLU{T, N}(α::T) <: Bijector{N} + LeakyReLU{T}(α::T) <: Bijector Defines the invertible mapping @@ -7,95 +7,23 @@ Defines the invertible mapping where α > 0. """ -struct LeakyReLU{T, N} <: Bijector{N} +struct LeakyReLU{T} <: Bijector α::T end -LeakyReLU(α::T; dim::Val{N} = Val(0)) where {T<:Real, N} = LeakyReLU{T, N}(α) -LeakyReLU(α::T; dim::Val{N} = Val(D)) where {D, T<:AbstractArray{<:Real, D}, N} = LeakyReLU{T, N}(α) +Functors.@functor LeakyReLU -# field is a numerical parameter -function Functors.functor(::Type{LeakyReLU{<:Any,N}}, x) where N - function reconstruct_leakyrelu(xs) - return LeakyReLU{typeof(xs.α),N}(xs.α) - end - return (α = x.α,), reconstruct_leakyrelu -end - -up1(b::LeakyReLU{T, N}) where {T, N} = LeakyReLU{T, N + 1}(b.α) - -# (N=0) Univariate case -function (b::LeakyReLU{<:Any, 0})(x::Real) - mask = x < zero(x) - return mask * b.α * x + !mask * x -end -(b::LeakyReLU{<:Any, 0})(x::AbstractVector{<:Real}) = map(b, x) - -function inverse(b::LeakyReLU{<:Any,N}) where N - invα = inv.(b.α) - return LeakyReLU{typeof(invα),N}(invα) -end - -function logabsdetjac(b::LeakyReLU{<:Any, 0}, x::Real) - mask = x < zero(x) - J = mask * b.α + (1 - mask) * one(x) - return log(abs(J)) -end -logabsdetjac(b::LeakyReLU{<:Real, 0}, x::AbstractVector{<:Real}) = map(x -> logabsdetjac(b, x), x) +inverse(b::LeakyReLU) = LeakyReLU(inv.(b.α)) - -# We implement `with_logabsdet_jacobian` by hand since we can re-use the computation of -# the Jacobian of the transformation. This will lead to faster sampling -# when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`. -function with_logabsdet_jacobian(b::LeakyReLU{<:Any, 0}, x::Real) +function with_logabsdet_jacobian(b::LeakyReLU, x::Real) mask = x < zero(x) - J = mask * b.α + !mask * one(x) - return (J * x, log(abs(J))) -end - -# Batched version -function with_logabsdet_jacobian(b::LeakyReLU{<:Any, 0}, x::AbstractVector) - J = let T = eltype(x), z = zero(T), o = one(T) - @. (x < z) * b.α + (x > z) * o - end - return (J .* x, log.(abs.(J))) -end - -# (N=1) Multivariate case -function (b::LeakyReLU{<:Any, 1})(x::AbstractVecOrMat) - return let z = zero(eltype(x)) - @. (x < z) * b.α * x + (x > z) * x - end + J = mask * b.α + !mask + return J * x, log(abs(J)) end -function logabsdetjac(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat) - # Is really diagonal of jacobian - J = let T = eltype(x), z = zero(T), o = one(T) - @. (x < z) * b.α + (x > z) * o - end - - if x isa AbstractVector - return sum(log.(abs.(J))) - elseif x isa AbstractMatrix - return vec(sum(log.(abs.(J)); dims = 1)) # sum along column - end -end - -# We implement `forward` by hand since we can re-use the computation of -# the Jacobian of the transformation. This will lead to faster sampling -# when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`. -function with_logabsdet_jacobian(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat) - # Is really diagonal of jacobian - J = let T = eltype(x), z = zero(T), o = one(T) - @. (x < z) * b.α + (x > z) * o - end - - if x isa AbstractVector - logjac = sum(log.(abs.(J))) - elseif x isa AbstractMatrix - logjac = vec(sum(log.(abs.(J)); dims = 1)) # sum along column - end - - y = J .* x - return (y, logjac) +# Array inputs. +function with_logabsdet_jacobian(b::LeakyReLU, x::AbstractArray) + mask = x .< zero(eltype(x)) + J = mask .* b.α .+ (!).(mask) + return J .* x, sum(log.(abs.(J))) end diff --git a/src/bijectors/logit.jl b/src/bijectors/logit.jl index b356c641..1df73514 100644 --- a/src/bijectors/logit.jl +++ b/src/bijectors/logit.jl @@ -1,43 +1,28 @@ ###################### # Logit and Logistic # ###################### -struct Logit{N, T<:Real} <: Bijector{N} +struct Logit{T} <: Bijector a::T b::T end -Logit(a::Real, b::Real) = Logit{0}(a, b) -Logit(a::AbstractArray{<:Real, N}, b::AbstractArray{<:Real, N}) where {N} = Logit{N}(a, b) -function Logit{N}(a, b) where {N} - T = promote_type(typeof(a), typeof(b)) - Logit{N, T}(a, b) -end -# fields are numerical parameters -function Functors.functor(::Type{<:Logit{N}}, x) where N - function reconstruct_logit(xs) - T = promote_type(typeof(xs.a), typeof(xs.b)) - return Logit{N,T}(xs.a, xs.b) - end - return (a = x.a, b = x.b,), reconstruct_logit -end +Functors.@functor Logit -up1(b::Logit{N, T}) where {N, T} = Logit{N + 1, T}(b.a, b.b) # For equality of Logit with Float64 fields to one with Duals Base.:(==)(b1::Logit, b2::Logit) = b1.a == b2.a && b1.b == b2.b -(b::Logit)(x) = _logit.(x, b.a, b.b) -(b::Logit)(x::AbstractArray{<:AbstractArray}) = map(b, x) +# Evaluation _logit(x, a, b) = LogExpFunctions.logit((x - a) / (b - a)) +transform(b::Logit, x) = _logit.(x, b.a, b.b) -(ib::Inverse{<:Logit})(y) = _ilogit.(y, ib.orig.a, ib.orig.b) -(ib::Inverse{<:Logit})(x::AbstractArray{<:AbstractArray}) = map(ib, x) +# Inverse _ilogit(y, a, b) = (b - a) * LogExpFunctions.logistic(y) + a -logabsdetjac(b::Logit{0}, x) = logit_logabsdetjac.(x, b.a, b.b) -logabsdetjac(b::Logit{1}, x::AbstractVector) = sum(logit_logabsdetjac.(x, b.a, b.b)) -logabsdetjac(b::Logit{1}, x::AbstractMatrix) = vec(sum(logit_logabsdetjac.(x, b.a, b.b), dims = 1)) -logabsdetjac(b::Logit{2}, x::AbstractMatrix) = sum(logit_logabsdetjac.(x, b.a, b.b)) -logabsdetjac(b::Logit{2}, x::AbstractArray{<:AbstractMatrix}) = map(x) do x - logabsdetjac(b, x) -end +transform(ib::Inverse{<:Logit}, y) = _ilogit.(y, ib.orig.a, ib.orig.b) + +# `logabsdetjac` logit_logabsdetjac(x, a, b) = -log((x - a) * (b - x) / (b - a)) +logabsdetjac(b::Logit, x) = sum(logit_logabsdetjac.(x, b.a, b.b)) + +# `with_logabsdet_jacobian` +with_logabsdet_jacobian(b::Logit, x) = _logit.(x, b.a, b.b), sum(logit_logabsdetjac.(x, b.a, b.b)) diff --git a/src/bijectors/named_bijector.jl b/src/bijectors/named_bijector.jl index d4fb0557..d147afc6 100644 --- a/src/bijectors/named_bijector.jl +++ b/src/bijectors/named_bijector.jl @@ -1,20 +1,19 @@ -abstract type AbstractNamedBijector <: AbstractBijector end - -with_logabsdet_jacobian(b::AbstractNamedBijector, x) = (b(x), logabsdetjac(b, x)) +abstract type AbstractNamedTransform <: Transform end +abstract type AbstractNamedBijector <: Transform end ####################### -### `NamedBijector` ### +### `NamedTransform` ### ####################### """ - NamedBijector <: AbstractNamedBijector + NamedTransform <: AbstractNamedTransform Wraps a `NamedTuple` of key -> `Bijector` pairs, implementing evaluation, inversion, etc. # Examples -```julia-repl -julia> using Bijectors: NamedBijector, Scale, Exp +```jldoctest +julia> using Bijectors: NamedTransform, Scale -julia> b = NamedBijector((a = Scale(2.0), b = Exp())); +julia> b = NamedTransform((a = Scale(2.0), b = exp)); julia> x = (a = 1., b = 0., c = 42.); @@ -25,21 +24,26 @@ julia> (a = 2 * x.a, b = exp(x.b), c = x.c) (a = 2.0, b = 1.0, c = 42.0) ``` """ -struct NamedBijector{names, Bs<:NamedTuple{names}} <: AbstractNamedBijector +struct NamedTransform{names, Bs<:NamedTuple{names}} <: AbstractNamedTransform bs::Bs end # fields contain nested numerical parameters -function Functors.functor(::Type{<:NamedBijector{names}}, x) where names +function Functors.functor(::Type{<:NamedTransform{names}}, x) where names function reconstruct_namedbijector(xs) - return NamedBijector{names,typeof(xs.bs)}(xs.bs) + return NamedTransform{names,typeof(xs.bs)}(xs.bs) end return (bs = x.bs,), reconstruct_namedbijector end -names_to_bijectors(b::NamedBijector) = b.bs +# TODO: Use recursion instead of `@generated`? +inverse(t::NamedTransform) = NamedTransform(map(inverse, t.bs)) +# NOTE: Need explicit definition, since `inverse(::NamedTransform)` will +# end up wrapping a potential `NoInverse` in `NamedTransform`. +isinvertible(t::NamedTransform) = all(isinvertible, t.bs) -@generated function (b::NamedBijector{names1})( +@generated function transform( + b::NamedTransform{names1}, x::NamedTuple{names2} ) where {names1, names2} exprs = [] @@ -55,146 +59,52 @@ names_to_bijectors(b::NamedBijector) = b.bs return :($(exprs...), ) end -@generated function inverse(b::NamedBijector{names}) where {names} - return :(NamedBijector(($([:($n = inverse(b.bs.$n)) for n in names]...), ))) -end - -@generated function logabsdetjac(b::NamedBijector{names}, x::NamedTuple) where {names} +@generated function logabsdetjac(b::NamedTransform{names}, x::NamedTuple) where {names} exprs = [:(logabsdetjac(b.bs.$n, x.$n)) for n in names] return :(+($(exprs...))) end +@generated function with_logabsdet_jacobian( + b::NamedTransform{names1}, + x::NamedTuple{names2} +) where {names1, names2} + body_exprs = [] + logjac_expr = Expr(:call, :+) + val_expr = Expr(:tuple, ) + for n in names2 + if n in names1 + val_sym = Symbol("y_$n") + logjac_sym = Symbol("logjac_$n") -###################### -### `NamedInverse` ### -###################### -""" - NamedInverse <: AbstractNamedBijector - -Represents the inverse of a `AbstractNamedBijector`, similarily to `Inverse` for `Bijector`. - -See also: [`Inverse`](@ref) -""" -struct NamedInverse{B<:AbstractNamedBijector} <: AbstractNamedBijector - orig::B -end -inverse(nb::AbstractNamedBijector) = NamedInverse(nb) -inverse(ni::NamedInverse) = ni.orig - -logabsdetjac(ni::NamedInverse, y::NamedTuple) = -logabsdetjac(inverse(ni), ni(y)) - -########################## -### `NamedComposition` ### -########################## -""" - NamedComposition <: AbstractNamedBijector - -Wraps a tuple of array of `AbstractNamedBijector` and implements their composition. - -This is very similar to `Composed` for `Bijector`, with the exception that we do not require -the inputs to have the same "dimension", which in this case refers to the *symbols* for the -`NamedTuple` that this takes as input. - -See also: [`Composed`](@ref) -""" -struct NamedComposition{Bs} <: AbstractNamedBijector - bs::Bs -end - -# Essentially just copy-paste from impl of composition for 'standard' bijectors, -# with minor changes here and there. -composel(bs::AbstractNamedBijector...) = NamedComposition(bs) -composer(bs::AbstractNamedBijector...) = NamedComposition(reverse(bs)) -∘(b1::AbstractNamedBijector, b2::AbstractNamedBijector) = composel(b2, b1) - -inverse(ct::NamedComposition) = NamedComposition(reverse(map(inverse, ct.bs))) - -function (cb::NamedComposition{<:AbstractArray{<:AbstractNamedBijector}})(x) - @assert length(cb.bs) > 0 - res = cb.bs[1](x) - for b ∈ Base.Iterators.drop(cb.bs, 1) - res = b(res) - end - - return res -end - -(cb::NamedComposition{<:Tuple})(x) = foldl(|>, cb.bs; init=x) - -function logabsdetjac(cb::NamedComposition, x) - y, logjac = with_logabsdet_jacobian(cb.bs[1], x) - for i = 2:length(cb.bs) - y, res_logjac = with_logabsdet_jacobian(cb.bs[i], y) - logjac += res_logjac - end - - return logjac -end - -@generated function logabsdetjac(cb::NamedComposition{T}, x) where {T<:Tuple} - N = length(T.parameters) - - expr = Expr(:block) - push!(expr.args, :((y, logjac) = with_logabsdet_jacobian(cb.bs[1], x))) - - for i = 2:N - 1 - temp = gensym(:res_logjac) - push!(expr.args, :(y, $temp = with_logabsdet_jacobian(cb.bs[$i], y))) - push!(expr.args, :(logjac += $temp)) - end - # don't need to evaluate the last bijector, only it's `logabsdetjac` - push!(expr.args, :(logjac += logabsdetjac(cb.bs[$N], y))) - - push!(expr.args, :(return logjac)) - - return expr -end - - -function with_logabsdet_jacobian(cb::NamedComposition, x) - rv, logjac = with_logabsdet_jacobian(cb.bs[1], x) - - for t in cb.bs[2:end] - rv, res_logjac = with_logabsdet_jacobian(t, rv) - logjac += res_logjac + push!(body_exprs, :(($val_sym, $logjac_sym) = with_logabsdet_jacobian(b.bs.$n, x.$n))) + push!(logjac_expr.args, logjac_sym) + push!(val_expr.args, :($n = $val_sym)) + else + push!(val_expr.args, :($n = x.$n)) + end end - return (rv, logjac) -end - - -@generated function with_logabsdet_jacobian(cb::NamedComposition{T}, x) where {T<:Tuple} - expr = Expr(:block) - - sym_y, sym_ladj, sym_tmp_ladj = gensym(:y), gensym(:lady), gensym(:tmp_lady) - push!(expr.args, :(($sym_y, $sym_ladj) = with_logabsdet_jacobian(cb.bs[1], x))) - sym_last_y, sym_last_ladj = sym_y, sym_ladj - for i = 2:length(T.parameters) - sym_y, sym_ladj, sym_tmp_ladj = gensym(:y), gensym(:lady), gensym(:tmp_lady) - push!(expr.args, :(($sym_y, $sym_tmp_ladj) = with_logabsdet_jacobian(cb.bs[$i], $sym_last_y))) - push!(expr.args, :($sym_ladj = $sym_tmp_ladj + $sym_last_ladj)) - sym_last_y, sym_last_ladj = sym_y, sym_ladj + return quote + $(body_exprs...) + return NamedTuple{$names2}($val_expr), $logjac_expr end - push!(expr.args, :(return ($sym_y, $sym_ladj))) - - return expr end - ############################ ### `NamedCouplingLayer` ### ############################ # TODO: Add ref to `Coupling` or `CouplingLayer` once that's merged. """ - NamedCoupling{target, deps, F} <: AbstractNamedBijector + NamedCoupling{target, deps, F} <: AbstractNamedTransform Implements a coupling layer for named bijectors. +See also: [`Coupling`](@ref) + # Examples -```julia-repl +```jldoctest julia> using Bijectors: NamedCoupling, Scale -julia> b = NamedCoupling(:b, (:a, :c), (a, c) -> Scale(a + c)) -NamedCoupling{:b,(:a, :c),var"#3#4"}(var"#3#4"()) +julia> b = NamedCoupling(:b, (:a, :c), (a, c) -> Scale(a + c)); julia> x = (a = 1., b = 2., c = 3.); @@ -214,31 +124,26 @@ function NamedCoupling(::Val{target}, ::Val{deps}, f::F) where {target, deps, F} return NamedCoupling{target, deps, F}(f) end +isinvertible(::NamedCoupling) = true + coupling(b::NamedCoupling) = b.f # For some reason trying to use the parameteric types doesn't always work # so we have to do this weird approach of extracting type and then index `parameters`. target(b::NamedCoupling{Target}) where {Target} = Target deps(b::NamedCoupling{<:Any, Deps}) where {Deps} = Deps -@generated function (nc::NamedCoupling{target, deps, F})(x::NamedTuple) where {target, deps, F} +@generated function with_logabsdet_jacobian(nc::NamedCoupling{target, deps, F}, x::NamedTuple) where {target, deps, F} return quote b = nc.f($([:(x.$d) for d in deps]...)) - return merge(x, ($target = b(x.$target), )) + x_target, logjac = with_logabsdet_jacobian(b, x.$target) + return merge(x, ($target = x_target, )), logjac end end -@generated function (ni::NamedInverse{<:NamedCoupling{target, deps, F}})( - x::NamedTuple -) where {target, deps, F} +@generated function with_logabsdet_jacobian(ni::Inverse{<:NamedCoupling{target, deps, F}}, x::NamedTuple) where {target, deps, F} return quote - b = ni.orig.f($([:(x.$d) for d in deps]...)) - return merge(x, ($target = inverse(b)(x.$target), )) - end -end - -@generated function logabsdetjac(nc::NamedCoupling{target, deps, F}, x::NamedTuple) where {target, deps, F} - return quote - b = nc.f($([:(x.$d) for d in deps]...)) - return logabsdetjac(b, x.$target) + ib = inverse(ni.orig.f($([:(x.$d) for d in deps]...))) + x_target, logjac = with_logabsdet_jacobian(ib, x.$target) + return merge(x, ($target = x_target, )), logjac end end diff --git a/src/bijectors/normalise.jl b/src/bijectors/normalise.jl index c49863c3..6b131d27 100644 --- a/src/bijectors/normalise.jl +++ b/src/bijectors/normalise.jl @@ -6,7 +6,7 @@ using Statistics: mean istraining() = false -mutable struct InvertibleBatchNorm{T1,T2,T3} <: Bijector{1} +mutable struct InvertibleBatchNorm{T1,T2,T3} <: Bijector b :: T1 # bias logs :: T1 # log-scale m :: T2 # moving mean @@ -38,15 +38,7 @@ function InvertibleBatchNorm( ) end -# define numerical parameters -# TODO: replace with `Functors.@functor InvertibleBatchNorm (b, logs)` when -# https://github.com/FluxML/Functors.jl/pull/7 is merged -function Functors.functor(::Type{<:InvertibleBatchNorm}, x) - function reconstruct_invertiblebatchnorm(xs) - return InvertibleBatchNorm(xs.b, xs.logs, x.m, x.v, x.eps, x.mtm) - end - return (b = x.b, logs = x.logs), reconstruct_invertiblebatchnorm -end +Functors.@functor InvertibleBatchNorm (b, logs) function with_logabsdet_jacobian(bn::InvertibleBatchNorm, x) dims = ndims(x) @@ -72,16 +64,15 @@ function with_logabsdet_jacobian(bn::InvertibleBatchNorm, x) v = reshape(bn.v, as...) end - rv = s .* (x .- m) ./ sqrt.(v .+ bn.eps) .+ b + result = s .* (x .- m) ./ sqrt.(v .+ bn.eps) .+ b logabsdetjac = ( fill(sum(logs - log.(v .+ bn.eps) / 2), size(x, dims)) ) - return (rv, logabsdetjac) + return (result, logabsdetjac) end logabsdetjac(bn::InvertibleBatchNorm, x) = last(with_logabsdet_jacobian(bn, x)) - -(bn::InvertibleBatchNorm)(x) = first(with_logabsdet_jacobian(bn, x)) +transform(bn::InvertibleBatchNorm, x) = first(with_logabsdet_jacobian(bn, x)) function with_logabsdet_jacobian(invbn::Inverse{<:InvertibleBatchNorm}, y) @assert !istraining() "`with_logabsdet_jacobian(::Inverse{InvertibleBatchNorm})` is only available in test mode." @@ -97,7 +88,7 @@ function with_logabsdet_jacobian(invbn::Inverse{<:InvertibleBatchNorm}, y) return (x, -logabsdetjac(bn, x)) end -(bn::Inverse{<:InvertibleBatchNorm})(y) = first(with_logabsdet_jacobian(bn, y)) +transform(bn::Inverse{<:InvertibleBatchNorm}, y) = first(with_logabsdet_jacobian(bn, y)) function Base.show(io::IO, l::InvertibleBatchNorm) print(io, "InvertibleBatchNorm($(join(size(l.b), ", ")))") diff --git a/src/bijectors/ordered.jl b/src/bijectors/ordered.jl index d1bfd8f0..2a43a661 100644 --- a/src/bijectors/ordered.jl +++ b/src/bijectors/ordered.jl @@ -7,7 +7,7 @@ A bijector mapping ordered vectors in ℝᵈ to unordered vectors in ℝᵈ. - [Stan's documentation](https://mc-stan.org/docs/2_27/reference-manual/ordered-vector.html) - Note that this transformation and its inverse are the _opposite_ of in this reference. """ -struct OrderedBijector <: Bijector{1} end +struct OrderedBijector <: Bijector end """ ordered(d::Distribution) @@ -16,7 +16,9 @@ Return a `Distribution` whose support are ordered vectors, i.e., vectors with in """ ordered(d::ContinuousMultivariateDistribution) = Bijectors.transformed(d, OrderedBijector()) -(b::OrderedBijector)(y::AbstractVecOrMat) = _transform_ordered(y) +with_logabsdet_jacobian(b::OrderedBijector, x) = transform(b, x), logabsdetjac(b, x) + +transform(b::OrderedBijector, y::AbstractVecOrMat) = _transform_ordered(y) function _transform_ordered(y::AbstractVector) x = similar(y) @@ -45,8 +47,7 @@ function _transform_ordered(y::AbstractMatrix) return x end -(ib::Inverse{<:OrderedBijector})(x::AbstractVecOrMat) = _transform_inverse_ordered(x) - +transform(ib::Inverse{OrderedBijector}, x::AbstractVecOrMat) = _transform_inverse_ordered(x) function _transform_inverse_ordered(x::AbstractVector) y = similar(x) @assert !isempty(y) diff --git a/src/bijectors/pd.jl b/src/bijectors/pd.jl index 332a8b36..5b57f55b 100644 --- a/src/bijectors/pd.jl +++ b/src/bijectors/pd.jl @@ -1,4 +1,4 @@ -struct PDBijector <: Bijector{2} end +struct PDBijector <: Bijector end # This function has custom adjoints defined for Tracker, Zygote and ReverseDiff. # I couldn't find a mutation-free implementation that maintains TrackedArrays in Tracker @@ -7,19 +7,17 @@ function replace_diag(f, X) g(i, j) = ifelse(i == j, f(X[i, i]), X[i, j]) return g.(1:size(X, 1), (1:size(X, 2))') end -(b::PDBijector)(X::AbstractMatrix{<:Real}) = pd_link(X) +transform(b::PDBijector, X::AbstractMatrix{<:Real}) = pd_link(X) function pd_link(X) Y = lower(parent(cholesky(X; check = true).L)) return replace_diag(log, Y) end -(b::PDBijector)(X::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, X) lower(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A)) -function (ib::Inverse{<:PDBijector})(Y::AbstractMatrix{<:Real}) +function transform(ib::Inverse{PDBijector}, Y::AbstractMatrix{<:Real}) X = replace_diag(exp, Y) return getpd(X) end -(ib::Inverse{<:PDBijector})(X::AbstractArray{<:AbstractMatrix{<:Real}}) = map(ib, X) getpd(X) = LowerTriangular(X) * LowerTriangular(X)' function logabsdetjac(b::PDBijector, X::AbstractMatrix{<:Real}) @@ -37,7 +35,3 @@ function logabsdetjac(b::PDBijector, Xcf::Cholesky) d = size(U, 1) return - sum((d .- (1:d) .+ 2) .* log.(diag(U))) - d * log(T(2)) end - -logabsdetjac(b::PDBijector, X::AbstractArray{<:AbstractMatrix{<:Real}}) = mapvcat(X) do x - logabsdetjac(b, x) -end diff --git a/src/bijectors/permute.jl b/src/bijectors/permute.jl index 5ba9e5cf..a2b49aa7 100644 --- a/src/bijectors/permute.jl +++ b/src/bijectors/permute.jl @@ -2,7 +2,7 @@ using SparseArrays using ArgCheck """ - Permute{A} <: Bijector{1} + Permute{A} <: Bijector A bijector implementation of a permutation. The permutation is performed using a matrix of type `A`. There are a couple of different ways to construct `Permute`: @@ -81,7 +81,7 @@ julia> inverse(b1)(b1([1., 2., 3.])) 3.0 ``` """ -struct Permute{A} <: Bijector{1} +struct Permute{A} <: Bijector A::A end @@ -150,8 +150,9 @@ function Permute(n::Int, indices::Pair{Vector{Int}, Vector{Int}}...) end -@inline (b::Permute)(x::AbstractVecOrMat) = b.A * x -@inline inverse(b::Permute) = Permute(transpose(b.A)) +transform(b::Permute, x::AbstractVecOrMat) = b.A * x +inverse(b::Permute) = Permute(transpose(b.A)) logabsdetjac(b::Permute, x::AbstractVector) = zero(eltype(x)) -logabsdetjac(b::Permute, x::AbstractMatrix) = zero(eltype(x), size(x, 2)) + +with_logabsdet_jacobian(b::Permute, x) = transform(b, x), logabsdetjac(b, x) diff --git a/src/bijectors/planar_layer.jl b/src/bijectors/planar_layer.jl index a4f7be1e..be46de3a 100644 --- a/src/bijectors/planar_layer.jl +++ b/src/bijectors/planar_layer.jl @@ -10,7 +10,7 @@ # TODO: add docstring -struct PlanarLayer{T1<:AbstractVector{<:Real}, T2<:Union{Real, AbstractVector{<:Real}}} <: Bijector{1} +struct PlanarLayer{T1<:AbstractVector{<:Real}, T2<:Union{Real, AbstractVector{<:Real}}} <: Bijector w::T1 u::T1 b::T2 @@ -29,6 +29,8 @@ end # all fields are numerical parameters Functors.@functor PlanarLayer +Base.show(io::IO, b::PlanarLayer) = print(io, "PlanarLayer(w = $(b.w), u = $(b.u), b = $(b.b))") + """ get_u_hat(u::AbstractVector{<:Real}, w::AbstractVector{<:Real}) @@ -74,7 +76,7 @@ function _transform(flow::PlanarLayer, z::AbstractVecOrMat{<:Real}) return (transformed = transformed, wT_û = wT_û, wT_z = wT_z) end -(b::PlanarLayer)(z) = _transform(b, z).transformed +transform(b::PlanarLayer, z) = _transform(b, z).transformed #= Log-determinant of the Jacobian of the planar layer @@ -101,10 +103,10 @@ function with_logabsdet_jacobian(flow::PlanarLayer, z::AbstractVecOrMat{<:Real}) b = first(flow.b) log_det_jacobian = log1p.(wT_û .* abs2.(sech.(_vec(wT_z) .+ b))) - return (transformed, log_det_jacobian) + return (result = transformed, logabsdetjac = log_det_jacobian) end -function (ib::Inverse{<:PlanarLayer})(y::AbstractVecOrMat{<:Real}) +function transform(ib::Inverse{<:PlanarLayer}, y::AbstractVecOrMat{<:Real}) flow = ib.orig w = flow.w b = first(flow.b) diff --git a/src/bijectors/radial_layer.jl b/src/bijectors/radial_layer.jl index e486f800..d4156f01 100644 --- a/src/bijectors/radial_layer.jl +++ b/src/bijectors/radial_layer.jl @@ -8,7 +8,7 @@ # RadialLayer # ############### -mutable struct RadialLayer{T1<:Union{Real, AbstractVector{<:Real}}, T2<:AbstractVector{<:Real}} <: Bijector{1} +mutable struct RadialLayer{T1<:Union{Real, AbstractVector{<:Real}}, T2<:AbstractVector{<:Real}} <: Bijector α_::T1 β::T1 z_0::T2 @@ -27,6 +27,8 @@ end # all fields are numerical parameters Functors.@functor RadialLayer +Base.show(io::IO, b::RadialLayer) = print(io, "RadialLayer(α_ = $(b.α_), β = $(b.β), z_0 = $(b.z_0))") + h(α, r) = 1 ./ (α .+ r) # for radial flow from eq(14) #dh(α, r) = .- (1 ./ (α .+ r)) .^ 2 # for radial flow; derivative of h() @@ -46,8 +48,8 @@ function _radial_transform(α_, β, z_0, z) return (transformed = transformed, α = α, β_hat = β_hat, r = r) end -(b::RadialLayer)(z::AbstractMatrix{<:Real}) = _transform(b, z).transformed -(b::RadialLayer)(z::AbstractVector{<:Real}) = vec(_transform(b, z).transformed) +transform(b::RadialLayer, z::AbstractVector{<:Real}) = vec(_transform(b, z).transformed) +transform(b::RadialLayer, z::AbstractMatrix{<:Real}) = _transform(b, z).transformed function with_logabsdet_jacobian(flow::RadialLayer, z::AbstractVecOrMat) transformed, α, β_hat, r = _transform(flow, z) @@ -63,10 +65,10 @@ function with_logabsdet_jacobian(flow::RadialLayer, z::AbstractVecOrMat) (d - 1) * log(1 + β_hat * h_) + log(1 + β_hat * h_ + β_hat * (- h_ ^ 2) * r) ) # from eq(14) - return (transformed, log_det_jacobian) + return (result = transformed, logabsdetjac = log_det_jacobian) end -function (ib::Inverse{<:RadialLayer})(y::AbstractVector{<:Real}) +function transform(ib::Inverse{<:RadialLayer}, y::AbstractVector{<:Real}) flow = ib.orig z0 = flow.z_0 α = LogExpFunctions.log1pexp(first(flow.α_)) # from A.2 @@ -80,7 +82,7 @@ function (ib::Inverse{<:RadialLayer})(y::AbstractVector{<:Real}) return z0 .+ γ .* y_minus_z0 end -function (ib::Inverse{<:RadialLayer})(y::AbstractMatrix{<:Real}) +function transform(ib::Inverse{<:RadialLayer}, y::AbstractMatrix{<:Real}) flow = ib.orig z0 = flow.z_0 α = LogExpFunctions.log1pexp(first(flow.α_)) # from A.2 diff --git a/src/bijectors/rational_quadratic_spline.jl b/src/bijectors/rational_quadratic_spline.jl index ef34c436..6c0dd601 100644 --- a/src/bijectors/rational_quadratic_spline.jl +++ b/src/bijectors/rational_quadratic_spline.jl @@ -1,6 +1,5 @@ """ - RationalQuadraticSpline{T, 0} <: Bijector{0} - RationalQuadraticSpline{T, 1} <: Bijector{1} + RationalQuadraticSpline{T} <: Bijector Implementation of the Rational Quadratic Spline flow [1]. @@ -29,53 +28,51 @@ There are two constructors for `RationalQuadraticSpline`: # Examples ## Univariate -```julia-repl +```jldoctest +julia> using StableRNGs: StableRNG; rng = StableRNG(42); # For reproducibility. + julia> using Bijectors: RationalQuadraticSpline julia> K = 3; B = 2; julia> # Monotonic spline on '[-B, B]' with `K` intermediate knots/"connection points". - b = RationalQuadraticSpline(randn(K), randn(K), randn(K - 1), B); + b = RationalQuadraticSpline(randn(rng, K), randn(rng, K), randn(rng, K - 1), B); julia> b(0.5) # inside of `[-B, B]` → transformed -1.412300607463467 +1.1943325397834206 julia> b(5.) # outside of `[-B, B]` → not transformed 5.0 -``` -Or we can use the constructor with the parameters correctly constrained: -```julia-repl + julia> b = RationalQuadraticSpline(b.widths, b.heights, b.derivatives); julia> b(0.5) # inside of `[-B, B]` → transformed -1.412300607463467 -``` -## Multivariate -```julia-repl +1.1943325397834206 + julia> d = 2; K = 3; B = 2; -julia> b = RationalQuadraticSpline(randn(d, K), randn(d, K), randn(d, K - 1), B); +julia> b = RationalQuadraticSpline(randn(rng, d, K), randn(rng, d, K), randn(rng, d, K - 1), B); julia> b([-1., 1.]) -2-element Array{Float64,1}: - -1.2568224171342797 - 0.5537259740554675 +2-element Vector{Float64}: + -1.5660106244288925 + 0.5384702734738573 julia> b([-5., 5.]) -2-element Array{Float64,1}: +2-element Vector{Float64}: -5.0 5.0 julia> b([-1., 5.]) -2-element Array{Float64,1}: - -1.2568224171342797 +2-element Vector{Float64}: + -1.5660106244288925 5.0 ``` # References [1] Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G., Neural Spline Flows, CoRR, arXiv:1906.04032 [stat.ML], (2019). """ -struct RationalQuadraticSpline{T, N} <: Bijector{N} +struct RationalQuadraticSpline{T} <: Bijector widths::T # K widths heights::T # K heights derivatives::T # K derivatives, with endpoints being ones @@ -89,7 +86,7 @@ struct RationalQuadraticSpline{T, N} <: Bijector{N} @assert length(widths) == length(heights) == length(derivatives) @assert all(derivatives .> 0) "derivatives need to be positive" - return new{T, 0}(widths, heights, derivatives) + return new{T}(widths, heights, derivatives) end function RationalQuadraticSpline( @@ -99,7 +96,7 @@ struct RationalQuadraticSpline{T, N} <: Bijector{N} ) where {T<:AbstractMatrix} @assert size(widths, 2) == size(heights, 2) == size(derivatives, 2) @assert all(derivatives .> 0) "derivatives need to be positive" - return new{T, 1}(widths, heights, derivatives) + return new{T}(widths, heights, derivatives) end end @@ -176,18 +173,15 @@ end # univariate -function (b::RationalQuadraticSpline{<:AbstractVector, 0})(x::Real) +function transform(b::RationalQuadraticSpline{<:AbstractVector}, x::Real) return rqs_univariate(b.widths, b.heights, b.derivatives, x) end -(b::RationalQuadraticSpline{<:AbstractVector, 0})(x::AbstractVector) = b.(x) # multivariate -function (b::RationalQuadraticSpline{<:AbstractMatrix, 1})(x::AbstractVector) +# TODO: Improve. +function transform(b::RationalQuadraticSpline{<:AbstractMatrix}, x::AbstractVector) return [rqs_univariate(b.widths[i, :], b.heights[i, :], b.derivatives[i, :], x[i]) for i = 1:length(x)] end -function (b::RationalQuadraticSpline{<:AbstractMatrix, 1})(x::AbstractMatrix) - return eachcolmaphcat(b, x) -end ########################## ### Inverse evaluation ### @@ -231,18 +225,15 @@ function rqs_univariate_inverse(widths, heights, derivatives, y::Real) return ξ * w + w_k end -function (ib::Inverse{<:RationalQuadraticSpline, 0})(y::Real) +function transform(ib::Inverse{<:RationalQuadraticSpline}, y::Real) return rqs_univariate_inverse(ib.orig.widths, ib.orig.heights, ib.orig.derivatives, y) end -(ib::Inverse{<:RationalQuadraticSpline, 0})(y::AbstractVector) = ib.(y) -function (ib::Inverse{<:RationalQuadraticSpline, 1})(y::AbstractVector) +# TODO: Improve. +function transform(ib::Inverse{<:RationalQuadraticSpline}, y::AbstractVector) b = ib.orig return [rqs_univariate_inverse(b.widths[i, :], b.heights[i, :], b.derivatives[i, :], y[i]) for i = 1:length(y)] end -function (ib::Inverse{<:RationalQuadraticSpline, 1})(y::AbstractMatrix) - return eachcolmaphcat(ib, y) -end ###################### ### `logabsdetjac` ### @@ -312,21 +303,17 @@ function rqs_logabsdetjac( return log(numerator) - 2 * log(denominator) end -function logabsdetjac(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::Real) +function logabsdetjac(b::RationalQuadraticSpline{<:AbstractVector}, x::Real) return rqs_logabsdetjac(b.widths, b.heights, b.derivatives, x) end -function logabsdetjac(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::AbstractVector) - return logabsdetjac.(b, x) -end -function logabsdetjac(b::RationalQuadraticSpline{<:AbstractMatrix, 1}, x::AbstractVector) + +# TODO: Improve. +function logabsdetjac(b::RationalQuadraticSpline{<:AbstractMatrix}, x::AbstractVector) return sum([ rqs_logabsdetjac(b.widths[i, :], b.heights[i, :], b.derivatives[i, :], x[i]) for i = 1:length(x) ]) end -function logabsdetjac(b::RationalQuadraticSpline{<:AbstractMatrix, 1}, x::AbstractMatrix) - return mapvcat(x -> logabsdetjac(b, x), eachcol(x)) -end ################# ### `forward` ### @@ -379,6 +366,10 @@ function rqs_forward( return (y, logjac) end -function with_logabsdet_jacobian(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::Real) +function with_logabsdet_jacobian(b::RationalQuadraticSpline{<:AbstractVector}, x::Real) return rqs_forward(b.widths, b.heights, b.derivatives, x) end + +function with_logabsdet_jacobian(b::RationalQuadraticSpline{<:AbstractMatrix}, x::AbstractVector) + return transform(b, x), logabsdetjac(b, x) +end diff --git a/src/bijectors/scale.jl b/src/bijectors/scale.jl index e19b33f5..bff549e5 100644 --- a/src/bijectors/scale.jl +++ b/src/bijectors/scale.jl @@ -1,57 +1,36 @@ -struct Scale{T, N} <: Bijector{N} +struct Scale{T} <: Bijector a::T end -Base.:(==)(b1::Scale{<:Any, N}, b2::Scale{<:Any, N}) where {N} = b1.a == b2.a +Base.:(==)(b1::Scale, b2::Scale) = b1.a == b2.a -function Scale(a::Union{Real,AbstractArray}; dim::Val{D} = Val(ndims(a))) where D - return Scale{typeof(a), D}(a) -end +Functors.@functor Scale -# field is a numerical parameter -function Functors.functor(::Type{<:Scale{<:Any,N}}, x) where N - function reconstruct_scale(xs) - return Scale{typeof(xs.a),N}(xs.a) - end - return (a = x.a,), reconstruct_scale -end +Base.show(io::IO, b::Scale) = print(io, "Scale($(b.a))") -up1(b::Scale{T, N}) where {N, T} = Scale{T, N + 1}(a) +with_logabsdet_jacobian(b::Scale, x) = transform(b, x), logabsdetjac(b, x) -(b::Scale)(x) = b.a .* x -(b::Scale{<:AbstractMatrix, 1})(x::AbstractVecOrMat) = b.a * x -(b::Scale{<:AbstractMatrix, 2})(x::AbstractMatrix) = b.a * x -(ib::Inverse{<:Scale})(y) = Scale(inv(ib.orig.a))(y) -(ib::Inverse{<:Scale{<:AbstractVector}})(y) = Scale(inv.(ib.orig.a))(y) -(ib::Inverse{<:Scale{<:AbstractMatrix, 1}})(y::AbstractVecOrMat) = ib.orig.a \ y -(ib::Inverse{<:Scale{<:AbstractMatrix, 2}})(y::AbstractMatrix) = ib.orig.a \ y +transform(b::Scale, x) = b.a .* x +transform(b::Scale{<:AbstractMatrix}, x::AbstractVecOrMat) = b.a * x +transform(ib::Inverse{<:Scale}, y) = transform(Scale(inv(ib.orig.a)), y) +transform(ib::Inverse{<:Scale{<:AbstractVector}}, y) = transform(Scale(inv.(ib.orig.a)), y) +transform(ib::Inverse{<:Scale{<:AbstractMatrix}}, y::AbstractVecOrMat) = ib.orig.a \ y # We're going to implement custom adjoint for this -logabsdetjac(b::Scale{<:Any, N}, x) where {N} = _logabsdetjac_scale(b.a, x, Val(N)) +logabsdetjac(b::Scale, x::Real) = _logabsdetjac_scale(b.a, x, Val(0)) +function logabsdetjac(b::Scale, x::AbstractArray{<:Real, N}) where {N} + return _logabsdetjac_scale(b.a, x, Val(N)) +end +# Scalar: single input. _logabsdetjac_scale(a::Real, x::Real, ::Val{0}) = log(abs(a)) -_logabsdetjac_scale(a::Real, x::AbstractVector, ::Val{0}) = fill(log(abs(a)), length(x)) _logabsdetjac_scale(a::Real, x::AbstractVector, ::Val{1}) = log(abs(a)) * length(x) -_logabsdetjac_scale(a::Real, x::AbstractMatrix, ::Val{1}) = fill(log(abs(a)) * size(x, 1), size(x, 2)) _logabsdetjac_scale(a::Real, x::AbstractMatrix, ::Val{2}) = log(abs(a)) * length(x) -_logabsdetjac_scale(a::Real, x::AbstractArray{<:AbstractMatrix}, ::Val{2}) = map(x) do x - _logabsdetjac_scale(a, x, Val(2)) -end -_logabsdetjac_scale(a::AbstractVector, x::AbstractVector, ::Val{1}) = sum(x -> log(abs(x)), a) -_logabsdetjac_scale(a::AbstractVector, x::AbstractMatrix, ::Val{1}) = fill(sum(x -> log(abs(x)), a), size(x, 2)) -_logabsdetjac_scale(a::AbstractVector, x::AbstractMatrix, ::Val{2}) = sum(x -> log(abs(x)), a) -_logabsdetjac_scale(a::AbstractVector, x::AbstractArray{<:AbstractMatrix}, ::Val{2}) = map(x) do x - _logabsdetjac_scale(a, x, Val(2)) -end + +# Vector: single input. +_logabsdetjac_scale(a::AbstractVector, x::AbstractVector, ::Val{1}) = sum(log ∘ abs, a) +_logabsdetjac_scale(a::AbstractVector, x::AbstractMatrix, ::Val{2}) = sum(log ∘ abs, a) + +# Matrix: single input. _logabsdetjac_scale(a::AbstractMatrix, x::AbstractVector, ::Val{1}) = logabsdet(a)[1] -_logabsdetjac_scale(a::AbstractMatrix, x::AbstractMatrix{T}, ::Val{1}) where {T} = logabsdet(a)[1] * ones(T, size(x, 2)) _logabsdetjac_scale(a::AbstractMatrix, x::AbstractMatrix, ::Val{2}) = logabsdet(a)[1] -function _logabsdetjac_scale( - a::AbstractMatrix, - x::AbstractArray{<:AbstractMatrix}, - ::Val{2}, -) - map(x) do x - _logabsdetjac_scale(a, x, Val(2)) - end -end \ No newline at end of file diff --git a/src/bijectors/shift.jl b/src/bijectors/shift.jl index e4e9960c..908815a6 100644 --- a/src/bijectors/shift.jl +++ b/src/bijectors/shift.jl @@ -1,37 +1,24 @@ ################# # Shift & Scale # ################# -struct Shift{T, N} <: Bijector{N} +struct Shift{T} <: Bijector a::T end -Base.:(==)(b1::Shift{<:Any, N}, b2::Shift{<:Any, N}) where {N} = b1.a == b2.a +Base.:(==)(b1::Shift, b2::Shift) = b1.a == b2.a -function Shift(a::Union{Real,AbstractArray}; dim::Val{D} = Val(ndims(a))) where D - return Shift{typeof(a), D}(a) -end - -# field is a numerical parameter -function Functors.functor(::Type{<:Shift{<:Any,N}}, x) where N - function reconstruct_shift(xs) - return Shift{typeof(xs.a),N}(xs.a) - end - return (a = x.a,), reconstruct_shift -end +Functors.@functor Shift -up1(b::Shift{T, N}) where {T, N} = Shift{T, N + 1}(b.a) +inverse(b::Shift) = Shift(-b.a) -(b::Shift)(x) = b.a .+ x -(b::Shift{<:Any, 2})(x::AbstractArray{<:AbstractMatrix}) = map(b, x) - -inverse(b::Shift{T, N}) where {T, N} = Shift{T, N}(-b.a) +transform(b::Shift, x) = b.a .+ x # FIXME: implement custom adjoint to ensure we don't get tracking -logabsdetjac(b::Shift{T, N}, x) where {T, N} = _logabsdetjac_shift(b.a, x, Val(N)) +function logabsdetjac(b::Shift, x::Union{Real, AbstractArray{<:Real}}) + return _logabsdetjac_shift(b.a, x) +end + +_logabsdetjac_shift(a, x) = zero(eltype(x)) +_logabsdetjac_shift_array_batch(a, x) = zeros(eltype(x), size(x, ndims(x))) -_logabsdetjac_shift(a::Real, x::Real, ::Val{0}) = zero(eltype(x)) -_logabsdetjac_shift(a::Real, x::AbstractVector{T}, ::Val{0}) where {T<:Real} = zeros(T, length(x)) -_logabsdetjac_shift(a::T1, x::AbstractVector{T2}, ::Val{1}) where {T1<:Union{Real, AbstractVector}, T2<:Real} = zero(T2) -_logabsdetjac_shift(a::T1, x::AbstractMatrix{T2}, ::Val{1}) where {T1<:Union{Real, AbstractVector}, T2<:Real} = zeros(T2, size(x, 2)) -_logabsdetjac_shift(a::T1, x::AbstractMatrix{T2}, ::Val{2}) where {T1<:Union{Real, AbstractVector}, T2<:Real} = zero(T2) -_logabsdetjac_shift(a::T1, x::AbstractArray{<:AbstractMatrix{T2}}, ::Val{2}) where {T1<:Union{Real, AbstractVector}, T2<:Real} = zeros(T2, size(x)) +with_logabsdet_jacobian(b::Shift, x) = transform(b, x), logabsdetjac(b, x) diff --git a/src/bijectors/simplex.jl b/src/bijectors/simplex.jl index 10ba4db3..1fbc28d2 100644 --- a/src/bijectors/simplex.jl +++ b/src/bijectors/simplex.jl @@ -1,20 +1,18 @@ #################### # Simplex bijector # #################### -struct SimplexBijector{N, T} <: Bijector{N} end -SimplexBijector() = SimplexBijector{1}() -SimplexBijector{N}() where {N} = SimplexBijector{N,true}() +struct SimplexBijector{T} <: Bijector end +SimplexBijector() = SimplexBijector{true}() -# Special case `N = 1` -SimplexBijector{true}() = SimplexBijector{1,true}() -SimplexBijector{false}() = SimplexBijector{1,false}() +with_logabsdet_jacobian(b::SimplexBijector, x) = transform(b, x), logabsdetjac(b, x) -(b::SimplexBijector{1})(x::AbstractVector) = _simplex_bijector(x, b) -(b::SimplexBijector{1})(y::AbstractVector, x::AbstractVector) = _simplex_bijector!(y, x, b) -function _simplex_bijector(x::AbstractVector, b::SimplexBijector{1}) - return _simplex_bijector!(similar(x), x, b) -end -function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector{1, proj}) where {proj} +transform(b::SimplexBijector, x) = _simplex_bijector(x, b) +transform!(b::SimplexBijector, y, x) = _simplex_bijector!(y, x, b) + +_simplex_bijector(x::AbstractArray, b::SimplexBijector) = _simplex_bijector!(similar(x), x, b) + +# Vector implementation. +function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector{proj}) where {proj} K = length(x) @assert K > 1 "x needs to be of length greater than 1" T = eltype(x) @@ -39,24 +37,8 @@ function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector{1, proj}) wh return y end -# Vectorised implementation of the above. -function (b::SimplexBijector{1})(X::AbstractMatrix) - _simplex_bijector(X, b) -end -function (b::SimplexBijector{1})( - Y::AbstractMatrix, - X::AbstractMatrix, -) - _simplex_bijector!(Y, X, b) -end -function (b::SimplexBijector{2, proj})(X::AbstractMatrix) where {proj} - SimplexBijector{1, proj}()(X) -end -(b::SimplexBijector{2})(X::AbstractArray{<:AbstractMatrix}) = map(b, X) -function _simplex_bijector(X::AbstractMatrix, b::SimplexBijector{1}) - _simplex_bijector!(similar(X), X, b) -end -function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector{1, proj}) where {proj} +# Matrix implementation. +function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector{proj}) where {proj} K, N = size(X, 1), size(X, 2) @assert K > 1 "x needs to be of length greater than 1" T = eltype(X) @@ -81,19 +63,19 @@ function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector{1, proj}) wh return Y end -function (ib::Inverse{<:SimplexBijector{1, proj}})(y::AbstractVector{T}) where {T, proj} - _simplex_inv_bijector(y, ib.orig) -end -function (ib::Inverse{<:SimplexBijector{1}})( - x::AbstractVector{T}, - y::AbstractVector{T}, +# Inverse. +transform(ib::Inverse{<:SimplexBijector}, y::AbstractArray) = _simplex_inv_bijector(y, ib.orig) +function transform!( + ib::Inverse{<:SimplexBijector}, + x::AbstractArray{T}, + y::AbstractArray{T}, ) where {T} - _simplex_inv_bijector!(x, y, ib.orig) + return _simplex_inv_bijector!(x, y, ib.orig) end -function _simplex_inv_bijector(y::AbstractVector, b::SimplexBijector{1}) - return _simplex_inv_bijector!(similar(y), y, b) -end -function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector{1, proj}) where {proj} + +_simplex_inv_bijector(y, b) = _simplex_inv_bijector!(similar(y), y, b) + +function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector{proj}) where {proj} K = length(y) @assert K > 1 "x needs to be of length greater than 1" T = eltype(y) @@ -116,27 +98,7 @@ function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector{1, proj return x end -# Vectorised implementation of the above. -function (ib::Inverse{<:SimplexBijector{1}})(Y::AbstractMatrix) - _simplex_inv_bijector(Y, ib.orig) -end -function (ib::Inverse{<:SimplexBijector{1}})( - X::AbstractMatrix{T}, - Y::AbstractMatrix{T}, -) where {T <: Real} - _simplex_inv_bijector!(X, Y, ib.orig) -end -function (ib::Inverse{<:SimplexBijector{2, proj}})(Y::AbstractMatrix) where {proj} - inverse(SimplexBijector{1, proj}())(Y) -end -function (ib::Inverse{<:SimplexBijector{2, proj}})(X::AbstractMatrix, Y::AbstractMatrix) where {proj} - inverse(SimplexBijector{1, proj}())(X, Y) -end -(ib::Inverse{<:SimplexBijector{2}})(Y::AbstractArray{<:AbstractMatrix}) = map(ib, Y) -function _simplex_inv_bijector(Y::AbstractMatrix, b::SimplexBijector{1}) - _simplex_inv_bijector!(similar(Y), Y, b) -end -function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector{1, proj}) where {proj} +function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector{proj}) where {proj} K, N = size(Y, 1), size(Y, 2) @assert K > 1 "x needs to be of length greater than 1" T = eltype(Y) @@ -160,7 +122,7 @@ function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector{1, proj return X end -function logabsdetjac(b::SimplexBijector{1}, x::AbstractVector{T}) where {T} +function logabsdetjac(b::SimplexBijector, x::AbstractVector{T}) where {T} ϵ = _eps(T) lp = zero(T) @@ -211,30 +173,7 @@ function simplex_logabsdetjac_gradient(x::AbstractVector) end return g end -function logabsdetjac(b::SimplexBijector{1}, x::AbstractMatrix{T}) where {T} - ϵ = _eps(T) - nlp = similar(x, T, size(x, 2)) - nlp .= zero(T) - K = size(x, 1) - for col in 1:size(x, 2) - sum_tmp = zero(eltype(x)) - z = x[1,col] - nlp[col] -= log(max(z, ϵ)) + log(max(one(T) - z, ϵ)) - for k in 2:(K - 1) - sum_tmp += x[k-1,col] - z = x[k,col] / max(one(T) - sum_tmp, ϵ) - nlp[col] -= log(max(z, ϵ)) + log(max(one(T) - z, ϵ)) + log(max(one(T) - sum_tmp, ϵ)) - end - end - return nlp -end -function logabsdetjac(b::SimplexBijector{2, proj}, x::AbstractMatrix) where {proj} - return sum(logabsdetjac(SimplexBijector{1, proj}(), x)) -end -function logabsdetjac(b::SimplexBijector{2}, x::AbstractArray{<:AbstractMatrix}) - return map(x -> logabsdetjac(b, x), x) -end function simplex_logabsdetjac_gradient(x::AbstractMatrix) T = eltype(x) ϵ = _eps(T) @@ -303,7 +242,7 @@ function simplex_link_jacobian( end return UpperTriangular(dydxt)' end -function jacobian(b::SimplexBijector{1, proj}, x::AbstractVector{T}) where {proj, T} +function jacobian(b::SimplexBijector{proj}, x::AbstractVector{T}) where {proj, T} return simplex_link_jacobian(x, Val(proj)) end @@ -425,7 +364,7 @@ function simplex_invlink_jacobian( return LowerTriangular(dxdy) end # jacobian -function jacobian(ib::Inverse{<:SimplexBijector{1, proj}}, y::AbstractVector{T}) where {proj, T} +function jacobian(ib::Inverse{<:SimplexBijector{proj}}, y::AbstractVector{T}) where {proj, T} return simplex_invlink_jacobian(y, Val(proj)) end diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 83605188..eec82eab 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -1,7 +1,7 @@ """ Stacked(bs) Stacked(bs, ranges) - stack(bs::Bijector{0}...) # where `0` means 0-dim `Bijector` + stack(bs::Bijector...) A `Bijector` which stacks bijectors together which can then be applied to a vector where `bs[i]::Bijector` is applied to `x[ranges[i]]::UnitRange{Int}`. @@ -16,27 +16,24 @@ where `bs[i]::Bijector` is applied to `x[ranges[i]]::UnitRange{Int}`. # Examples ``` b1 = Logit(0.0, 1.0) -b2 = Identity{0}() +b2 = Identity() b = stack(b1, b2) b([0.0, 1.0]) == [b1(0.0), 1.0] # => true ``` """ -struct Stacked{Bs, Rs} <: Bijector{1} +struct Stacked{Bs, Rs} <: Transform bs::Bs ranges::Rs end Stacked(bs::Tuple) = Stacked(bs, ntuple(i -> i:i, length(bs))) Stacked(bs::AbstractArray) = Stacked(bs, [i:i for i in 1:length(bs)]) -# define nested numerical parameters -# TODO: replace with `Functors.@functor Stacked (bs,)` when -# https://github.com/FluxML/Functors.jl/pull/7 is merged -function Functors.functor(::Type{<:Stacked}, x) - function reconstruct_stacked(xs) - return Stacked(xs.bs, x.ranges) - end - return (bs = x.bs,), reconstruct_stacked -end +# Avoid mixing tuples and arrays. +Stacked(bs::Tuple, ranges::AbstractArray) = Stacked(collect(bs), ranges) + +Functors.@functor Stacked (bs,) + +Base.show(io::IO, b::Stacked) = print(io, "Stacked($(b.bs), $(b.ranges))") function Base.:(==)(b1::Stacked, b2::Stacked) bs1, bs2 = b1.bs, b2.bs @@ -48,7 +45,9 @@ end isclosedform(b::Stacked) = all(isclosedform, b.bs) -stack(bs::Bijector{0}...) = Stacked(bs) +isinvertible(b::Stacked) = all(isinvertible, b.bs) + +stack(bs...) = Stacked(bs) # For some reason `inverse.(sb.bs)` was unstable... This works though. inverse(sb::Stacked) = Stacked(map(inverse, sb.bs), sb.ranges) @@ -62,24 +61,24 @@ inverse(sb::Stacked) = Stacked(map(inverse, sb.bs), sb.ranges) :(Stacked(($(exprs...), ), sb.ranges)) end -@generated function _transform(x, rs::NTuple{N, UnitRange{Int}}, bs::Bijector...) where N +@generated function _transform(x, rs::NTuple{N, UnitRange{Int}}, bs...) where N exprs = [] for i = 1:N push!(exprs, :(bs[$i](x[rs[$i]]))) end return :(vcat($(exprs...))) end -function _transform(x, rs::NTuple{1, UnitRange{Int}}, b::Bijector) +function _transform(x, rs::NTuple{1, UnitRange{Int}}, b) @assert rs[1] == 1:length(x) return b(x) end -function (sb::Stacked{<:Tuple,<:Tuple})(x::AbstractVector{<:Real}) +function transform(sb::Stacked{<:Tuple,<:Tuple}, x::AbstractVector{<:Real}) y = _transform(x, sb.ranges, sb.bs...) @assert size(y) == size(x) "x is size $(size(x)) but y is $(size(y))" return y end # The Stacked{<:AbstractArray} version is not TrackedArray friendly -function (sb::Stacked)(x::AbstractVector{<:Real}) +function transform(sb::Stacked{<:AbstractArray}, x::AbstractVector{<:Real}) N = length(sb.bs) N == 1 && return sb.bs[1](x[sb.ranges[1]]) @@ -90,8 +89,6 @@ function (sb::Stacked)(x::AbstractVector{<:Real}) return y end -(sb::Stacked)(x::AbstractMatrix{<:Real}) = eachcolmaphcat(sb, x) - function logabsdetjac( b::Stacked, x::AbstractVector{<:Real} @@ -123,12 +120,6 @@ function logabsdetjac( end end -function logabsdetjac(b::Stacked, x::AbstractMatrix{<:Real}) - return map(eachcol(x)) do c - logabsdetjac(b, c) - end -end - # Generates something similar to: # # quote diff --git a/src/bijectors/truncated.jl b/src/bijectors/truncated.jl index d8d082e1..52c09f4d 100644 --- a/src/bijectors/truncated.jl +++ b/src/bijectors/truncated.jl @@ -1,60 +1,22 @@ ####################################################### # Constrained to unconstrained distribution bijectors # ####################################################### -struct TruncatedBijector{N, T1, T2} <: Bijector{N} +struct TruncatedBijector{T1, T2} <: Bijector lb::T1 ub::T2 end -TruncatedBijector(lb, ub) = TruncatedBijector{0}(lb, ub) -function TruncatedBijector{N}(lb::T1, ub::T2) where {N, T1, T2} - return TruncatedBijector{N, T1, T2}(lb, ub) -end -# field are numerical parameters -function Functors.functor(::Type{<:TruncatedBijector{N}}, x) where N - function reconstruct_truncatedbijector(xs) - return TruncatedBijector{N}(xs.lb, xs.ub) - end - return (lb = x.lb, ub = x.ub,), reconstruct_truncatedbijector -end - -up1(b::TruncatedBijector{N}) where {N} = TruncatedBijector{N + 1}(b.lb, b.ub) +Functors.@functor TruncatedBijector function Base.:(==)(b1::TruncatedBijector, b2::TruncatedBijector) return b1.lb == b2.lb && b1.ub == b2.ub end -function (b::TruncatedBijector{0})(x::Real) +function transform(b::TruncatedBijector, x) a, b = b.lb, b.ub - truncated_link(_clamp(x, a, b), a, b) + return truncated_link.(_clamp.(x, a, b), a, b) end -function (b::TruncatedBijector{0})(x::AbstractArray{<:Real}) - a, b = b.lb, b.ub - truncated_link.(_clamp.(x, a, b), a, b) -end -function (b::TruncatedBijector{1})(x::AbstractVecOrMat{<:Real}) - a, b = b.lb, b.ub - if a isa AbstractVector - @assert b isa AbstractVector - maporbroadcast(x, a, b) do x, a, b - truncated_link(_clamp(x, a, b), a, b) - end - else - truncated_link.(_clamp.(x, a, b), a, b) - end -end -function (b::TruncatedBijector{2})(x::AbstractMatrix{<:Real}) - a, b = b.lb, b.ub - if a isa AbstractMatrix - @assert b isa AbstractMatrix - maporbroadcast(x, a, b) do x, a, b - truncated_link(_clamp(x, a, b), a, b) - end - else - truncated_link.(_clamp.(x, a, b), a, b) - end -end -(b::TruncatedBijector{2})(x::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, x) + function truncated_link(x::Real, a, b) lowerbounded, upperbounded = isfinite(a), isfinite(b) if lowerbounded && upperbounded @@ -68,37 +30,11 @@ function truncated_link(x::Real, a, b) end end -function (ib::Inverse{<:TruncatedBijector{0}})(y::Real) - a, b = ib.orig.lb, ib.orig.ub - _clamp(truncated_invlink(y, a, b), a, b) -end -function (ib::Inverse{<:TruncatedBijector{0}})(y::AbstractArray{<:Real}) - a, b = ib.orig.lb, ib.orig.ub - _clamp.(truncated_invlink.(y, a, b), a, b) -end -function (ib::Inverse{<:TruncatedBijector{1}})(y::AbstractVecOrMat{<:Real}) +function transform(ib::Inverse{<:TruncatedBijector}, y) a, b = ib.orig.lb, ib.orig.ub - if a isa AbstractVector - @assert b isa AbstractVector - maporbroadcast(y, a, b) do y, a, b - _clamp(truncated_invlink(y, a, b), a, b) - end - else - _clamp.(truncated_invlink.(y, a, b), a, b) - end -end -function (ib::Inverse{<:TruncatedBijector{2}})(y::AbstractMatrix{<:Real}) - a, b = ib.orig.lb, ib.orig.ub - if a isa AbstractMatrix - @assert b isa AbstractMatrix - return maporbroadcast(y, a, b) do y, a, b - _clamp(truncated_invlink(y, a, b), a, b) - end - else - return _clamp.(truncated_invlink.(y, a, b), a, b) - end + return _clamp.(truncated_invlink.(y, a, b), a, b) end -(ib::Inverse{<:TruncatedBijector{2}})(y::AbstractArray{<:AbstractMatrix{<:Real}}) = map(ib, y) + function truncated_invlink(y, a, b) lowerbounded, upperbounded = isfinite(a), isfinite(b) if lowerbounded && upperbounded @@ -112,31 +48,11 @@ function truncated_invlink(y, a, b) end end -function logabsdetjac(b::TruncatedBijector{0}, x::Real) +function logabsdetjac(b::TruncatedBijector, x) a, b = b.lb, b.ub - truncated_logabsdetjac(_clamp(x, a, b), a, b) -end -function logabsdetjac(b::TruncatedBijector{0}, x::AbstractArray{<:Real}) - a, b = b.lb, b.ub - truncated_logabsdetjac.(_clamp.(x, a, b), a, b) -end -function logabsdetjac(b::TruncatedBijector{1}, x::AbstractVector{<:Real}) - a, b = b.lb, b.ub - sum(truncated_logabsdetjac.(_clamp.(x, a, b), a, b)) -end -function logabsdetjac(b::TruncatedBijector{1}, x::AbstractMatrix{<:Real}) - a, b = b.lb, b.ub - vec(sum(truncated_logabsdetjac.(_clamp.(x, a, b), a, b), dims = 1)) -end -function logabsdetjac(b::TruncatedBijector{2}, x::AbstractMatrix{<:Real}) - a, b = b.lb, b.ub - sum(truncated_logabsdetjac.(_clamp.(x, a, b), a, b)) -end -function logabsdetjac(b::TruncatedBijector{2}, x::AbstractArray{<:AbstractMatrix{<:Real}}) - map(x) do x - logabsdetjac(b, x) - end + return sum(truncated_logabsdetjac.(_clamp.(x, a, b), a, b)) end + function truncated_logabsdetjac(x, a, b) lowerbounded, upperbounded = isfinite(a), isfinite(b) if lowerbounded && upperbounded @@ -149,3 +65,5 @@ function truncated_logabsdetjac(x, a, b) return zero(x) end end + +with_logabsdet_jacobian(b::TruncatedBijector, x) = transform(b, x), logabsdetjac(b, x) diff --git a/src/compat/distributionsad.jl b/src/compat/distributionsad.jl index 5402ce00..8abc5b36 100644 --- a/src/compat/distributionsad.jl +++ b/src/compat/distributionsad.jl @@ -9,28 +9,28 @@ using Distributions: AbstractMvLogNormal bijector(::TuringDirichlet) = SimplexBijector() bijector(::TuringWishart) = PDBijector() bijector(::TuringInverseWishart) = PDBijector() -bijector(::TuringScalMvNormal) = Identity{1}() -bijector(::TuringDiagMvNormal) = Identity{1}() -bijector(::TuringDenseMvNormal) = Identity{1}() +bijector(::TuringScalMvNormal) = Identity() +bijector(::TuringDiagMvNormal) = Identity() +bijector(::TuringDenseMvNormal) = Identity() -bijector(d::FillVectorOfUnivariate{Continuous}) = up1(bijector(d.v.value)) -bijector(d::FillMatrixOfUnivariate{Continuous}) = up1(up1(bijector(d.dists.value))) -bijector(d::MatrixOfUnivariate{Discrete}) = Identity{2}() -bijector(d::MatrixOfUnivariate{Continuous}) = TruncatedBijector{2}(_minmax(d.dists)...) -bijector(d::VectorOfMultivariate{Discrete}) = Identity{2}() +bijector(d::FillVectorOfUnivariate{Continuous}) = bijector(d.v.value) +bijector(d::FillMatrixOfUnivariate{Continuous}) = up1(bijector(d.dists.value)) +bijector(d::MatrixOfUnivariate{Discrete}) = Identity() +bijector(d::MatrixOfUnivariate{Continuous}) = TruncatedBijector(_minmax(d.dists)...) +bijector(d::VectorOfMultivariate{Discrete}) = Identity() for T in (:VectorOfMultivariate, :FillVectorOfMultivariate) @eval begin - bijector(d::$T{Continuous, <:MvNormal}) = Identity{2}() - bijector(d::$T{Continuous, <:TuringScalMvNormal}) = Identity{2}() - bijector(d::$T{Continuous, <:TuringDiagMvNormal}) = Identity{2}() - bijector(d::$T{Continuous, <:TuringDenseMvNormal}) = Identity{2}() - bijector(d::$T{Continuous, <:MvNormalCanon}) = Identity{2}() - bijector(d::$T{Continuous, <:AbstractMvLogNormal}) = Log{2}() - bijector(d::$T{Continuous, <:SimplexDistribution}) = SimplexBijector{2}() - bijector(d::$T{Continuous, <:TuringDirichlet}) = SimplexBijector{2}() + bijector(d::$T{Continuous, <:MvNormal}) = Identity() + bijector(d::$T{Continuous, <:TuringScalMvNormal}) = Identity() + bijector(d::$T{Continuous, <:TuringDiagMvNormal}) = Identity() + bijector(d::$T{Continuous, <:TuringDenseMvNormal}) = Identity() + bijector(d::$T{Continuous, <:MvNormalCanon}) = Identity() + bijector(d::$T{Continuous, <:AbstractMvLogNormal}) = Log() + bijector(d::$T{Continuous, <:SimplexDistribution}) = SimplexBijector() + bijector(d::$T{Continuous, <:TuringDirichlet}) = SimplexBijector() end end -bijector(d::FillVectorOfMultivariate{Continuous}) = up1(bijector(d.dists.value)) +bijector(d::FillVectorOfMultivariate{Continuous}) = bijector(d.dists.value) isdirichlet(::VectorOfMultivariate{Continuous, <:Dirichlet}) = true isdirichlet(::VectorOfMultivariate{Continuous, <:TuringDirichlet}) = true diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index 116d8531..86130418 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -4,7 +4,7 @@ using ..ReverseDiff: ReverseDiff, @grad, value, track, TrackedReal, TrackedVecto TrackedMatrix using Requires, LinearAlgebra -using ..Bijectors: Log, SimplexBijector, maphcat, simplex_link_jacobian, +using ..Bijectors: Elementwise, SimplexBijector, maphcat, simplex_link_jacobian, simplex_invlink_jacobian, simplex_logabsdetjac_gradient, ADBijector, ReverseDiffAD, Inverse import ..Bijectors: _eps, logabsdetjac, _logabsdetjac_scale, _simplex_bijector, @@ -49,13 +49,10 @@ function Base.maximum(d::LocationScale{<:TrackedReal}) end end -logabsdetjac(b::Log{1}, x::Union{TrackedVector, TrackedMatrix}) = track(logabsdetjac, b, x) -@grad function logabsdetjac(b::Log{1}, x::AbstractVector) +logabsdetjac(b::Elementwise{typeof(log)}, x::Union{TrackedVector, TrackedMatrix}) = track(logabsdetjac, b, x) +@grad function logabsdetjac(b::Elementwise{typeof(log)}, x::AbstractVector) return -sum(log, value(x)), Δ -> (nothing, -Δ ./ value(x)) end -@grad function logabsdetjac(b::Log{1}, x::AbstractMatrix) - return -vec(sum(log, value(x); dims = 1)), Δ -> (nothing, .- Δ' ./ value(x)) -end function _logabsdetjac_scale(a::TrackedReal, x::Real, ::Val{0}) return track(_logabsdetjac_scale, a, value(x), Val(0)) end @@ -100,30 +97,22 @@ end Jᵀ = repeat(inv.(da), 1, size(x, 2)) return _logabsdetjac_scale(da, value(x), Val(1)), Δ -> (Jᵀ * Δ, nothing, nothing) end -function _simplex_bijector(X::Union{TrackedVector, TrackedMatrix}, b::SimplexBijector{1}) +function _simplex_bijector(X::Union{TrackedVector, TrackedMatrix}, b::SimplexBijector) return track(_simplex_bijector, X, b) end -@grad function _simplex_bijector(Y::AbstractVector, b::SimplexBijector{1}) +@grad function _simplex_bijector(Y::AbstractVector, b::SimplexBijector) Yd = value(Y) return _simplex_bijector(Yd, b), Δ -> (simplex_link_jacobian(Yd)' * Δ, nothing) end -@grad function _simplex_bijector(Y::AbstractMatrix, b::SimplexBijector{1}) - Yd = value(Y) - return _simplex_bijector(Yd, b), Δ -> begin - maphcat(eachcol(Yd), eachcol(Δ)) do c1, c2 - simplex_link_jacobian(c1)' * c2 - end, nothing - end -end -function _simplex_inv_bijector(X::Union{TrackedVector, TrackedMatrix}, b::SimplexBijector{1}) +function _simplex_inv_bijector(X::Union{TrackedVector, TrackedMatrix}, b::SimplexBijector) return track(_simplex_inv_bijector, X, b) end -@grad function _simplex_inv_bijector(Y::AbstractVector, b::SimplexBijector{1}) +@grad function _simplex_inv_bijector(Y::AbstractVector, b::SimplexBijector) Yd = value(Y) return _simplex_inv_bijector(Yd, b), Δ -> (simplex_invlink_jacobian(Yd)' * Δ, nothing) end -@grad function _simplex_inv_bijector(Y::AbstractMatrix, b::SimplexBijector{1}) +@grad function _simplex_inv_bijector(Y::AbstractMatrix, b::SimplexBijector) Yd = value(Y) return _simplex_inv_bijector(Yd, b), Δ -> begin maphcat(eachcol(Yd), eachcol(Δ)) do c1, c2 @@ -154,21 +143,13 @@ replace_diag(::typeof(exp), X::TrackedMatrix) = track(replace_diag, exp, X) end end -logabsdetjac(b::SimplexBijector{1}, x::Union{TrackedVector, TrackedMatrix}) = track(logabsdetjac, b, x) -@grad function logabsdetjac(b::SimplexBijector{1}, x::AbstractVector) +logabsdetjac(b::SimplexBijector, x::Union{TrackedVector, TrackedMatrix}) = track(logabsdetjac, b, x) +@grad function logabsdetjac(b::SimplexBijector, x::AbstractVector) xd = value(x) return logabsdetjac(b, xd), Δ -> begin (nothing, simplex_logabsdetjac_gradient(xd) * Δ) end end -@grad function logabsdetjac(b::SimplexBijector{1}, x::AbstractMatrix) - xd = value(x) - return logabsdetjac(b, xd), Δ -> begin - (nothing, maphcat(eachcol(xd), Δ) do c, g - simplex_logabsdetjac_gradient(c) * g - end) - end -end getpd(X::TrackedMatrix) = track(getpd, X) @grad function getpd(X::AbstractMatrix) diff --git a/src/compat/tracker.jl b/src/compat/tracker.jl index eed09521..53727813 100644 --- a/src/compat/tracker.jl +++ b/src/compat/tracker.jl @@ -12,8 +12,8 @@ using ..Tracker: Tracker, param import ..Bijectors -using ..Bijectors: Log, SimplexBijector, ADBijector, - TrackerAD, Inverse, Stacked, Exp +using ..Bijectors: Elementwise, SimplexBijector, ADBijector, + TrackerAD, Inverse, Stacked import ChainRulesCore import LogExpFunctions @@ -91,13 +91,10 @@ end # Log bijector -@grad function Bijectors.logabsdetjac(b::Log{1}, x::AbstractVector) +@grad function Bijectors.logabsdetjac(b::Elementwise{typeof(log)}, x::AbstractVector) return -sum(log, data(x)), Δ -> (nothing, -Δ ./ data(x)) end -@grad function Bijectors.logabsdetjac(b::Log{1}, x::AbstractMatrix) - return -vec(sum(log, data(x); dims = 1)), Δ -> (nothing, .- Δ' ./ data(x)) -end -@grad function Bijectors.logabsdetjac(b::Log{2}, x::AbstractMatrix) +@grad function Bijectors.logabsdetjac(b::Elementwise{typeof(log)}, x::AbstractMatrix) return -sum(log, data(x)), Δ -> (nothing, -Δ ./ data(x)) end @@ -154,49 +151,23 @@ end # @grad function _logabsdetjac_scale(a::TrackedMatrix, x::AbstractVector, ::Val{1}) # throw # end -# implementations for Stacked bijector -function Bijectors.logabsdetjac(b::Stacked, x::TrackedMatrix{<:Real}) - return map(eachcol(x)) do c - Bijectors.logabsdetjac(b, c) - end -end -# TODO: implement custom adjoint since we can exploit block-diagonal nature of `Stacked` -function (sb::Stacked)(x::TrackedMatrix{<:Real}) - return Bijectors.eachcolmaphcat(sb, x) -end + # Simplex adjoints -function Bijectors._simplex_bijector(X::TrackedVecOrMat, b::SimplexBijector{1}) +function Bijectors._simplex_bijector(X::TrackedVecOrMat, b::SimplexBijector) return track(Bijectors._simplex_bijector, X, b) end -function Bijectors._simplex_inv_bijector(Y::TrackedVecOrMat, b::SimplexBijector{1}) +function Bijectors._simplex_inv_bijector(Y::TrackedVecOrMat, b::SimplexBijector) return track(Bijectors._simplex_inv_bijector, Y, b) end -@grad function Bijectors._simplex_bijector(X::AbstractVector, b::SimplexBijector{1}) +@grad function Bijectors._simplex_bijector(X::AbstractVector, b::SimplexBijector) Xd = data(X) return Bijectors._simplex_bijector(Xd, b), Δ -> (Bijectors.simplex_link_jacobian(Xd)' * Δ, nothing) end -@grad function Bijectors._simplex_inv_bijector(Y::AbstractVector, b::SimplexBijector{1}) +@grad function Bijectors._simplex_inv_bijector(Y::AbstractVector, b::SimplexBijector) Yd = data(Y) return Bijectors._simplex_inv_bijector(Yd, b), Δ -> (Bijectors.simplex_invlink_jacobian(Yd)' * Δ, nothing) end -@grad function Bijectors._simplex_bijector(X::AbstractMatrix, b::SimplexBijector{1}) - Xd = data(X) - return Bijectors._simplex_bijector(Xd, b), Δ -> begin - Bijectors.maphcat(eachcol(Xd), eachcol(Δ)) do c1, c2 - Bijectors.simplex_link_jacobian(c1)' * c2 - end, nothing - end -end -@grad function Bijectors._simplex_inv_bijector(Y::AbstractMatrix, b::SimplexBijector{1}) - Yd = data(Y) - return Bijectors._simplex_inv_bijector(Yd, b), Δ -> begin - Bijectors.maphcat(eachcol(Yd), eachcol(Δ)) do c1, c2 - Bijectors.simplex_invlink_jacobian(c1)' * c2 - end, nothing - end -end - Bijectors.replace_diag(::typeof(log), X::TrackedMatrix) = track(Bijectors.replace_diag, log, X) @grad function Bijectors.replace_diag(::typeof(log), X) Xd = data(X) @@ -219,21 +190,13 @@ Bijectors.replace_diag(::typeof(exp), X::TrackedMatrix) = track(Bijectors.replac end end -Bijectors.logabsdetjac(b::SimplexBijector{1}, x::TrackedVecOrMat) = track(Bijectors.logabsdetjac, b, x) -@grad function Bijectors.logabsdetjac(b::SimplexBijector{1}, x::AbstractVector) +Bijectors.logabsdetjac(b::SimplexBijector, x::TrackedVecOrMat) = track(Bijectors.logabsdetjac, b, x) +@grad function Bijectors.logabsdetjac(b::SimplexBijector, x::AbstractVector) xd = data(x) return Bijectors.logabsdetjac(b, xd), Δ -> begin (nothing, Bijectors.simplex_logabsdetjac_gradient(xd) * Δ) end end -@grad function Bijectors.logabsdetjac(b::SimplexBijector{1}, x::AbstractMatrix) - xd = data(x) - return Bijectors.logabsdetjac(b, xd), Δ -> begin - (nothing, Bijectors.maphcat(eachcol(xd), Δ) do c, g - Bijectors.simplex_logabsdetjac_gradient(c) * g - end) - end -end for header in [ (:(α_::TrackedReal), :β, :z_0, :(z::AbstractVector)), @@ -327,18 +290,12 @@ function vectorof(::Type{TrackedReal{T}}) where {T<:Real} return TrackedArray{T,1,Vector{T}} end -(b::Exp{0})(x::TrackedVector) = exp.(x)::vectorof(float(eltype(x))) -(b::Exp{1})(x::TrackedVector) = exp.(x)::vectorof(float(eltype(x))) -(b::Exp{1})(x::TrackedMatrix) = exp.(x)::matrixof(float(eltype(x))) -(b::Exp{2})(x::TrackedMatrix) = exp.(x)::matrixof(float(eltype(x))) - -(b::Log{0})(x::TrackedVector) = log.(x)::vectorof(float(eltype(x))) -(b::Log{1})(x::TrackedVector) = log.(x)::vectorof(float(eltype(x))) -(b::Log{1})(x::TrackedMatrix) = log.(x)::matrixof(float(eltype(x))) -(b::Log{2})(x::TrackedMatrix) = log.(x)::matrixof(float(eltype(x))) +(b::Elementwise{typeof(exp)})(x::TrackedVector) = exp.(x)::vectorof(float(eltype(x))) +(b::Elementwise{typeof(exp)})(x::TrackedVector) = exp.(x)::vectorof(float(eltype(x))) +(b::Elementwise{typeof(exp)})(x::TrackedMatrix) = exp.(x)::matrixof(float(eltype(x))) -Bijectors.logabsdetjac(b::Log{0}, x::TrackedVector) = .-log.(x)::vectorof(float(eltype(x))) -Bijectors.logabsdetjac(b::Log{1}, x::TrackedMatrix) = - vec(sum(log.(x); dims = 1)) +(b::Elementwise{typeof(log)})(x::TrackedVector) = log.(x)::vectorof(float(eltype(x))) +(b::Elementwise{typeof(log)})(x::TrackedMatrix) = log.(x)::matrixof(float(eltype(x))) Bijectors.getpd(X::TrackedMatrix) = track(Bijectors.getpd, X) @grad function Bijectors.getpd(X::AbstractMatrix) diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl index 362915a7..0b19467b 100644 --- a/src/compat/zygote.jl +++ b/src/compat/zygote.jl @@ -27,12 +27,9 @@ end return pullback(g, f, x1, x2) end -@adjoint function logabsdetjac(b::Log{1}, x::AbstractVector) +@adjoint function logabsdetjac(b::Elementwise{typeof(log)}, x::AbstractVector) return -sum(log, x), Δ -> (nothing, -Δ ./ x) end -@adjoint function logabsdetjac(b::Log{1}, x::AbstractMatrix) - return -vec(sum(log, x; dims = 1)), Δ -> (nothing, .- Δ' ./ x) -end # AD implementations function jacobian( @@ -121,21 +118,21 @@ end # Simplex adjoints -@adjoint function _simplex_bijector(X::AbstractVector, b::SimplexBijector{1}) +@adjoint function _simplex_bijector(X::AbstractVector, b::SimplexBijector) return _simplex_bijector(X, b), Δ -> (simplex_link_jacobian(X)' * Δ, nothing) end -@adjoint function _simplex_inv_bijector(Y::AbstractVector, b::SimplexBijector{1}) +@adjoint function _simplex_inv_bijector(Y::AbstractVector, b::SimplexBijector) return _simplex_inv_bijector(Y, b), Δ -> (simplex_invlink_jacobian(Y)' * Δ, nothing) end -@adjoint function _simplex_bijector(X::AbstractMatrix, b::SimplexBijector{1}) +@adjoint function _simplex_bijector(X::AbstractMatrix, b::SimplexBijector) return _simplex_bijector(X, b), Δ -> begin maphcat(eachcol(X), eachcol(Δ)) do c1, c2 simplex_link_jacobian(c1)' * c2 end, nothing end end -@adjoint function _simplex_inv_bijector(Y::AbstractMatrix, b::SimplexBijector{1}) +@adjoint function _simplex_inv_bijector(Y::AbstractMatrix, b::SimplexBijector) return _simplex_inv_bijector(Y, b), Δ -> begin maphcat(eachcol(Y), eachcol(Δ)) do c1, c2 simplex_invlink_jacobian(c1)' * c2 @@ -143,18 +140,11 @@ end end end -@adjoint function logabsdetjac(b::SimplexBijector{1}, x::AbstractVector) +@adjoint function logabsdetjac(b::SimplexBijector, x::AbstractVector) return logabsdetjac(b, x), Δ -> begin (nothing, simplex_logabsdetjac_gradient(x) * Δ) end end -@adjoint function logabsdetjac(b::SimplexBijector{1}, x::AbstractMatrix) - return logabsdetjac(b, x), Δ -> begin - (nothing, maphcat(eachcol(x), Δ) do c, g - simplex_logabsdetjac_gradient(c) * g - end) - end -end # LocationScale fix @@ -292,4 +282,3 @@ end return z, pullback_link_chol_lkj end - diff --git a/src/interface.jl b/src/interface.jl index 276ce631..5487d01a 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -3,6 +3,21 @@ import Base: ∘ import Random: AbstractRNG import Distributions: logpdf, rand, rand!, _rand!, _logpdf +const Elementwise{F} = Base.Fix1{<:Union{typeof(map),typeof(broadcast)}, F} +""" + elementwise(f) + +Alias for `Base.Fix1(broadcast, f)`. + +In the case where `f::ComposedFunction`, the result is +`Base.Fix1(broadcast, f.outer) ∘ Base.Fix1(broadcast, f.inner)` rather than +`Base.Fix1(broadcast, f)`. +""" +elementwise(f) = Base.Fix1(broadcast, f) +# TODO: This is makes dispatching quite a bit easier, but uncertain if this is really +# the way to go. +elementwise(f::ComposedFunction) = ComposedFunction(elementwise(f.outer), elementwise(f.inner)) + ####################################### # AD stuff "extracted" from Turing.jl # ####################################### @@ -31,111 +46,168 @@ ADBackend(::Val) = error("The requested AD backend is not available. Make sure t ###################### # Bijector interface # ###################### -"Abstract type for a bijector." -abstract type AbstractBijector end +""" + +Abstract type for a transformation. -"Abstract type of bijectors with fixed dimensionality." -abstract type Bijector{N} <: AbstractBijector end +# Implementing -dimension(b::Bijector{N}) where {N} = N -dimension(b::Type{<:Bijector{N}}) where {N} = N +A subtype of `Transform` of should at least implement [`transform(b, x)`](@ref). -Broadcast.broadcastable(b::Bijector) = Ref(b) +If the `Transform` is also invertible: +- Required: + - _Either_ of the following: + - `transform(::Inverse{<:MyTransform}, x)`: the `transform` for its inverse. + - `InverseFunctions.inverse(b::MyTransform)`: returns an existing `Transform`. + - [`logabsdetjac`](@ref): computes the log-abs-det jacobian factor. +- Optional: + - `with_logabsdet_jacobian`: `transform` and `logabsdetjac` combined. Useful in cases where we + can exploit shared computation in the two. + +For the above methods, there are mutating versions which can _optionally_ be implemented: +- [`with_logabsdet_jacobian!`](@ref) +- [`logabsdetjac!`](@ref) +- [`with_logabsdet_jacobian!`](@ref) +""" +abstract type Transform end + +(t::Transform)(x) = transform(t, x) + +Broadcast.broadcastable(b::Transform) = Ref(b) + +""" + transform(b, x) +Transform `x` using `b`, treating `x` as a single input. """ - isclosedform(b::Bijector)::bool - isclosedform(b⁻¹::Inverse{<:Bijector})::bool +transform(f::F, x) where {F<:Function} = f(x) +transform(t::Transform, x) = first(with_logabsdet_jacobian(t, x)) + +""" + transform!(b, x[, y]) + +Transform `x` using `b`, storing the result in `y`. + +If `y` is not provided, `x` is used as the output. +""" +transform!(b, x) = transform!(b, x, x) +transform!(b, x, y) = copyto!(y, transform(b, x)) + +""" + logabsdetjac(b, x) + +Return `log(abs(det(J(b, x))))`, where `J(b, x)` is the jacobian of `b` at `x`. +""" +logabsdetjac(b, x) = last(with_logabsdet_jacobian(b, x)) + +""" + logabsdetjac!(b, x[, logjac]) + +Compute `log(abs(det(J(b, x))))` and store the result in `logjac`, where `J(b, x)` is the jacobian of `b` at `x`. +""" +logabsdetjac!(b, x) = logabsdetjac!(b, x, zero(eltype(x))) +logabsdetjac!(b, x, logjac) = (logjac += logabsdetjac(b, x)) + +""" + with_logabsdet_jacobian!(b, x[, y, logjac]) + +Compute `transform(b, x)` and `logabsdetjac(b, x)`, storing the result +in `y` and `logjac`, respetively. + +If `y` is not provided, then `x` will be used in its place. + +Defaults to calling `with_logabsdet_jacobian(b, x)` and updating `y` and `logjac` with the result. +""" +with_logabsdet_jacobian!(b, x) = with_logabsdet_jacobian!(b, x, x) +with_logabsdet_jacobian!(b, x, y) = with_logabsdet_jacobian!(b, x, y, zero(eltype(x))) +function with_logabsdet_jacobian!(b, x, y, logjac) + y_, logjac_ = with_logabsdet_jacobian(b, x) + y .= y_ + return (y, logjac + logjac_) +end + +""" + isclosedform(b::Transform)::bool + isclosedform(b⁻¹::Inverse{<:Transform})::bool Returns `true` or `false` depending on whether or not evaluation of `b` has a closed-form implementation. -Most bijectors have closed-form evaluations, but there are cases where +Most transformations have closed-form evaluations, but there are cases where this is not the case. For example the *inverse* evaluation of `PlanarLayer` requires an iterative procedure to evaluate. """ -isclosedform(b::Bijector) = true +isclosedform(t::Transform) = true + +""" + isinvertible(t) +Return `true` if `t` is invertible, and `false` otherwise. """ - inverse(b::Bijector) - Inverse(b::Bijector) +isinvertible(t) = inverse(t) !== InverseFunctions.NoInverse() -A `Bijector` representing the inverse transform of `b`. """ -struct Inverse{B <: Bijector, N} <: Bijector{N} - orig::B + inverse(b::Transform) + Inverse(b::Transform) - Inverse(b::B) where {N, B<:Bijector{N}} = new{B, N}(b) +A `Transform` representing the inverse transform of `b`. +""" +struct Inverse{T<:Transform} <: Transform + orig::T + + function Inverse(orig::Transform) + if !isinvertible(orig) + error("$(orig) is not invertible") + end + + return new{typeof(orig)}(orig) + end end -# field contains nested numerical parameters Functors.@functor Inverse -up1(b::Inverse) = Inverse(up1(b.orig)) - -inverse(b::Bijector) = Inverse(b) -inverse(ib::Inverse{<:Bijector}) = ib.orig -Base.:(==)(b1::Inverse{<:Bijector}, b2::Inverse{<:Bijector}) = b1.orig == b2.orig +""" + inverse(t::Transform) +Returns the inverse of transform `t`. """ - logabsdetjac(b::Bijector, x) - logabsdetjac(ib::Inverse{<:Bijector}, y) +inverse(t::Transform) = Inverse(t) +inverse(ib::Inverse) = ib.orig -Computes the log(abs(det(J(b(x))))) where J is the jacobian of the transform. -Similarily for the inverse-transform. +Base.:(==)(b1::Inverse, b2::Inverse) = b1.orig == b2.orig -Default implementation for `Inverse{<:Bijector}` is implemented as -`- logabsdetjac` of original `Bijector`. -""" -logabsdetjac(ib::Inverse{<:Bijector}, y) = - logabsdetjac(ib.orig, ib(y)) +"Abstract type of a bijector, i.e. differentiable bijection with differentiable inverse." +abstract type Bijector <: Transform end -""" - with_logabsdet_jacobian(b::Bijector, x) +isinvertible(::Bijector) = true -Computes both `transform` and `logabsdetjac` in one forward pass, and -returns a named tuple `(b(x), logabsdetjac(b, x))`. +# Default implementation for inverse of a `Bijector`. +logabsdetjac(ib::Inverse{<:Transform}, y) = -logabsdetjac(ib.orig, transform(ib, y)) -This defaults to the call above, but often one can re-use computation -in the computation of the forward pass and the computation of the -`logabsdetjac`. `forward` allows the user to take advantange of such -efficiencies, if they exist. -""" -with_logabsdet_jacobian(b::Bijector, x) = (b(x), logabsdetjac(b, x)) +function with_logabsdet_jacobian(ib::Inverse{<:Transform}, y) + x = transform(ib, y) + return x, -logabsdetjac(inverse(ib), x) +end """ - logabsdetjacinv(b::Bijector, y) + logabsdetjacinv(b, y) Just an alias for `logabsdetjac(inverse(b), y)`. """ -logabsdetjacinv(b::Bijector, y) = logabsdetjac(inverse(b), y) +logabsdetjacinv(b, y) = logabsdetjac(inverse(b), y) ############################## # Example bijector: Identity # ############################## +Identity() = identity -struct Identity{N} <: Bijector{N} end -(::Identity)(x) = copy(x) -inverse(b::Identity) = b -up1(::Identity{N}) where {N} = Identity{N + 1}() - -logabsdetjac(::Identity{0}, x::Real) = zero(eltype(x)) -@generated function logabsdetjac( - b::Identity{N1}, - x::AbstractArray{T2, N2} -) where {N1, T2, N2} - if N1 == N2 - return :(zero(eltype(x))) - elseif N1 + 1 == N2 - return :(zeros(eltype(x), size(x, $N2))) - else - return :(throw(MethodError(logabsdetjac, (b, x)))) - end -end -logabsdetjac(::Identity{2}, x::AbstractArray{<:AbstractMatrix}) = zeros(eltype(x[1]), size(x)) +# Here we don't need to separate between batched version and non-batched, and so +# we can just overload `transform`, etc. directly. +transform(::typeof(identity), x) = copy(x) +transform!(::typeof(identity), x, y) = copy!(y, x) -######################## -# Convenient constants # -######################## -const ZeroOrOneDimBijector = Union{Bijector{0}, Bijector{1}} +logabsdetjac(::typeof(identity), x) = zero(eltype(x)) +logabsdetjac!(::typeof(identity), x, logjac) = logjac ###################### # Bijectors includes # diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index f51013c3..e515d65d 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -1,20 +1,20 @@ # Transformed distributions -struct TransformedDistribution{D, B, V} <: Distribution{V, Continuous} where {D<:Distribution{V, Continuous}, B<:Bijector} +struct TransformedDistribution{D, B, V} <: Distribution{V, Continuous} where {D<:Distribution{V, Continuous}, B} dist::D transform::B - TransformedDistribution(d::UnivariateDistribution, b::Bijector{0}) = new{typeof(d), typeof(b), Univariate}(d, b) - TransformedDistribution(d::MultivariateDistribution, b::Bijector{1}) = new{typeof(d), typeof(b), Multivariate}(d, b) - TransformedDistribution(d::MatrixDistribution, b::Bijector{2}) = new{typeof(d), typeof(b), Matrixvariate}(d, b) + TransformedDistribution(d::UnivariateDistribution, b) = new{typeof(d), typeof(b), Univariate}(d, b) + TransformedDistribution(d::MultivariateDistribution, b) = new{typeof(d), typeof(b), Multivariate}(d, b) + TransformedDistribution(d::MatrixDistribution, b) = new{typeof(d), typeof(b), Matrixvariate}(d, b) end # fields may contain nested numerical parameters Functors.@functor TransformedDistribution -const UnivariateTransformed = TransformedDistribution{<:Distribution, <:Bijector, Univariate} -const MultivariateTransformed = TransformedDistribution{<:Distribution, <:Bijector, Multivariate} +const UnivariateTransformed = TransformedDistribution{<:Distribution,<:Any,Univariate} +const MultivariateTransformed = TransformedDistribution{<:Distribution,<:Any,Multivariate} const MvTransformed = MultivariateTransformed -const MatrixTransformed = TransformedDistribution{<:Distribution, <:Bijector, Matrixvariate} +const MatrixTransformed = TransformedDistribution{<:Distribution,<:Any,Matrixvariate} const Transformed = TransformedDistribution @@ -27,7 +27,7 @@ Couples distribution `d` with the bijector `b` by returning a `TransformedDistri If no bijector is provided, i.e. `transformed(d)` is called, then `transformed(d, bijector(d))` is returned. """ -transformed(d::Distribution, b::Bijector) = TransformedDistribution(d, b) +transformed(d::Distribution, b) = TransformedDistribution(d, b) transformed(d) = transformed(d, bijector(d)) """ @@ -36,12 +36,12 @@ transformed(d) = transformed(d, bijector(d)) Returns the constrained-to-unconstrained bijector for distribution `d`. """ bijector(td::TransformedDistribution) = bijector(td.dist) ∘ inverse(td.transform) -bijector(d::DiscreteUnivariateDistribution) = Identity{0}() -bijector(d::DiscreteMultivariateDistribution) = Identity{1}() +bijector(d::DiscreteUnivariateDistribution) = Identity() +bijector(d::DiscreteMultivariateDistribution) = Identity() bijector(d::ContinuousUnivariateDistribution) = TruncatedBijector(minimum(d), maximum(d)) -bijector(d::Product{Discrete}) = Identity{1}() +bijector(d::Product{Discrete}) = Identity() function bijector(d::Product{Continuous}) - return TruncatedBijector{1}(_minmax(d.v)...) + return TruncatedBijector(_minmax(d.v)...) end @generated function _minmax(d::AbstractArray{T}) where {T} try @@ -52,16 +52,16 @@ end end end -bijector(d::Normal) = Identity{0}() -bijector(d::Distributions.AbstractMvNormal) = Identity{1}() -bijector(d::Distributions.AbstractMvLogNormal) = Log{1}() -bijector(d::PositiveDistribution) = Log{0}() -bijector(d::SimplexDistribution) = SimplexBijector{1}() +bijector(d::Normal) = Identity() +bijector(d::Distributions.AbstractMvNormal) = Identity() +bijector(d::Distributions.AbstractMvLogNormal) = elementwise(log) +bijector(d::PositiveDistribution) = elementwise(log) +bijector(d::SimplexDistribution) = SimplexBijector() bijector(d::KSOneSided) = Logit(zero(eltype(d)), one(eltype(d))) bijector_bounded(d, a=minimum(d), b=maximum(d)) = Logit(a, b) -bijector_lowerbounded(d, a=minimum(d)) = Log() ∘ Shift(-a) -bijector_upperbounded(d, b=maximum(d)) = Log() ∘ Shift(b) ∘ Scale(- one(typeof(b))) +bijector_lowerbounded(d, a=minimum(d)) = elementwise(log) ∘ Shift(-a) +bijector_upperbounded(d, b=maximum(d)) = elementwise(log) ∘ Shift(b) ∘ Scale(- one(typeof(b))) const BoundedDistribution = Union{ Arcsine, Biweight, Cosine, Epanechnikov, Beta, NoncentralBeta @@ -151,123 +151,6 @@ function _rand!(rng::AbstractRNG, td::MatrixTransformed, x::DenseMatrix{<:Real}) x .= td.transform(x) end -############################################################# -# Additional useful functions for `TransformedDistribution` # -############################################################# -""" - logpdf_with_jac(td::UnivariateTransformed, y::Real) - logpdf_with_jac(td::MvTransformed, y::AbstractVector{<:Real}) - logpdf_with_jac(td::MatrixTransformed, y::AbstractMatrix{<:Real}) - -Makes use of the `forward` method to potentially re-use computation -and returns a tuple `(logpdf, logabsdetjac)`. -""" -function logpdf_with_jac(td::UnivariateTransformed, y::Real) - x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) - return (logpdf(td.dist, x) + logjac, logjac) -end - -# TODO: implement more efficiently for flows in the case of `Matrix` -function logpdf_with_jac(td::MvTransformed, y::AbstractVector{<:Real}) - x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) - return (logpdf(td.dist, x) + logjac, logjac) -end - -function logpdf_with_jac(td::MvTransformed, y::AbstractMatrix{<:Real}) - x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) - return (logpdf(td.dist, x) + logjac, logjac) -end - -function logpdf_with_jac(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) - T = eltype(y) - ϵ = _eps(T) - - x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) - lp = logpdf(td.dist, mappedarray(x->x+ϵ, x)) + logjac - return (lp, logjac) -end - -# TODO: should eventually drop using `logpdf_with_trans` -function logpdf_with_jac(td::MatrixTransformed, y::AbstractMatrix{<:Real}) - x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) - return (logpdf_with_trans(td.dist, x, true), logjac) -end - -""" - logpdf_forward(td::Transformed, x) - logpdf_forward(td::Transformed, x, logjac) - -Computes the `logpdf` using the forward pass of the bijector rather than using -the inverse transform to compute the necessary `logabsdetjac`. - -This is similar to `logpdf_with_trans`. -""" -# TODO: implement more efficiently for flows in the case of `Matrix` -logpdf_forward(td::Transformed, x, logjac) = logpdf(td.dist, x) - logjac -logpdf_forward(td::Transformed, x) = logpdf_forward(td, x, logabsdetjac(td.transform, x)) - -function logpdf_forward(td::MvTransformed{<:Dirichlet}, x, logjac) - T = eltype(x) - ϵ = _eps(T) - - return logpdf(td.dist, mappedarray(z->z+ϵ, x)) - logjac -end - - -# forward function -const GLOBAL_RNG = Distributions.GLOBAL_RNG - -function _forward(d::UnivariateDistribution, x) - y, logjac = with_logabsdet_jacobian(Identity{0}(), x) - return (x = x, y = y, logabsdetjac = logjac, logpdf = logpdf.(d, x)) -end - -forward(rng::AbstractRNG, d::Distribution) = _forward(d, rand(rng, d)) -function forward(rng::AbstractRNG, d::Distribution, num_samples::Int) - return _forward(d, rand(rng, d, num_samples)) -end -function _forward(d::Distribution, x) - y, logjac = with_logabsdet_jacobian(Identity{length(size(d))}(), x) - return (x = x, y = y, logabsdetjac = logjac, logpdf = logpdf(d, x)) -end - -function _forward(td::Transformed, x) - y, logjac = with_logabsdet_jacobian(td.transform, x) - return ( - x = x, - y = y, - logabsdetjac = logjac, - logpdf = logpdf_forward(td, x, logjac) - ) -end -function forward(rng::AbstractRNG, td::Transformed) - return _forward(td, rand(rng, td.dist)) -end -function forward(rng::AbstractRNG, td::Transformed, num_samples::Int) - return _forward(td, rand(rng, td.dist, num_samples)) -end - -""" - forward(d::Distribution) - forward(d::Distribution, num_samples::Int) - -Returns a `NamedTuple` with fields `x`, `y`, `logabsdetjac` and `logpdf`. - -In the case where `d isa TransformedDistribution`, this means -- `x = rand(d.dist)` -- `y = d.transform(x)` -- `logabsdetjac` is the logabsdetjac of the "forward" transform. -- `logpdf` is the logpdf of `y`, not `x` - -In the case where `d isa Distribution`, this means -- `x = rand(d)` -- `y = x` -- `logabsdetjac = 0.0` -- `logpdf` is logpdf of `x` -""" -forward(d::Distribution) = forward(GLOBAL_RNG, d) -forward(d::Distribution, num_samples::Int) = forward(GLOBAL_RNG, d, num_samples) - # utility stuff Distributions.params(td::Transformed) = Distributions.params(td.dist) function Base.maximum(td::UnivariateTransformed) @@ -281,19 +164,3 @@ function Base.minimum(td::UnivariateTransformed) return max < min ? max : min end -# logabsdetjac for distributions -logabsdetjacinv(d::UnivariateDistribution, x::T) where T <: Real = zero(T) -logabsdetjacinv(d::MultivariateDistribution, x::AbstractVector{T}) where {T<:Real} = zero(T) - - -""" - logabsdetjacinv(td::UnivariateTransformed, y::Real) - logabsdetjacinv(td::MultivariateTransformed, y::AbstractVector{<:Real}) - -Computes the `logabsdetjac` of the _inverse_ transformation, since `rand(td)` returns -the _transformed_ random variable. -""" -logabsdetjacinv(td::UnivariateTransformed, y::Real) = logabsdetjac(inverse(td.transform), y) -function logabsdetjacinv(td::MvTransformed, y::AbstractVector{<:Real}) - return logabsdetjac(inverse(td.transform), y) -end diff --git a/test/ad/flows.jl b/test/ad/flows.jl index 335f6333..bfcbaacc 100644 --- a/test/ad/flows.jl +++ b/test/ad/flows.jl @@ -3,23 +3,27 @@ test_ad(randn(7)) do θ layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5]) flow = transformed(MvNormal(zeros(2), I), layer) - return logpdf_forward(flow, θ[6:7]) + x = θ[6:7] + return logpdf(flow.dist, x) - logabsdetjac(flow.transform, x) end test_ad(randn(11)) do θ layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5]) flow = transformed(MvNormal(zeros(2), I), layer) - return sum(logpdf_forward(flow, reshape(θ[6:end], 2, :))) + x = reshape(θ[6:end], 2, :) + return sum(logpdf(flow.dist, x) - logabsdetjac(flow.transform, x)) end # logpdf of a flow with the inverse of a planar layer and two-dimensional inputs test_ad(randn(7)) do θ layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5]) flow = transformed(MvNormal(zeros(2), I), inverse(layer)) - return logpdf_forward(flow, θ[6:7]) + x = θ[6:7] + return logpdf(flow.dist, x) - logabsdetjac(flow.transform, x) end test_ad(randn(11)) do θ layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5]) flow = transformed(MvNormal(zeros(2), I), inverse(layer)) - return sum(logpdf_forward(flow, reshape(θ[6:end], 2, :))) + x = reshape(θ[6:end], 2, :) + return sum(logpdf(flow.dist, x) - logabsdetjac(flow.transform, x)) end end diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 479145ba..6bf8365f 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -8,7 +8,9 @@ function test_ad(f, x, broken = (); rtol = 1e-6, atol = 1e-6) if :Tracker in broken @test_broken Tracker.data(Tracker.gradient(f, x)[1]) ≈ finitediff rtol=rtol atol=atol else - @test Tracker.data(Tracker.gradient(f, x)[1]) ≈ finitediff rtol=rtol atol=atol + ∇tracker = Tracker.gradient(f, x)[1] + @test Tracker.data(∇tracker) ≈ finitediff rtol=rtol atol=atol + @test Tracker.istracked(∇tracker) end end @@ -24,7 +26,8 @@ function test_ad(f, x, broken = (); rtol = 1e-6, atol = 1e-6) if :Zygote in broken @test_broken Zygote.gradient(f, x)[1] ≈ finitediff rtol=rtol atol=atol else - @test Zygote.gradient(f, x)[1] ≈ finitediff rtol=rtol atol=atol + ∇zygote = Zygote.gradient(f, x)[1] + @test (all(finitediff .== 0) && ∇zygote === nothing) || isapprox(∇zygote, finitediff, rtol=rtol, atol=atol) end end diff --git a/test/bijectors/coupling.jl b/test/bijectors/coupling.jl index 38ebd763..5d6c6607 100644 --- a/test/bijectors/coupling.jl +++ b/test/bijectors/coupling.jl @@ -54,8 +54,12 @@ using Bijectors: # With `Scale` cl = Coupling(x -> Scale(x[1]), m) - x = hcat([-1., -2., -3.], [1., 2., 3.]) - y = hcat([2., -2., -3.], [2., 2., 3.]) - test_bijector(cl, x, y, log.([2., 2.])) + x = [-1., -2., -3.] + y = [2., -2., -3.] + test_bijector(cl, x; y=y, logjac=log(2)) + + x = [1., 2., 3.] + y = [2., 2., 3.] + test_bijector(cl, x; y=y, logjac=log(2)) end end diff --git a/test/bijectors/leaky_relu.jl b/test/bijectors/leaky_relu.jl index e110a046..a51f33e5 100644 --- a/test/bijectors/leaky_relu.jl +++ b/test/bijectors/leaky_relu.jl @@ -3,84 +3,48 @@ using Test using Bijectors using Bijectors: LeakyReLU -using LinearAlgebra -using ForwardDiff - -true_logabsdetjac(b::Bijector{0}, x::Real) = (log ∘ abs)(ForwardDiff.derivative(b, x)) -true_logabsdetjac(b::Bijector{0}, x::AbstractVector) = (log ∘ abs).(ForwardDiff.derivative.(b, x)) -true_logabsdetjac(b::Bijector{1}, x::AbstractVector) = logabsdet(ForwardDiff.jacobian(b, x))[1] -true_logabsdetjac(b::Bijector{1}, xs::AbstractMatrix) = mapreduce(z -> true_logabsdetjac(b, z), vcat, eachcol(xs)) - @testset "0-dim parameter, 0-dim input" begin - b = LeakyReLU(0.1; dim=Val(0)) - x = 1. - @test inverse(b)(b(x)) == x - @test inverse(b)(b(-x)) == -x - - # Mixing of types - # 1. Changes in input-type - @assert eltype(b(Float32(1.))) == Float64 - @assert eltype(b(Float64(1.))) == Float64 + b = LeakyReLU(0.1) - # 2. Changes in parameter-type - b = LeakyReLU(Float32(0.1); dim=Val(0)) - @assert eltype(b(Float32(1.))) == Float32 - @assert eltype(b(Float64(1.))) == Float64 + # < 0 + x = -1.0 + test_bijector(b, x) - # logabsdetjac - @test logabsdetjac(b, x) == true_logabsdetjac(b, x) - @test logabsdetjac(b, Float32(x)) == true_logabsdetjac(b, x) + # ≥ 0 + x = 1.0 + test_bijector(b, x; test_not_identity=false, test_types=true) - # Batch - xs = randn(10) - @test logabsdetjac(b, xs) == true_logabsdetjac(b, xs) - @test logabsdetjac(b, Float32.(x)) == true_logabsdetjac(b, Float32.(x)) + # Float32 + b = LeakyReLU(Float32(b.α)) - @test logabsdetjac(b, -xs) == true_logabsdetjac(b, -xs) - @test logabsdetjac(b, -Float32.(xs)) == true_logabsdetjac(b, -Float32.(xs)) + # < 0 + x = -1f0 + test_bijector(b, x) - # Forward - f = with_logabsdet_jacobian(b, xs) - @test f[2] ≈ logabsdetjac(b, xs) - @test f[1] ≈ b(xs) - - f = with_logabsdet_jacobian(b, Float32.(xs)) - @test f[2] == logabsdetjac(b, Float32.(xs)) - @test f[1] ≈ b(Float32.(xs)) + # ≥ 0 + x = 1f0 + test_bijector(b, x; test_not_identity=false, test_types=true) end @testset "0-dim parameter, 1-dim input" begin d = 2 + b = LeakyReLU(0.1) - b = LeakyReLU(0.1; dim=Val(1)) - x = ones(d) - @test inverse(b)(b(x)) == x - @test inverse(b)(b(-x)) == -x + # < 0 + x = -ones(d) + test_bijector(b, x) - # Batch - xs = randn(d, 10) - @test logabsdetjac(b, xs) == true_logabsdetjac(b, xs) - @test logabsdetjac(b, Float32.(x)) == true_logabsdetjac(b, Float32.(x)) - - @test logabsdetjac(b, -xs) == true_logabsdetjac(b, -xs) - @test logabsdetjac(b, -Float32.(xs)) == true_logabsdetjac(b, -Float32.(xs)) - - # Forward - f = with_logabsdet_jacobian(b, xs) - @test f[2] ≈ logabsdetjac(b, xs) - @test f[1] ≈ b(xs) - - f = with_logabsdet_jacobian(b, Float32.(xs)) - @test f[2] == logabsdetjac(b, Float32.(xs)) - @test f[1] ≈ b(Float32.(xs)) + # ≥ 0 + x = ones(d) + test_bijector(b, x; test_not_identity=false) - # Mixing of types - # 1. Changes in input-type - @assert eltype(b(ones(Float32, 2))) == Float64 - @assert eltype(b(ones(Float64, 2))) == Float64 + # Float32 + b = LeakyReLU(Float32.(b.α)) + # < 0 + x = -ones(Float32, d) + test_bijector(b, x; test_types=true) - # 2. Changes in parameter-type - b = LeakyReLU(Float32(0.1); dim=Val(1)) - @assert eltype(b(ones(Float32, 2))) == Float32 - @assert eltype(b(ones(Float64, 2))) == Float64 + # ≥ 0 + x = ones(Float32, d) + test_bijector(b, x; test_not_identity=false, test_types=true) end diff --git a/test/bijectors/named_bijector.jl b/test/bijectors/named_bijector.jl index a7248fae..fbdce0ec 100644 --- a/test/bijectors/named_bijector.jl +++ b/test/bijectors/named_bijector.jl @@ -1,26 +1,12 @@ using Test using Bijectors -using Bijectors: Exp, Log, Logit, AbstractNamedBijector, NamedBijector, NamedInverse, NamedCoupling, NamedComposition, Shift +using Bijectors: Logit, AbstractNamedTransform, NamedTransform, NamedCoupling, Shift -@testset "NamedBijector" begin - b = NamedBijector((a = Exp(), b = Log())) +@testset "NamedTransform" begin + b = NamedTransform((a = elementwise(exp), b = elementwise(log))) @test b((a = 0.0, b = exp(1.0))) == (a = 1.0, b = 1.0) -end - -@testset "NamedComposition" begin - b = NamedBijector((a = Exp(), )) - x = (a = 0., b = 1.) - - nc1 = NamedComposition((b, b)) - @test nc1(x) == b(b(x)) - @test logabsdetjac(nc1, x) ≈ logabsdetjac(b, x) + logabsdetjac(b, b(x)) - - nc2 = b ∘ b - @test nc1 == nc2 - inc2 = inverse(nc2) - @test (inc2 ∘ nc2)(x) == x - @test logabsdetjac((inc2 ∘ nc2), x) ≈ 0.0 + with_logabsdet_jacobian(b, (a = 0.0, b = exp(1.0))) end @testset "NamedCoupling" begin diff --git a/test/bijectors/ordered.jl b/test/bijectors/ordered.jl index 1bf53931..058ee77d 100644 --- a/test/bijectors/ordered.jl +++ b/test/bijectors/ordered.jl @@ -5,18 +5,12 @@ import Bijectors: OrderedBijector # Length 1 x = randn(1) - y = b(x) - test_bijector(b, hcat(x, x), hcat(y, y), zeros(2)) + test_bijector(b, x; test_not_identity=false) # Larger x = randn(5) - xs = hcat(x, x) - test_bijector(b, xs) + test_bijector(b, x) y = b(x) @test sort(y) == y - - ys = b(xs) - @test sort(ys[:, 1]) == ys[:, 1] - @test sort(ys[:, 2]) == ys[:, 2] end diff --git a/test/bijectors/rational_quadratic_spline.jl b/test/bijectors/rational_quadratic_spline.jl index a80c4bc3..fe0ddc31 100644 --- a/test/bijectors/rational_quadratic_spline.jl +++ b/test/bijectors/rational_quadratic_spline.jl @@ -38,23 +38,23 @@ using Bijectors: RationalQuadraticSpline # Inside of domain x = 0.5 - test_bijector(b, [-x, x]) + test_bijector(b, -x) + test_bijector(b, x) # Outside of domain - x = 5. - test_bijector(b, [-x, x], [-x, x], [0., 0.]) + x = 5.0 + test_bijector(b, -x; y=-x, logjac=0) + test_bijector(b, x; y=x, logjac=0) # multivariate b = b_mv # Inside of domain x = [-0.5, 0.5] - x = hcat(x, -x, x) # batch test_bijector(b, x) # Outside of domain x = [-5., 5.] - x = hcat(x, -x, x) # batch - test_bijector(b, x, x, zeros(size(x, 2))) + test_bijector(b, x; y=x, logjac=zero(eltype(x))) end end diff --git a/test/bijectors/utils.jl b/test/bijectors/utils.jl index 85b7a1d9..cf1283dd 100644 --- a/test/bijectors/utils.jl +++ b/test/bijectors/utils.jl @@ -1,184 +1,84 @@ -function test_bijector_reals( - b::Bijector{0}, - x_true::Real, - y_true::Real, - logjac_true::Real; - isequal = true, - tol = 1e-6 -) - ib = @inferred inverse(b) - y = @inferred b(x_true) - logjac = @inferred logabsdetjac(b, x_true) - ilogjac = @inferred logabsdetjac(ib, y_true) - res = @inferred with_logabsdet_jacobian(b, x_true) +# Allows us to run `ChangesOfVariables.test_with_logabsdet_jacobian` +include(joinpath(dirname(pathof(ChangesOfVariables)), "..", "test", "getjacobian.jl")) - # If `isequal` is false, then we use the computed `y`, - # but if it's true, we use the true `y`. - ires = isequal ? @inferred(with_logabsdet_jacobian(inverse(b), y_true)) : @inferred(with_logabsdet_jacobian(inverse(b), y)) +test_bijector(b, x; kwargs...) = test_bijector(b, x, getjacobian; kwargs...) - # Always want the following to hold - @test ires[1] ≈ x_true atol=tol - @test ires[2] ≈ -logjac atol=tol - - if isequal - @test y ≈ y_true atol=tol # forward - @test (@inferred ib(y_true)) ≈ x_true atol=tol # inverse - @test logjac ≈ logjac_true # logjac forward - @test res[1] ≈ y_true atol=tol # forward using `forward` - @test res[2] ≈ logjac_true atol=tol # logjac using `forward` - else - @test y ≠ y_true # forward - @test (@inferred ib(y)) ≈ x_true atol=tol # inverse - @test logjac ≠ logjac_true # logjac forward - @test res[1] ≠ y_true # forward using `forward` - @test res[2] ≠ logjac_true # logjac using `forward` - end -end - -function test_bijector_arrays( - b::Bijector, - xs_true::AbstractArray{<:Real}, - ys_true::AbstractArray{<:Real}, - logjacs_true::Union{Real, AbstractArray{<:Real}}; - isequal = true, - tol = 1e-6 +# TODO: Should we move this into `src/`? +function test_bijector( + b, + x, + getjacobian; + y=nothing, + logjac=nothing, + test_not_identity=isnothing(y) && isnothing(logjac), + test_types=false, + compare=isapprox, + kwargs... ) + # Ensure that everything is type-stable. ib = @inferred inverse(b) - ys = @inferred b(xs_true) - logjacs = @inferred logabsdetjac(b, xs_true) - res = @inferred with_logabsdet_jacobian(b, xs_true) - # If `isequal` is false, then we use the computed `y`, - # but if it's true, we use the true `y`. - ires = isequal ? @inferred(with_logabsdet_jacobian(inverse(b), ys_true)) : @inferred(with_logabsdet_jacobian(inverse(b), ys)) - - # always want the following to hold - @test ys isa typeof(ys_true) - @test logjacs isa typeof(logjacs_true) - @test mean(abs, ires[1] - xs_true) ≤ tol - @test mean(abs, ires[2] + logjacs) ≤ tol - - if isequal - @test mean(abs, ys - ys_true) ≤ tol # forward - @test mean(abs, (ib(ys_true)) - xs_true) ≤ tol # inverse - @test mean(abs, logjacs - logjacs_true) ≤ tol # logjac forward - @test mean(abs, res[1] - ys_true) ≤ tol # forward using `forward` - @test mean(abs, res[2] - logjacs_true) ≤ tol # logjac `forward` - @test mean(abs, ires[2] + logjacs_true) ≤ tol # inverse logjac `forward` + logjac_test = @inferred logabsdetjac(b, x) + res = @inferred with_logabsdet_jacobian(b, x) + + y_test = @inferred b(x) + ilogjac_test = !isnothing(y) ? @inferred(logabsdetjac(ib, y)) : @inferred(logabsdetjac(ib, y_test)) + ires = if !isnothing(y) + @inferred(with_logabsdet_jacobian(inverse(b), y)) else - # Don't want the following to be equal to their "true" values - @test mean(abs, ys - ys_true) > tol # forward - @test mean(abs, logjacs - logjacs_true) > tol # logjac forward - @test mean(abs, res[1] - ys_true) > tol # forward using `forward` - - # Still want the following to be equal to the COMPUTED values - @test mean(abs, ib(ys) - xs_true) ≤ tol # inverse - @test mean(abs, res[2] - logjacs) ≤ tol # logjac forward using `forward` + @inferred(with_logabsdet_jacobian(inverse(b), y_test)) end -end - -""" - test_bijector(b::Bijector, xs::Array; kwargs...) - test_bijector(b::Bijector, xs::Array, ys::Array, logjacs::Array; kwargs...) - -Tests the bijector `b` on the inputs `xs` against the, optionally, provided `ys` -and `logjacs`. - -If `ys` and `logjacs` are NOT provided, `isequal` will be set to `false` and -`ys` and `logjacs` will be set to `zeros`. These `ys` and `logjacs` will be -treated as "counter-examples", i.e. values NOT to match. - -# Arguments -- `b::Bijector`: the bijector to test -- `xs`: inputs (has to be several!!!)(has to be several, i.e. a batch!!!) to test -- `ys`: outputs (has to be several, i.e. a batch!!!) to test against -- `logjacs`: `logabsdetjac` outputs (has to be several!!!)(has to be several, i.e. - a batch!!!) to test against - -# Keywords -- `isequal = true`: if `false`, it will be assumed that the given values are - provided as "counter-examples" in the sense that the inputs `xs` should NOT map - to the given outputs. This is useful in cases where one might not know the expected - output, but still wants to test that the evaluation, etc. works. - This is set to `true` by default if `ys` and `logjacs` are not provided. -- `tol = 1e-6`: the absolute tolerance used for the checks. This is also used to check - arrays where we check that the L1-norm is sufficiently small. -""" -function test_bijector(b::Bijector{0}, xs::AbstractVector{<:Real}) - return test_bijector(b, xs, zeros(length(xs)), zeros(length(xs)); isequal = false) -end -function test_bijector(b::Bijector{1}, xs::AbstractMatrix{<:Real}) - return test_bijector(b, xs, zeros(size(xs)), zeros(size(xs, 2)); isequal = false) -end + # ChangesOfVariables.jl + ChangesOfVariables.test_with_logabsdet_jacobian(b, x, getjacobian; compare=compare, kwargs...) + ChangesOfVariables.test_with_logabsdet_jacobian(ib, isnothing(y) ? y_test : y, getjacobian; compare=compare, kwargs...) -function test_bijector( - b::Bijector{0}, - xs_true::AbstractVector{<:Real}, - ys_true::AbstractVector{<:Real}, - logjacs_true::AbstractVector{<:Real}; - kwargs... -) - ib = inverse(b) + # InverseFunctions.jl + InverseFunctions.test_inverse(b, x; compare, kwargs...) + InverseFunctions.test_inverse(ib, isnothing(y) ? y_test : y; compare=compare, kwargs...) - # Batch - test_bijector_arrays(b, xs_true, ys_true, logjacs_true; kwargs...) - - # Test `logabsdetjac` against jacobians - test_logabsdetjac(b, xs_true) - test_logabsdetjac(b, ys_true) - - for (x_true, y_true, logjac_true) in zip(xs_true, ys_true, logjacs_true) - test_bijector_reals(b, x_true, y_true, logjac_true; kwargs...) + # Always want the following to hold + @test compare(ires[1], x; kwargs...) + @test compare(ires[2], -logjac_test; kwargs...) + + # Verify values. + if !isnothing(y) + @test compare(y_test, y; kwargs...) + @test compare((@inferred ib(y)), x; kwargs...) # inverse + @test compare(res[1], y; kwargs...) # forward using `forward` + end - # Test AD - test_ad(x -> b(first(x)), [x_true, ]) + if !isnothing(logjac) + # We've already checked `ires[2]` against `res[2]`, so if `res[2]` is correct, then so is `ires[2]`. + @test compare(logjac_test, logjac; kwargs...) # logjac forward + @test compare(res[2], logjac; kwargs...) # logjac using `forward` + end - y = b(x_true) - test_ad(x -> ib(first(x)), [y, ]) + # Useful for testing when you don't know the true outputs but know that + # `b` is definitively not identity. + if test_not_identity + @test y_test ≠ x + @test logjac_test ≠ zero(eltype(x)) + @test res[2] ≠ zero(eltype(x)) + end - test_ad(x -> logabsdetjac(b, first(x)), [x_true, ]) + if test_types + @test typeof(first(res)) === typeof(x) + @test typeof(res) === typeof(ires) + @test typeof(y_test) === typeof(x) + @test typeof(logjac_test) === typeof(ilogjac_test) end end +make_jacobian_function(f, xs::AbstractVector) = f, xs +function make_jacobian_function(f, xs::AbstractArray) + xs_new = vec(xs) + s = size(xs) -function test_bijector( - b::Bijector{1}, - xs_true::AbstractMatrix{<:Real}, - ys_true::AbstractMatrix{<:Real}, - logjacs_true::AbstractVector{<:Real}; - kwargs... -) - ib = inverse(b) - - # Batch - test_bijector_arrays(b, xs_true, ys_true, logjacs_true; kwargs...) - - # Test `logabsdetjac` against jacobians - test_logabsdetjac(b, xs_true) - test_logabsdetjac(b, ys_true) - - for (x_true, y_true, logjac_true) in zip(eachcol(xs_true), eachcol(ys_true), logjacs_true) - # HACK: collect to avoid dealing with sub-arrays and thus allowing us to compare the - # type of the computed output to the "true" output. - test_bijector_arrays(b, collect(x_true), collect(y_true), logjac_true; kwargs...) - - # Test AD - test_ad(x -> sum(b(x)), collect(x_true)) - y = b(x_true) - test_ad(x -> sum(ib(x)), y) - - test_ad(x -> logabsdetjac(b, x), x_true) + function g(x) + return vec(f(reshape(x, s))) end -end - -function test_logabsdetjac(b::Bijector{1}, xs::AbstractMatrix; tol=1e-6) - logjac_ad = [logabsdet(ForwardDiff.jacobian(b, x))[1] for x in eachcol(xs)] - @test mean(logabsdetjac(b, xs) - logjac_ad) ≤ tol -end -function test_logabsdetjac(b::Bijector{0}, xs::AbstractVector; tol=1e-6) - logjac_ad = [log(abs(ForwardDiff.derivative(b, x))) for x in xs] - @test mean(logabsdetjac(b, xs) - logjac_ad) ≤ tol + return g, xs_new end # Check if `Functors.functor` works properly @@ -187,3 +87,19 @@ function test_functor(x, xs) @test x == re(_xs) @test _xs == xs end + +function test_bijector_parameter_gradient(b::Bijectors.Transform, x, y = b(x)) + args, re = Functors.functor(b) + recon(k, param) = re(merge(args, NamedTuple{(k, )}((param, )))) + + # Compute the gradient wrt. one argument at the time. + for (k, v) in pairs(args) + test_ad(p -> sum(transform(recon(k, p), x)), v) + test_ad(p -> logabsdetjac(recon(k, p), x), v) + + if Bijectors.isinvertible(b) + test_ad(p -> sum(transform(inv(recon(k, p)), y)), v) + test_ad(p -> logabsdetjac(inv(recon(k, p)), y), v) + end + end +end diff --git a/test/interface.jl b/test/interface.jl index 11fc27f6..e975e486 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -1,3 +1,6 @@ +# using Pkg; Pkg.activate("..") +# using TestEnv; TestEnv.activate() + using Test using Random using LinearAlgebra @@ -7,23 +10,24 @@ using Tracker using DistributionsAD using Bijectors -using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, ADBijector, RationalQuadraticSpline, LeakyReLU +using Bijectors: Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, ADBijector, RationalQuadraticSpline, LeakyReLU Random.seed!(123) -struct MyADBijector{AD, N, B <: Bijector{N}} <: ADBijector{AD, N} +struct MyADBijector{AD,B} <: ADBijector{AD} b::B end MyADBijector(d::Distribution) = MyADBijector{Bijectors.ADBackend()}(d) MyADBijector{AD}(d::Distribution) where {AD} = MyADBijector{AD}(bijector(d)) -MyADBijector{AD}(b::B) where {AD, N, B <: Bijector{N}} = MyADBijector{AD, N, B}(b) +MyADBijector{AD}(b) where {AD} = MyADBijector{AD, typeof(b)}(b) (b::MyADBijector)(x) = b.b(x) -(b::Inverse{<:MyADBijector})(x) = inverse(b.orig.b)(x) +Bijectors.transform(b::MyADBijector, x) = b.b(x) +Bijectors.transform(b::Inverse{<:MyADBijector}, x) = inverse(b.orig.b)(x) -struct NonInvertibleBijector{AD} <: ADBijector{AD, 1} end +struct NonInvertibleBijector{AD} <: ADBijector{AD} end contains(predicate::Function, b::Bijector) = predicate(b) -contains(predicate::Function, b::Composed) = any(contains.(predicate, b.ts)) +contains(predicate::Function, b::ComposedFunction) = any(contains.(predicate, b.ts)) contains(predicate::Function, b::Stacked) = any(contains.(predicate, b.bs)) # Scalar tests @@ -77,11 +81,6 @@ end @test y ≈ @inferred td.transform(x) @test @inferred(logpdf(td, y)) ≈ @inferred(logpdf_with_trans(dist, x, true)) - # logpdf_with_jac - lp, logjac = logpdf_with_jac(td, y) - @test lp ≈ logpdf(td, y) - @test logjac ≈ logabsdetjacinv(td.transform, y) - # multi-sample y = @inferred rand(td, 10) x = inverse(td.transform).(y) @@ -95,14 +94,6 @@ end @test logpdf(d, inverse(b)(y)) + logabsdetjacinv(b, y) ≈ logpdf_with_trans(d, x, true) @test logpdf(d, x) - logabsdetjac(b, x) ≈ logpdf_with_trans(d, x, true) - # forward - f = @inferred forward(td) - @test f.x ≈ inverse(td.transform)(f.y) - @test f.y ≈ td.transform(f.x) - @test f.logabsdetjac ≈ logabsdetjac(td.transform, f.x) - @test f.logpdf ≈ logpdf_with_trans(td.dist, f.x, true) - @test f.logpdf ≈ logpdf(td.dist, f.x) - f.logabsdetjac - # verify against AD d = dist b = bijector(d) @@ -142,234 +133,6 @@ end end end -@testset "Batch computation" begin - bs_xs = [ - (Scale(2.0), randn(3)), - (Scale([1.0, 2.0]), randn(2, 3)), - (Shift(2.0), randn(3)), - (Shift([1.0, 2.0]), randn(2, 3)), - (Log{0}(), exp.(randn(3))), - (Log{1}(), exp.(randn(2, 3))), - (Exp{0}(), randn(3)), - (Exp{1}(), randn(2, 3)), - (Log{1}() ∘ Exp{1}(), randn(2, 3)), - (inverse(Logit(-1.0, 1.0)), randn(3)), - (Identity{0}(), randn(3)), - (Identity{1}(), randn(2, 3)), - (PlanarLayer(2), randn(2, 3)), - (RadialLayer(2), randn(2, 3)), - (PlanarLayer(2) ∘ RadialLayer(2), randn(2, 3)), - (Exp{1}() ∘ PlanarLayer(2) ∘ RadialLayer(2), randn(2, 3)), - (SimplexBijector(), mapslices(z -> normalize(z, 1), rand(2, 3); dims = 1)), - (stack(Exp{0}(), Scale(2.0)), randn(2, 3)), - (Stacked((Exp{1}(), SimplexBijector()), (1:1, 2:3)), - mapslices(z -> normalize(z, 1), rand(3, 2); dims = 1)), - (RationalQuadraticSpline(randn(3), randn(3), randn(3 - 1), 2.), [-0.5, 0.5]), - (LeakyReLU(0.1), randn(3)), - (LeakyReLU(Float32(0.1)), randn(3)), - (LeakyReLU(0.1; dim = Val(1)), randn(2, 3)), - ] - - for (b, xs) in bs_xs - @testset "$b" begin - D = @inferred Bijectors.dimension(b) - ib = @inferred inverse(b) - - @test Bijectors.dimension(ib) == D - - x = D == 0 ? xs[1] : xs[:, 1] - - y = @inferred b(x) - ys = @inferred b(xs) - @inferred(b(param(xs))) - - x_ = @inferred ib(y) - xs_ = @inferred ib(ys) - @inferred(ib(param(ys))) - - result = @inferred with_logabsdet_jacobian(b, x) - results = @inferred with_logabsdet_jacobian(b, xs) - - iresult = @inferred with_logabsdet_jacobian(ib, y) - iresults = @inferred with_logabsdet_jacobian(ib, ys) - - # Sizes - @test size(y) == size(x) - @test size(ys) == size(xs) - - @test size(x_) == size(x) - @test size(xs_) == size(xs) - - @test size(result[1]) == size(x) - @test size(results[1]) == size(xs) - - @test size(iresult[1]) == size(y) - @test size(iresults[1]) == size(ys) - - # Values - @test ys ≈ hcat([b(xs[:, i]) for i = 1:size(xs, 2)]...) - @test ys ≈ results[1] - - if D == 0 - # Sizes - @test y == ys[1] - - @test length(logabsdetjac(b, xs)) == length(xs) - @test length(logabsdetjac(ib, ys)) == length(xs) - - @test @inferred(logabsdetjac(b, param(xs))) isa Union{Array, TrackedArray} - @test @inferred(logabsdetjac(ib, param(ys))) isa Union{Array, TrackedArray} - - @test size(results[2]) == size(xs, ) - @test size(iresults[2]) == size(ys, ) - - # Values - b_logjac_ad = [(log ∘ abs)(ForwardDiff.derivative(b, xs[i])) for i = 1:length(xs)] - ib_logjac_ad = [(log ∘ abs)(ForwardDiff.derivative(ib, ys[i])) for i = 1:length(ys)] - @test logabsdetjac.(b, xs) == @inferred(logabsdetjac(b, xs)) - @test @inferred(logabsdetjac(b, xs)) ≈ b_logjac_ad atol=1e-9 - @test logabsdetjac.(ib, ys) == @inferred(logabsdetjac(ib, ys)) - @test @inferred(logabsdetjac(ib, ys)) ≈ ib_logjac_ad atol=1e-9 - - @test logabsdetjac.(b, param(xs)) == @inferred(logabsdetjac(b, param(xs))) - @test logabsdetjac.(ib, param(ys)) == @inferred(logabsdetjac(ib, param(ys))) - - @test results[2] ≈ vec(logabsdetjac.(b, xs)) - @test iresults[2] ≈ vec(logabsdetjac.(ib, ys)) - elseif D == 1 - @test y == ys[:, 1] - # Comparing sizes instead of lengths ensures we catch errors s.t. - # length(x) == 3 when size(x) == (1, 3). - # Sizes - @test size(logabsdetjac(b, xs)) == (size(xs, 2), ) - @test size(logabsdetjac(ib, ys)) == (size(xs, 2), ) - - @test @inferred(logabsdetjac(b, param(xs))) isa Union{Array, TrackedArray} - @test @inferred(logabsdetjac(ib, param(ys))) isa Union{Array, TrackedArray} - - @test size(results[2]) == (size(xs, 2), ) - @test size(iresults[2]) == (size(ys, 2), ) - - # Test all values - @test @inferred(logabsdetjac(b, xs)) ≈ vec(mapslices(z -> logabsdetjac(b, z), xs; dims = 1)) - @test @inferred(logabsdetjac(ib, ys)) ≈ vec(mapslices(z -> logabsdetjac(ib, z), ys; dims = 1)) - - @test results[2] ≈ vec(mapslices(z -> logabsdetjac(b, z), xs; dims = 1)) - @test iresults[2] ≈ vec(mapslices(z -> logabsdetjac(ib, z), ys; dims = 1)) - - # FIXME: `SimplexBijector` results in ∞ gradient if not in the domain - if !contains(t -> t isa SimplexBijector, b) - b_logjac_ad = [logabsdet(ForwardDiff.jacobian(b, xs[:, i]))[1] for i = 1:size(xs, 2)] - @test logabsdetjac(b, xs) ≈ b_logjac_ad atol=1e-9 - - ib_logjac_ad = [logabsdet(ForwardDiff.jacobian(ib, ys[:, i]))[1] for i = 1:size(ys, 2)] - @test logabsdetjac(ib, ys) ≈ ib_logjac_ad atol=1e-9 - end - else - error("tests not implemented yet") - end - end - end - - @testset "Composition" begin - @test_throws DimensionMismatch (Exp{1}() ∘ Log{0}()) - - # Check that type-stable composition stays type-stable - cb1 = Composed((Exp(), Log())) ∘ Exp() - @test cb1 isa Composed{<:Tuple} - cb2 = Exp() ∘ Composed((Exp(), Log())) - @test cb2 isa Composed{<:Tuple} - cb3 = cb1 ∘ cb2 - @test cb3 isa Composed{<:Tuple} - - @test logabsdetjac(cb1, 1.) isa Real - @test logabsdetjac(cb1, 1.) == 1. - - @test inverse(cb1) isa Composed{<:Tuple} - @test inverse(cb2) isa Composed{<:Tuple} - @test inverse(cb3) isa Composed{<:Tuple} - - # Check that type-unstable composition stays type-unstable - cb1 = Composed([Exp(), Log()]) ∘ Exp() - @test cb1 isa Composed{<:AbstractArray} - cb2 = Exp() ∘ Composed([Exp(), Log()]) - @test cb2 isa Composed{<:AbstractArray} - cb3 = cb1 ∘ cb2 - @test cb3 isa Composed{<:AbstractArray} - - @test logabsdetjac(cb1, 1.) isa Real - @test logabsdetjac(cb1, 1.) == 1. - - @test inverse(cb1) isa Composed{<:AbstractArray} - @test inverse(cb2) isa Composed{<:AbstractArray} - @test inverse(cb3) isa Composed{<:AbstractArray} - - # combining the two - @test_throws ErrorException (Log() ∘ Exp()) ∘ cb1 - @test_throws ErrorException cb1 ∘ (Log() ∘ Exp()) - end - - @testset "Batch-computation with Tracker.jl" begin - @testset "Scale" begin - # 0-dim with `Real` parameter - b = Scale(param(2.0)) - lj = logabsdetjac(b, 1.0) - Tracker.back!(lj, 1.0) - @test Tracker.extract_grad!(b.a) == 0.5 - - # 0-dim with `Real` parameter for batch-computation - lj = logabsdetjac(b, [1.0, 2.0, 3.0]) - Tracker.back!(lj, [1.0, 1.0, 1.0]) - @test Tracker.extract_grad!(b.a) == sum([0.5, 0.5, 0.5]) - - - # 1-dim with `Vector` parameter - x = [3.0, 4.0, 5.0] - xs = [3.0 4.0; 4.0 7.0; 5.0 8.0] - a = [2.0, 3.0, 5.0] - - b = Scale(param(a)) - lj = logabsdetjac(b, x) - Tracker.back!(lj) - @test Tracker.extract_grad!(b.a) == ForwardDiff.gradient(a -> logabsdetjac(Scale(a), x), a) - - # batch - lj = logabsdetjac(b, xs) - Tracker.back!(mean(lj), 1.0) - @test Tracker.extract_grad!(b.a) == ForwardDiff.gradient(a -> mean(logabsdetjac(Scale(a), xs)), a) - - # Forward when doing a composition - y, logjac = logabsdetjac(b, xs) - Tracker.back!(mean(logjac), 1.0) - @test Tracker.extract_grad!(b.a) == ForwardDiff.gradient(a -> mean(logabsdetjac(Scale(a), xs)), a) - end - - @testset "Shift" begin - b = Shift(param(1.0)) - lj = logabsdetjac(b, 1.0) - Tracker.back!(lj, 1.0) - @test Tracker.extract_grad!(b.a) == 0.0 - - # 0-dim with `Real` parameter for batch-computation - lj = logabsdetjac(b, [1.0, 2.0, 3.0]) - @test lj isa TrackedArray - Tracker.back!(lj, [1.0, 1.0, 1.0]) - @test Tracker.extract_grad!(b.a) == 0.0 - - # 1-dim with `Vector` parameter - b = Shift(param([2.0, 3.0, 5.0])) - lj = logabsdetjac(b, [3.0, 4.0, 5.0]) - Tracker.back!(lj) - @test Tracker.extract_grad!(b.a) == zeros(3) - - lj = logabsdetjac(b, [3.0 4.0 5.0; 6.0 7.0 8.0]) - @test lj isa TrackedArray - Tracker.back!(lj, [1.0, 1.0, 1.0]) - @test Tracker.extract_grad!(b.a) == zeros(3) - end - end -end - @testset "Truncated" begin d = truncated(Normal(), -1, 1) b = bijector(d) @@ -421,29 +184,15 @@ end @test td.transform(param(x)) isa TrackedArray @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) - # logpdf_with_jac - lp, logjac = logpdf_with_jac(td, y) - @test lp ≈ logpdf(td, y) - @test logjac ≈ logabsdetjacinv(td.transform, y) - - # multi-sample - y = rand(td, 10) - x = inverse(td.transform)(y) - @test inverse(td.transform)(param(y)) isa TrackedArray - @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) - - # forward - f = forward(td) - @test f.x ≈ inverse(td.transform)(f.y) - @test f.y ≈ td.transform(f.x) - @test f.logabsdetjac ≈ logabsdetjac(td.transform, f.x) - @test f.logpdf ≈ logpdf_with_trans(td.dist, f.x, true) - # verify against AD # similar to what we do in test/transform.jl for Dirichlet if dist isa Dirichlet - b = Bijectors.SimplexBijector{1, false}() - x = rand(dist) + b = Bijectors.SimplexBijector{false}() + # HACK(torfjelde): Calling `rand(dist)` will sometimes lead to `[0.999..., 0.0]` + # which in turn will lead to differences between `ForwardDiff.jacobian` + # and `logabsdetjac` due to how we handle the boundary values in `SimplexBijector`. + # We therefore test the realizations _on_ the boundary rather if we're near the boundary. + x = any(rand(dist) .> 0.9999) ? [0.0, 1.0][sortperm(rand(dist))] : rand(dist) y = b(x) @test b(param(x)) isa TrackedArray @test log(abs(det(ForwardDiff.jacobian(b, x)))) ≈ logabsdetjac(b, x) @@ -490,80 +239,10 @@ end # lp, logjac = logpdf_with_jac(td, y) # @test lp ≈ logpdf(td, y) # @test logjac ≈ logabsdetjacinv(td.transform, y) - - # multi-sample - y = rand(td, 10) - x = inverse(td.transform)(y) - @test inverse(td.transform)(param.(y)) isa Vector{<:TrackedArray} - @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) end end end -@testset "Composition <: Bijector" begin - d = Beta() - td = transformed(d) - - x = rand(d) - y = td.transform(x) - - b = @inferred Bijectors.composel(td.transform, Bijectors.Identity{0}()) - ib = @inferred inverse(b) - - @test with_logabsdet_jacobian(b, x) == with_logabsdet_jacobian(td.transform, x) - @test with_logabsdet_jacobian(ib, y) == with_logabsdet_jacobian(inverse(td.transform), y) - - @test with_logabsdet_jacobian(b, x) == with_logabsdet_jacobian(Bijectors.composer(b.ts...), x) - - # inverse works fine for composition - cb = @inferred b ∘ ib - @test cb(x) ≈ x - - cb2 = @inferred cb ∘ cb - @test cb(x) ≈ x - - # ensures that the `logabsdetjac` is correct - x = rand(d) - b = inverse(bijector(d)) - @test logabsdetjac(b ∘ b, x) ≈ logabsdetjac(b, b(x)) + logabsdetjac(b, x) - - # order of composed evaluation - b1 = MyADBijector(d) - b2 = MyADBijector(Gamma()) - - cb = inverse(b1) ∘ b2 - @test cb(x) ≈ inverse(b1)(b2(x)) - - # contrived example - b = bijector(d) - cb = @inferred inverse(b) ∘ b - cb = @inferred cb ∘ cb - @test @inferred(cb ∘ cb ∘ cb ∘ cb ∘ cb)(x) ≈ x - - # forward for tuple and array - d = Beta() - b = @inferred inverse(bijector(d)) - b⁻¹ = @inferred inverse(b) - x = rand(d) - - cb_t = b⁻¹ ∘ b⁻¹ - f_t = with_logabsdet_jacobian(cb_t, x) - - cb_a = Composed([b⁻¹, b⁻¹]) - f_a = with_logabsdet_jacobian(cb_a, x) - - @test f_t == f_a - - # `composer` and `composel` - cb_l = Bijectors.composel(b⁻¹, b⁻¹, b) - cb_r = Bijectors.composer(reverse(cb_l.ts)...) - y = cb_l(x) - @test y == Bijectors.composel(cb_r.ts...)(x) - - k = length(cb_l.ts) - @test all([cb_l.ts[i] == cb_r.ts[i] for i = 1:k]) -end - @testset "Stacked <: Bijector" begin # `logabsdetjac` withOUT AD d = Beta() @@ -609,7 +288,7 @@ end # value-test x = ones(3) - sb = @inferred stack(Bijectors.Exp(), Bijectors.Log(), Bijectors.Shift(5.0)) + sb = @inferred stack(elementwise(exp), elementwise(log), Shift(5.0)) res = with_logabsdet_jacobian(sb, x) @test sb(param(x)) isa TrackedArray @test sb(x) == [exp(x[1]), log(x[2]), x[3] + 5.0] @@ -619,7 +298,7 @@ end # TODO: change when we have dimensionality in the type - sb = @inferred Stacked((Bijectors.Exp(), Bijectors.SimplexBijector()), (1:1, 2:3)) + sb = @inferred Stacked((elementwise(exp), SimplexBijector()), (1:1, 2:3)) x = ones(3) ./ 3.0 res = @inferred with_logabsdet_jacobian(sb, x) @test sb(param(x)) isa TrackedArray @@ -632,7 +311,7 @@ end @test_throws AssertionError sb(x) # Array-version - sb = Stacked([Bijectors.Exp(), Bijectors.SimplexBijector()], [1:1, 2:3]) + sb = Stacked([elementwise(exp), SimplexBijector()], [1:1, 2:3]) x = ones(3) ./ 3.0 res = with_logabsdet_jacobian(sb, x) @test sb(param(x)) isa TrackedArray @@ -646,7 +325,7 @@ end # Mixed versions # Tuple, Array - sb = Stacked([Bijectors.Exp(), Bijectors.SimplexBijector()], (1:1, 2:3)) + sb = Stacked([elementwise(exp), SimplexBijector()], (1:1, 2:3)) x = ones(3) ./ 3.0 res = with_logabsdet_jacobian(sb, x) @test sb(param(x)) isa TrackedArray @@ -659,7 +338,7 @@ end @test_throws AssertionError sb(x) # Array, Tuple - sb = Stacked((Bijectors.Exp(), Bijectors.SimplexBijector()), [1:1, 2:3]) + sb = Stacked((elementwise(exp), SimplexBijector()), [1:1, 2:3]) x = ones(3) ./ 3.0 res = with_logabsdet_jacobian(sb, x) @test sb(param(x)) isa TrackedArray @@ -712,7 +391,6 @@ end @test td isa Distribution{Multivariate, Continuous} # check that wrong ranges fails - @test_throws MethodError stack(ibs...) sb = Stacked(ibs) x = rand(d) @test_throws AssertionError sb(x) @@ -801,15 +479,9 @@ end @testset "Equality" begin bs = [ - Identity{0}(), - Identity{1}(), - Identity{2}(), - Exp{0}(), - Exp{1}(), - Exp{2}(), - Log{0}(), - Log{1}(), - Log{2}(), + Identity(), + elementwise(exp), + elementwise(log), Scale(2.0), Scale(3.0), Scale(rand(2,2)), @@ -832,14 +504,12 @@ end RadialLayer(2), RadialLayer(3), SimplexBijector(), - Stacked((Exp{0}(), Log{0}())), - Stacked((Log{0}(), Exp{0}())), - Stacked([Exp{0}(), Log{0}()]), - Stacked([Log{0}(), Exp{0}()]), - Composed((Exp{0}(), Log{0}())), - Composed((Log{0}(), Exp{0}())), - # Composed([Exp{0}(), Log{0}()]), - # Composed([Log{0}(), Exp{0}()]), + Stacked((elementwise(exp), elementwise(log))), + Stacked((elementwise(log), elementwise(exp))), + Stacked([elementwise(exp), elementwise(log)]), + Stacked([elementwise(log), elementwise(exp)]), + elementwise(exp) ∘ elementwise(log), + elementwise(log) ∘ elementwise(exp), TruncatedBijector(1.0, 2.0), TruncatedBijector(1.0, 3.0), TruncatedBijector(0.0, 2.0), @@ -855,18 +525,9 @@ end end @testset "test_inverse and test_with_logabsdet_jacobian" begin - b = Bijectors.Scale{Float64,0}(4.2) - x = 0.3 - - test_inverse(b, x) - test_with_logabsdet_jacobian(b, x, (f::Bijectors.Scale, x) -> f.a) -end - - -@testset "deprecations" begin - b = Bijectors.Exp() + b = Bijectors.Scale{Float64,}(4.2) x = 0.3 - @test @test_deprecated(forward(b, x)) == NamedTuple{(:rv, :logabsdetjac)}(with_logabsdet_jacobian(b, x)) - @test @test_deprecated(inv(b)) == inverse(b) + InverseFunctions.test_inverse(b, x) + ChangesOfVariables.test_with_logabsdet_jacobian(b, x, (f::Bijectors.Scale, x) -> f.a) end diff --git a/test/norm_flows.jl b/test/norm_flows.jl index fdf79676..eb07488d 100644 --- a/test/norm_flows.jl +++ b/test/norm_flows.jl @@ -103,7 +103,7 @@ end x = rand(d) y = flow.transform(x) res = with_logabsdet_jacobian(flow.transform, x) - lp = logpdf_forward(flow, x, res[2]) + lp = logpdf(d, x) - res[2] @test res[1] ≈ y @test logpdf(flow, y) ≈ lp rtol=0.1 diff --git a/test/runtests.jl b/test/runtests.jl index 9be0948c..7fcd6dfc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,11 +13,11 @@ using Zygote using Random, LinearAlgebra, Test -using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, +using Bijectors: Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, ADBijector -using ChangesOfVariables: test_with_logabsdet_jacobian -using InverseFunctions: test_inverse +using ChangesOfVariables: ChangesOfVariables +using InverseFunctions: InverseFunctions const GROUP = get(ENV, "GROUP", "All") diff --git a/test/transform.jl b/test/transform.jl index e119c1b2..7be147d3 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -1,6 +1,6 @@ using Test using Bijectors -using ForwardDiff: derivative, jacobian +using ForwardDiff: ForwardDiff using LinearAlgebra: logabsdet, I, norm using Random @@ -63,23 +63,6 @@ function single_sample_tests(dist) @test logpdf(dist, x) == logpdf_with_trans(dist, x, false) @test all(isfinite, logpdf.(Ref(dist), [invlink(dist, _rand_real(x)) for _ in 1:100])) end - # This is a quirk of the current implementation, of which it would be nice to be rid. - @test typeof(x) == typeof(y) -end - -# Standard tests for all distributions involving multiple samples. xs should be whatever -# the appropriate repeated version of x is for the distribution in question. ie. for -# univariate distributions, just a vector of identical values. For vector-valued -# distributions, a matrix whose columns are identical. -function multi_sample_tests(dist, x, xs, N) - ys = @inferred(link(dist, copy(xs))) - @test @inferred(invlink(dist, link(dist, copy(xs)))) ≈ xs atol=1e-9 - @test @inferred(link(dist, invlink(dist, copy(ys)))) ≈ ys atol=1e-9 - @test logpdf_with_trans(dist, xs, true) == fill(logpdf_with_trans(dist, x, true), N) - @test logpdf_with_trans(dist, xs, false) == fill(logpdf_with_trans(dist, x, false), N) - - # This is a quirk of the current implementation, of which it would be nice to be rid. - @test typeof(xs) == typeof(ys) end # Scalar tests @@ -116,13 +99,7 @@ let ] for dist in uni_dists - single_sample_tests(dist, derivative) - - # specialised multi-sample tests. - N = 10 - x = rand(dist) - xs = fill(x, N) - multi_sample_tests(dist, x, xs, N) + single_sample_tests(dist, ForwardDiff.derivative) end end end @@ -155,7 +132,7 @@ let ϵ = eps(Float64) end logpdf_turing = logpdf_with_trans(dist, x, true) - J = jacobian(x->link(dist, x, Val(false)), x) + J = ForwardDiff.jacobian(x->link(dist, x, Val(false)), x) @test logpdf(dist, x .+ ϵ) - _logabsdet(J) ≈ logpdf_turing # Issue #12 @@ -164,14 +141,8 @@ let ϵ = eps(Float64) x = [logpdf_with_trans(dist, invlink(dist, link(dist, rand(dist)) .+ randn(dim) .* stepsize), true) for _ in 1:1_000] @test !any(isinf, x) && !any(isnan, x) else - single_sample_tests(dist, jacobian) + single_sample_tests(dist, ForwardDiff.jacobian) end - - # Multi-sample tests. Columns are observations due to Distributions.jl conventions. - N = 10 - x = rand(dist) - xs = repeat(x, 1, N) - multi_sample_tests(dist, x, xs, N) end end end @@ -191,15 +162,9 @@ let lowerinds = [LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[1] >= I[2]] upperinds = [LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] >= I[1]] logpdf_turing = logpdf_with_trans(dist, x, true) - J = jacobian(x->link(dist, x), x) + J = ForwardDiff.jacobian(x->link(dist, x), x) J = J[lowerinds, upperinds] @test logpdf(dist, x) - _logabsdet(J) ≈ logpdf_turing - - # Multi-sample tests comprising vectors of matrices. - N = 10 - x = rand(dist) - xs = [x for _ in 1:N] - multi_sample_tests(dist, x, xs, N) end end end @@ -216,17 +181,10 @@ end x = d .* x .* d' upperinds = [LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] > I[1]] - J = jacobian(x->link(dist, x), x) + J = ForwardDiff.jacobian(x->link(dist, x), x) J = J[upperinds, upperinds] logpdf_turing = logpdf_with_trans(dist, x, true) @test logpdf(dist, x) - _logabsdet(J) ≈ logpdf_turing - - # Multi-sample tests comprising vectors of matrices. - N = 10 - x = rand(dist) - xs = [x for _ in 1:N] - multi_sample_tests(dist, x, xs, N) - end ################################## Miscelaneous old tests ################################## @@ -279,10 +237,10 @@ end g1 = y -> invlink(dist, y, Val(true)) g2 = y -> invlink(dist, y, Val(false)) - @test @aeq jacobian(f1, x) @inferred(Bijectors.simplex_link_jacobian(x, Val(true))) - @test @aeq jacobian(f2, x) @inferred(Bijectors.simplex_link_jacobian(x, Val(false))) - @test @aeq jacobian(g1, y) @inferred(Bijectors.simplex_invlink_jacobian(y, Val(true))) - @test @aeq jacobian(g2, y) @inferred(Bijectors.simplex_invlink_jacobian(y, Val(false))) + @test @aeq ForwardDiff.jacobian(f1, x) @inferred(Bijectors.simplex_link_jacobian(x, Val(true))) + @test @aeq ForwardDiff.jacobian(f2, x) @inferred(Bijectors.simplex_link_jacobian(x, Val(false))) + @test @aeq ForwardDiff.jacobian(g1, y) @inferred(Bijectors.simplex_invlink_jacobian(y, Val(true))) + @test @aeq ForwardDiff.jacobian(g2, y) @inferred(Bijectors.simplex_invlink_jacobian(y, Val(false))) @test @aeq Bijectors.simplex_link_jacobian(x, Val(false)) * Bijectors.simplex_invlink_jacobian(y, Val(false)) I end for i in 1:4