Skip to content

Commit

Permalink
Try #454:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] authored Feb 3, 2023
2 parents a2bdc16 + 8d77d78 commit 9d1a707
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 11 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.21.6"
version = "0.22.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -24,7 +24,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
AbstractMCMC = "2, 3.0, 4"
AbstractPPL = "0.5.3"
BangBang = "0.3"
Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9, 0.10"
Bijectors = "0.11"
ChainRulesCore = "0.9.7, 0.10, 1"
ConstructionBase = "1"
Distributions = "0.23.8, 0.24, 0.25"
Expand Down
9 changes: 4 additions & 5 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ end

# Vector-based ones.
function link!!(
t::StaticTransformation{<:Bijectors.Bijector{1}},
t::StaticTransformation{<:Bijectors.Transform},
vi::AbstractVarInfo,
spl::AbstractSampler,
model::Model,
Expand All @@ -420,7 +420,7 @@ function link!!(
end

function invlink!!(
t::StaticTransformation{<:Bijectors.Bijector{1}},
t::StaticTransformation{<:Bijectors.Transform},
vi::AbstractVarInfo,
spl::AbstractSampler,
model::Model,
Expand Down Expand Up @@ -452,9 +452,8 @@ julia> using DynamicPPL, Distributions, Bijectors
julia> @model demo() = x ~ Normal()
demo (generic function with 2 methods)
julia> # By subtyping `Bijector{1}`, we inherit the `(inv)link!!` defined for
# bijectors which acts on 1-dimensional arrays, i.e. vectors.
struct MyBijector <: Bijectors.Bijector{1} end
julia> # By subtyping `Transform`, we inherit the `(inv)link!!`.
struct MyBijector <: Bijectors.Transform end
julia> # Define some dummy `inverse` which will be used in the `link!!` call.
Bijectors.inverse(f::MyBijector) = identity
Expand Down
4 changes: 2 additions & 2 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarIn

# Allow usage of `NamedBijector` too.
function link!!(
t::StaticTransformation{<:Bijectors.NamedBijector},
t::StaticTransformation{<:Bijectors.NamedTransform},
vi::SimpleVarInfo{<:NamedTuple},
spl::AbstractSampler,
model::Model,
Expand All @@ -663,7 +663,7 @@ function link!!(
end

function invlink!!(
t::StaticTransformation{<:Bijectors.NamedBijector},
t::StaticTransformation{<:Bijectors.NamedTransform},
vi::SimpleVarInfo{<:NamedTuple},
spl::AbstractSampler,
model::Model,
Expand Down
2 changes: 1 addition & 1 deletion src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ Simple model for which [`default_transformation`](@ref) returns a [`StaticTransf
end

function DynamicPPL.default_transformation(::Model{typeof(demo_static_transformation)})
b = Bijectors.stack(Bijectors.Exp{0}(), Bijectors.Identity{0}())
b = Bijectors.stack(Bijectors.elementwise(exp), identity)
return DynamicPPL.StaticTransformation(b)
end

Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
AbstractMCMC = "2.1, 3.0, 4"
AbstractPPL = "0.5.1, 0.6"
Bijectors = "0.9.5, 0.10"
Bijectors = "0.11"
Distributions = "0.25"
DistributionsAD = "0.6.3"
Documenter = "0.26.1, 0.27"
Expand Down
3 changes: 3 additions & 0 deletions test/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
@testset "$(typeof(vi))" for vi in (
SimpleVarInfo(Dict()), SimpleVarInfo(values_constrained), VarInfo(model)
)
vi = SimpleVarInfo(values_constrained)
for vn in DynamicPPL.TestUtils.varnames(model)
vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn)
end
Expand Down Expand Up @@ -108,6 +109,8 @@

@testset "SimpleVarInfo on $(nameof(model))" for model in
DynamicPPL.TestUtils.DEMO_MODELS
model = DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix()

# We might need to pre-allocate for the variable `m`, so we need
# to see whether this is the case.
svi_nt = SimpleVarInfo(rand(NamedTuple, model))
Expand Down

0 comments on commit 9d1a707

Please sign in to comment.