Skip to content

Commit

Permalink
Fix initialization of parameters for algorithms that use `SampleFromU…
Browse files Browse the repository at this point in the history
…niform` (#232)

This PR is a quick fix for TuringLang/Turing.jl#1563 and TuringLang/Turing.jl#1588.

As explained in TuringLang/Turing.jl#1588 (comment), the problem is that currently `SampleFromUniform` always resamples variables in every run, and hence also initial parameters provided by users are resampled in https://github.com/TuringLang/DynamicPPL.jl/blob/9d4137eb33e83f34c484bf78f9a57f828b3c92a0/src/sampler.jl#L80.

As mentioned in TuringLang/Turing.jl#1588 (comment), a better long term solution would be to fix this inconsistency and use dedicated evaluation and sampling contexts, as suggested in #80.
  • Loading branch information
devmotion committed Apr 18, 2021
1 parent 9d4137e commit 7c8edab
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 58 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.10.15"
version = "0.10.16"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
10 changes: 9 additions & 1 deletion src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,15 @@ function AbstractMCMC.step(
initialize_parameters!(vi, kwargs[:init_params], spl)

# Update joint log probability.
model(rng, vi, _spl)
# TODO: fix properly by using sampler and evaluation contexts
# This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
# and https://github.com/TuringLang/Turing.jl/issues/1563
# to avoid that existing variables are resampled
if _spl isa SampleFromUniform
model(rng, vi, SampleFromPrior())
else
model(rng, vi, _spl)
end
end

return initialstep(rng, model, spl, vi; kwargs...)
Expand Down
120 changes: 64 additions & 56 deletions test/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,76 +32,84 @@
end
@testset "Initial parameters" begin
# dummy algorithm that just returns initial value and does not perform any sampling
struct OnlyInitAlg end
abstract type OnlyInitAlg end
struct OnlyInitAlgDefault <: OnlyInitAlg end
struct OnlyInitAlgUniform <: OnlyInitAlg end
function DynamicPPL.initialstep(
rng::Random.AbstractRNG,
model::Model,
::Sampler{OnlyInitAlg},
::Sampler{<:OnlyInitAlg},
vi::AbstractVarInfo;
kwargs...,
)
return vi, nothing
end
DynamicPPL.getspace(::Sampler{OnlyInitAlg}) = ()
DynamicPPL.getspace(::Sampler{<:OnlyInitAlg}) = ()

# model with one variable: initialization p = 0.2
@model function coinflip()
p ~ Beta(1, 1)
10 ~ Binomial(25, p)
end
model = coinflip()
sampler = Sampler(OnlyInitAlg())
lptrue = logpdf(Binomial(25, 0.2), 10)
chain = sample(model, sampler, 1; init_params = 0.2, progress = false)
@test chain[1].metadata.p.vals == [0.2]
@test getlogp(chain[1]) == lptrue
# initial samplers
DynamicPPL.initialsampler(::Sampler{OnlyInitAlgUniform}) = SampleFromUniform()
@test DynamicPPL.initialsampler(Sampler(OnlyInitAlgDefault())) == SampleFromPrior()

# parallel sampling
chains = sample(
model, sampler, MCMCThreads(), 1, 10;
init_params = 0.2, progress = false,
)
for c in chains
@test c[1].metadata.p.vals == [0.2]
@test getlogp(c[1]) == lptrue
end
for alg in (OnlyInitAlgDefault(), OnlyInitAlgUniform())
# model with one variable: initialization p = 0.2
@model function coinflip()
p ~ Beta(1, 1)
10 ~ Binomial(25, p)
end
model = coinflip()
sampler = Sampler(alg)
lptrue = logpdf(Binomial(25, 0.2), 10)
chain = sample(model, sampler, 1; init_params = 0.2, progress = false)
@test chain[1].metadata.p.vals == [0.2]
@test getlogp(chain[1]) == lptrue

# model with two variables: initialization s = 4, m = -1
@model function twovars()
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
end
model = twovars()
lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1)
chain = sample(model, sampler, 1; init_params = [4, -1], progress = false)
@test chain[1].metadata.s.vals == [4]
@test chain[1].metadata.m.vals == [-1]
@test getlogp(chain[1]) == lptrue
# parallel sampling
chains = sample(
model, sampler, MCMCThreads(), 1, 10;
init_params = 0.2, progress = false,
)
for c in chains
@test c[1].metadata.p.vals == [0.2]
@test getlogp(c[1]) == lptrue
end

# parallel sampling
chains = sample(
model, sampler, MCMCThreads(), 1, 10;
init_params = [4, -1], progress = false,
)
for c in chains
@test c[1].metadata.s.vals == [4]
@test c[1].metadata.m.vals == [-1]
@test getlogp(c[1]) == lptrue
end
# model with two variables: initialization s = 4, m = -1
@model function twovars()
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
end
model = twovars()
lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1)
chain = sample(model, sampler, 1; init_params = [4, -1], progress = false)
@test chain[1].metadata.s.vals == [4]
@test chain[1].metadata.m.vals == [-1]
@test getlogp(chain[1]) == lptrue

# parallel sampling
chains = sample(
model, sampler, MCMCThreads(), 1, 10;
init_params = [4, -1], progress = false,
)
for c in chains
@test c[1].metadata.s.vals == [4]
@test c[1].metadata.m.vals == [-1]
@test getlogp(c[1]) == lptrue
end

# set only m = -1
chain = sample(model, sampler, 1; init_params = [missing, -1], progress = false)
@test !ismissing(chain[1].metadata.s.vals[1])
@test chain[1].metadata.m.vals == [-1]
# set only m = -1
chain = sample(model, sampler, 1; init_params = [missing, -1], progress = false)
@test !ismissing(chain[1].metadata.s.vals[1])
@test chain[1].metadata.m.vals == [-1]

# parallel sampling
chains = sample(
model, sampler, MCMCThreads(), 1, 10;
init_params = [missing, -1], progress = false,
)
for c in chains
@test !ismissing(c[1].metadata.s.vals[1])
@test c[1].metadata.m.vals == [-1]
# parallel sampling
chains = sample(
model, sampler, MCMCThreads(), 1, 10;
init_params = [missing, -1], progress = false,
)
for c in chains
@test !ismissing(c[1].metadata.s.vals[1])
@test c[1].metadata.m.vals == [-1]
end
end
end
end

2 comments on commit 7c8edab

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/34612

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.16 -m "<description of version>" 7c8edab1ab72ce86ffc24e52a96c2aa222ca88c4
git push origin v0.10.16

Please sign in to comment.