diff --git a/Project.toml b/Project.toml index 013c89814..25fcc2c13 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.10.14" +version = "0.10.15" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/varinfo.jl b/src/varinfo.jl index 8d9ffecce..de369cceb 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -722,8 +722,11 @@ function link!(vi::UntypedVarInfo, spl::Sampler) end end function link!(vi::TypedVarInfo, spl::AbstractSampler) + return link!(vi, spl, Val(getspace(spl))) +end +function link!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) vns = _getvns(vi, spl) - return _link!(vi.metadata, vi, vns, Val(getspace(spl))) + return _link!(vi.metadata, vi, vns, spaceval) end @generated function _link!(metadata::NamedTuple{names}, vi, vns, ::Val{space}) where {names, space} expr = Expr(:block) @@ -770,8 +773,11 @@ function invlink!(vi::UntypedVarInfo, spl::AbstractSampler) end end function invlink!(vi::TypedVarInfo, spl::AbstractSampler) + return invlink!(vi, spl, Val(getspace(spl))) +end +function invlink!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) vns = _getvns(vi, spl) - return _invlink!(vi.metadata, vi, vns, Val(getspace(spl))) + return _invlink!(vi.metadata, vi, vns, spaceval) end @generated function _invlink!(metadata::NamedTuple{names}, vi, vns, ::Val{space}) where {names, space} expr = Expr(:block) diff --git a/test/turing/varinfo.jl b/test/turing/varinfo.jl index 4cfa78d4c..3f932a5b9 100644 --- a/test/turing/varinfo.jl +++ b/test/turing/varinfo.jl @@ -67,6 +67,16 @@ @test all(x -> !istrans(vi, x), meta.m.vns) @test meta.s.vals == v_s @test meta.m.vals == v_m + + # Transforming only a subset of the variables + link!(vi, spl, Val((:m, ))) + @test all(x -> !istrans(vi, x), meta.s.vns) + @test all(x -> istrans(vi, x), meta.m.vns) + invlink!(vi, spl, Val((:m, ))) + @test all(x -> !istrans(vi, x), meta.s.vns) + @test all(x -> !istrans(vi, x), meta.m.vns) + @test meta.s.vals == v_s + @test meta.m.vals == v_m end @testset "orders" begin csym = gensym() # unique per model @@ -329,4 +339,4 @@ @test vi.metadata.w.gids[1] == Set([hmc.selector]) @test vi.metadata.u.gids[1] == Set([hmc.selector]) =# end -end \ No newline at end of file +end