diff --git a/Project.toml b/Project.toml index 562180c08..6714b1ddb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.21.6" +version = "0.22.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -24,7 +24,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" AbstractMCMC = "2, 3.0, 4" AbstractPPL = "0.5.3" BangBang = "0.3" -Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9, 0.10" +Bijectors = "0.11" ChainRulesCore = "0.9.7, 0.10, 1" ConstructionBase = "1" Distributions = "0.23.8, 0.24, 0.25" diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 8c1dd88d4..acd51e288 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -405,7 +405,7 @@ end # Vector-based ones. function link!!( - t::StaticTransformation{<:Bijectors.Bijector{1}}, + t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, spl::AbstractSampler, model::Model, @@ -420,7 +420,7 @@ function link!!( end function invlink!!( - t::StaticTransformation{<:Bijectors.Bijector{1}}, + t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, spl::AbstractSampler, model::Model, @@ -452,9 +452,8 @@ julia> using DynamicPPL, Distributions, Bijectors julia> @model demo() = x ~ Normal() demo (generic function with 2 methods) -julia> # By subtyping `Bijector{1}`, we inherit the `(inv)link!!` defined for - # bijectors which acts on 1-dimensional arrays, i.e. vectors. - struct MyBijector <: Bijectors.Bijector{1} end +julia> # By subtyping `Transform`, we inherit the `(inv)link!!`. + struct MyBijector <: Bijectors.Transform end julia> # Define some dummy `inverse` which will be used in the `link!!` call. Bijectors.inverse(f::MyBijector) = identity diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 4b345a6ff..a445bf87a 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -648,7 +648,7 @@ Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarIn # Allow usage of `NamedBijector` too. function link!!( - t::StaticTransformation{<:Bijectors.NamedBijector}, + t::StaticTransformation{<:Bijectors.NamedTransform}, vi::SimpleVarInfo{<:NamedTuple}, spl::AbstractSampler, model::Model, @@ -663,7 +663,7 @@ function link!!( end function invlink!!( - t::StaticTransformation{<:Bijectors.NamedBijector}, + t::StaticTransformation{<:Bijectors.NamedTransform}, vi::SimpleVarInfo{<:NamedTuple}, spl::AbstractSampler, model::Model, diff --git a/src/test_utils.jl b/src/test_utils.jl index b5f3a80b3..45b9fff07 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -686,7 +686,7 @@ Simple model for which [`default_transformation`](@ref) returns a [`StaticTransf end function DynamicPPL.default_transformation(::Model{typeof(demo_static_transformation)}) - b = Bijectors.stack(Bijectors.Exp{0}(), Bijectors.Identity{0}()) + b = Bijectors.stack(Bijectors.elementwise(exp), identity) return DynamicPPL.StaticTransformation(b) end diff --git a/test/Project.toml b/test/Project.toml index 283b37fda..29bf12ebb 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -23,7 +23,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractMCMC = "2.1, 3.0, 4" AbstractPPL = "0.5.1, 0.6" -Bijectors = "0.9.5, 0.10" +Bijectors = "0.11" Distributions = "0.25" DistributionsAD = "0.6.3" Documenter = "0.26.1, 0.27" diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 9a7f6e549..a5b57f5f6 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -64,6 +64,7 @@ @testset "$(typeof(vi))" for vi in ( SimpleVarInfo(Dict()), SimpleVarInfo(values_constrained), VarInfo(model) ) + vi = SimpleVarInfo(values_constrained) for vn in DynamicPPL.TestUtils.varnames(model) vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) end @@ -108,6 +109,8 @@ @testset "SimpleVarInfo on $(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS + model = DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix() + # We might need to pre-allocate for the variable `m`, so we need # to see whether this is the case. svi_nt = SimpleVarInfo(rand(NamedTuple, model))