diff --git a/Project.toml b/Project.toml index 25fcc2c13..07d68608e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.10.15" +version = "0.10.16" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/sampler.jl b/src/sampler.jl index 9a09f37df..8b32bef9d 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -77,7 +77,15 @@ function AbstractMCMC.step( initialize_parameters!(vi, kwargs[:init_params], spl) # Update joint log probability. - model(rng, vi, _spl) + # TODO: fix properly by using sampler and evaluation contexts + # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 + # and https://github.com/TuringLang/Turing.jl/issues/1563 + # to avoid that existing variables are resampled + if _spl isa SampleFromUniform + model(rng, vi, SampleFromPrior()) + else + model(rng, vi, _spl) + end end return initialstep(rng, model, spl, vi; kwargs...) diff --git a/test/sampler.jl b/test/sampler.jl index bd2783104..4959bf845 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -32,76 +32,84 @@ end @testset "Initial parameters" begin # dummy algorithm that just returns initial value and does not perform any sampling - struct OnlyInitAlg end + abstract type OnlyInitAlg end + struct OnlyInitAlgDefault <: OnlyInitAlg end + struct OnlyInitAlgUniform <: OnlyInitAlg end function DynamicPPL.initialstep( rng::Random.AbstractRNG, model::Model, - ::Sampler{OnlyInitAlg}, + ::Sampler{<:OnlyInitAlg}, vi::AbstractVarInfo; kwargs..., ) return vi, nothing end - DynamicPPL.getspace(::Sampler{OnlyInitAlg}) = () + DynamicPPL.getspace(::Sampler{<:OnlyInitAlg}) = () - # model with one variable: initialization p = 0.2 - @model function coinflip() - p ~ Beta(1, 1) - 10 ~ Binomial(25, p) - end - model = coinflip() - sampler = Sampler(OnlyInitAlg()) - lptrue = logpdf(Binomial(25, 0.2), 10) - chain = sample(model, sampler, 1; init_params = 0.2, progress = false) - @test chain[1].metadata.p.vals == [0.2] - @test getlogp(chain[1]) == lptrue + # initial samplers + DynamicPPL.initialsampler(::Sampler{OnlyInitAlgUniform}) = SampleFromUniform() + @test DynamicPPL.initialsampler(Sampler(OnlyInitAlgDefault())) == SampleFromPrior() - # parallel sampling - chains = sample( - model, sampler, MCMCThreads(), 1, 10; - init_params = 0.2, progress = false, - ) - for c in chains - @test c[1].metadata.p.vals == [0.2] - @test getlogp(c[1]) == lptrue - end + for alg in (OnlyInitAlgDefault(), OnlyInitAlgUniform()) + # model with one variable: initialization p = 0.2 + @model function coinflip() + p ~ Beta(1, 1) + 10 ~ Binomial(25, p) + end + model = coinflip() + sampler = Sampler(alg) + lptrue = logpdf(Binomial(25, 0.2), 10) + chain = sample(model, sampler, 1; init_params = 0.2, progress = false) + @test chain[1].metadata.p.vals == [0.2] + @test getlogp(chain[1]) == lptrue - # model with two variables: initialization s = 4, m = -1 - @model function twovars() - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - end - model = twovars() - lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1) - chain = sample(model, sampler, 1; init_params = [4, -1], progress = false) - @test chain[1].metadata.s.vals == [4] - @test chain[1].metadata.m.vals == [-1] - @test getlogp(chain[1]) == lptrue + # parallel sampling + chains = sample( + model, sampler, MCMCThreads(), 1, 10; + init_params = 0.2, progress = false, + ) + for c in chains + @test c[1].metadata.p.vals == [0.2] + @test getlogp(c[1]) == lptrue + end - # parallel sampling - chains = sample( - model, sampler, MCMCThreads(), 1, 10; - init_params = [4, -1], progress = false, - ) - for c in chains - @test c[1].metadata.s.vals == [4] - @test c[1].metadata.m.vals == [-1] - @test getlogp(c[1]) == lptrue - end + # model with two variables: initialization s = 4, m = -1 + @model function twovars() + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + end + model = twovars() + lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1) + chain = sample(model, sampler, 1; init_params = [4, -1], progress = false) + @test chain[1].metadata.s.vals == [4] + @test chain[1].metadata.m.vals == [-1] + @test getlogp(chain[1]) == lptrue + + # parallel sampling + chains = sample( + model, sampler, MCMCThreads(), 1, 10; + init_params = [4, -1], progress = false, + ) + for c in chains + @test c[1].metadata.s.vals == [4] + @test c[1].metadata.m.vals == [-1] + @test getlogp(c[1]) == lptrue + end - # set only m = -1 - chain = sample(model, sampler, 1; init_params = [missing, -1], progress = false) - @test !ismissing(chain[1].metadata.s.vals[1]) - @test chain[1].metadata.m.vals == [-1] + # set only m = -1 + chain = sample(model, sampler, 1; init_params = [missing, -1], progress = false) + @test !ismissing(chain[1].metadata.s.vals[1]) + @test chain[1].metadata.m.vals == [-1] - # parallel sampling - chains = sample( - model, sampler, MCMCThreads(), 1, 10; - init_params = [missing, -1], progress = false, - ) - for c in chains - @test !ismissing(c[1].metadata.s.vals[1]) - @test c[1].metadata.m.vals == [-1] + # parallel sampling + chains = sample( + model, sampler, MCMCThreads(), 1, 10; + init_params = [missing, -1], progress = false, + ) + for c in chains + @test !ismissing(c[1].metadata.s.vals[1]) + @test c[1].metadata.m.vals == [-1] + end end end end