From c6059120078666c4a64d4875103a9913d62d8e60 Mon Sep 17 00:00:00 2001 From: "Edward A. Roualdes" Date: Mon, 24 Jul 2023 16:13:24 -0400 Subject: [PATCH] prep DrMALA for run on Rusty --- src/damping_adapter.jl | 6 +++--- src/draws_initializer.jl | 15 ++++++++++++--- src/drmala.jl | 34 ++++++++++++++++++++++++---------- src/metric_adapter.jl | 8 ++++---- src/noise_adapter.jl | 6 +++--- src/onlinepca.jl | 5 +---- src/pca_adapter.jl | 4 ++-- src/reductionfactor_adapter.jl | 4 ++-- src/steps_adapter.jl | 14 +++++++------- src/stepsize_adapter.jl | 16 ++++++++-------- src/stepsize_initializer.jl | 2 +- src/tools.jl | 13 +++++++++++-- src/trace.jl | 10 ---------- 13 files changed, 78 insertions(+), 59 deletions(-) diff --git a/src/damping_adapter.jl b/src/damping_adapter.jl index d5a4ef9..37a4c7d 100644 --- a/src/damping_adapter.jl +++ b/src/damping_adapter.jl @@ -18,7 +18,7 @@ struct DampingECA{T<:AbstractFloat} <: AbstractDampingAdapter{T} end function DampingECA(initial_damping::AbstractVector{T}; kwargs...) where {T} - return DampingECA(initial_damping, zero(initial_damping)) + return DampingECA(copy(initial_damping), copy(initial_damping)) end function update!(deca::DampingECA, m, zpositions, stepsize, idx, args...; kwargs...) @@ -41,7 +41,7 @@ struct DampingConstant{T<:AbstractFloat} <: AbstractDampingAdapter{T} end function DampingConstant(initial_damping::AbstractVector; kwargs...) - return DampingConstant(initial_damping, initial_damping) + return DampingConstant(copy(initial_damping), copy(initial_damping)) end function update!(dc::DampingConstant, args...; kwargs...) end @@ -54,7 +54,7 @@ struct DampingMALT{T<:AbstractFloat} <: AbstractDampingAdapter{T} end function DampingMALT(initial_damping::AbstractVector, args...; kwargs...) - return DampingMALT(initial_damping, initial_damping) + return DampingMALT(copy(initial_damping), copy(initial_damping)) end function update!(dmalt::DampingMALT, m, gamma, stepsize, args...; damping_coefficient=1, kwargs...) diff --git a/src/draws_initializer.jl b/src/draws_initializer.jl index 9fc87ad..42bb277 100644 --- a/src/draws_initializer.jl +++ b/src/draws_initializer.jl @@ -34,7 +34,7 @@ struct DrawsInitializerStan end function initialize_draws!(initializer::DrawsInitializerStan, draws, rngs, ldg!; kwargs...) for chain in axes(draws, 3) - @views draws[1, :, chain] = stan_initialize_draw( + draws[1, :, chain] .= stan_initialize_draw( draws[1, :, chain], ldg!, rngs[chain]; kwargs... ) end @@ -47,6 +47,7 @@ function stan_initialize_draw(position, ldg!, rng; radius=2, attempts=100, kwarg dims = length(position) q = copy(position) gradient = similar(q) + momentum = similar(q) while a < attempts && !initialized q .= radius .* (2 .* rand(rng, T, dims) .- 1) @@ -57,10 +58,18 @@ function stan_initialize_draw(position, ldg!, rng; radius=2, attempts=100, kwarg end g = sum(gradient) - if isfinite(g) && !isnan(g) - initialized &= true + if !isfinite(g) || isnan(g) + initialized = false + continue end + # randn!(momentum) + # ld = leapfrog!(q, momentum, ldg!, gradient, 1, 1) + + # if !isfinite(ld) || isnan(ld) || any(isnan.(q)) || any(isnan.(momentum)) + # initialized = false + # end + a += 1 end diff --git a/src/drmala.jl b/src/drmala.jl index 548e94c..d223ea4 100644 --- a/src/drmala.jl +++ b/src/drmala.jl @@ -10,6 +10,8 @@ struct DrMALA{T} <: AbstractDrMALA{T} reductionfactor::Vector{T} damping::Vector{T} noise::Vector{T} + drift::Vector{T} + acceptanceprob::Vector{T} dims::Int chains::Int end @@ -33,8 +35,10 @@ function DrMALA( pca ./= mapslices(norm, pca, dims = 1) damping = ones(T, 1) noise = exp.(-2 .* damping .* stepsize) + drift = (1 .- noise .^ 2) ./ 2 + acceptanceprob = 2 .* rand(T, chains) .- 1 return DrMALA( - momentum, metric, pca, steps, stepsize, reductionfactor, damping, noise, D, chains + momentum, metric, pca, steps, stepsize, reductionfactor, damping, noise, drift, acceptanceprob, D, chains ) end @@ -45,7 +49,7 @@ function sample!( warmup=iterations, draws_initializer=DrawsInitializerStan(), stepsize_initializer=StepsizeInitializerStan(), - steps_adapter=StepsPCA(sampler.steps), + steps_adapter=StepsPCA(copy(sampler.steps)), stepsize_adapter=StepsizeDualAverage(sampler.stepsize; δ=0.5), reductionfactor_adapter=ReductionFactorDualAverage(sampler.reductionfactor; reductionfactor_δ = 0.9), metric_adapter=MetricOnlineMoments(sampler.metric), @@ -75,19 +79,25 @@ function sample!( end function transition!(sampler::DrMALA, m, ldg!, draws, rngs, trace; - J = 3, kwargs...) + J = 3, nonreversible_update = false, malt = true, + kwargs...) nt = get(kwargs, :threads, Threads.nthreads()) chains = size(draws, 3) - pca = sampler.pca ./ mapslices(norm, sampler.pca, dims = 1) reduction_factor = sampler.reductionfactor[1] + if malt + randn!(sampler.momentum) + end + damping = sampler.damping[1] + metric = sampler.metric[:, 1] + metric ./= maximum(metric) + Threads.@threads for it in 1:nt for chain in it:nt:chains stepsize = sampler.stepsize[chain] steps = sampler.steps[chain] - metric = sampler.metric[:, 1] - metric ./= maximum(metric) noise = sampler.noise[chain] - damping = sampler.damping[1] + drift = sampler.drift[chain] + acceptanceprob = sampler.acceptanceprob[chain:chain] local info acceptstats = zeros(steps) @@ -106,8 +116,11 @@ function transition!(sampler::DrMALA, m, ldg!, draws, rngs, trace; stepsize, 1, noise, + drift, + acceptanceprob, J, reduction_factor, + nonreversible_update, 1000; kwargs..., ) @@ -119,8 +132,6 @@ function transition!(sampler::DrMALA, m, ldg!, draws, rngs, trace; info = (; info..., damping, reductionfactor = reduction_factor, - pca = pca[:, 1], - # previousposition = draws[m, :, chain], acceptstat = mean(acceptstats), finalacceptstat = maybe_mean(finalacceptstats), steps = sum(retried)) @@ -147,6 +158,7 @@ function adapt!( damping_adapter, noise_adapter, drift_adapter; + steps_coefficient = 1, kwargs..., ) @@ -158,10 +170,12 @@ function adapt!( update!(noise_adapter, sampler.damping, sampler.stepsize; kwargs...) set!(sampler, noise_adapter; kwargs...) + sampler.drift .= (1 .- sampler.noise .^ 2) ./ 2 if schedule.firstwindow <= m <= schedule.lastwindow final_accept_stats = trace.finalacceptstat[m + 1, :] update!(reductionfactor_adapter, maybe_mean(final_accept_stats); kwargs...) + set!(sampler, reductionfactor_adapter; kwargs...) positions = draws[m + 1, :, :] @@ -175,7 +189,7 @@ function adapt!( ) lambda = sqrt.(lambda_max(pca_adapter)) - update!(steps_adapter, m + 1, lambda, sampler.stepsize, pca_adapter.opca.n[1]; kwargs...) + update!(steps_adapter, m + 1, steps_coefficient * lambda, sampler.stepsize, pca_adapter.opca.n[1]; kwargs...) update!(damping_adapter, m + 1, lambda, sampler.stepsize; kwargs...) end diff --git a/src/metric_adapter.jl b/src/metric_adapter.jl index 30b88a9..0cfb917 100644 --- a/src/metric_adapter.jl +++ b/src/metric_adapter.jl @@ -23,7 +23,7 @@ function MetricOnlineMoments( ) where {T} dims, metrics = size(initial_metric) om = OnlineMoments(T, dims, metrics) - return MetricOnlineMoments(om, initial_metric) + return MetricOnlineMoments(om, copy(initial_metric)) end function update!( @@ -51,7 +51,7 @@ struct MetricConstant{T<:AbstractFloat} <: AbstractMetricAdapter{T} end function MetricConstant(initial_metric::AbstractMatrix, args...; kwargs...) - return MetricConstant(initial_metric) + return MetricConstant(copy(initial_metric)) end function metric_mean(mc::MetricConstant, args...; kwargs...) @@ -69,7 +69,7 @@ struct MetricECA{T<:AbstractFloat} <: AbstractMetricAdapter{T} end function MetricECA(initial_metric::AbstractMatrix, args...; kwargs...) - return MetricECA(initial_metric) + return MetricECA(copy(initial_metric)) end function update!(meca::MetricECA, sigma, idx, args...; kwargs...) @@ -92,7 +92,7 @@ function MetricFisherDivergence( dims, metrics = size(initial_metric) om = OnlineMoments(T, dims, metrics) og = OnlineMoments(T, dims, metrics) - return MetricFisherDivergence(om, og, initial_metric) + return MetricFisherDivergence(om, og, copy(initial_metric)) end function update!( diff --git a/src/noise_adapter.jl b/src/noise_adapter.jl index 99349ad..3a0a9fe 100644 --- a/src/noise_adapter.jl +++ b/src/noise_adapter.jl @@ -18,7 +18,7 @@ struct NoiseECA{T<:AbstractFloat} <: AbstractNoiseAdapter{T} end function NoiseECA(initial_noise::AbstractVector{T}; kwargs...) where {T} - return NoiseECA(initial_noise, zero(initial_noise)) + return NoiseECA(copy(initial_noise), copy(initial_noise)) end function update!(neca::NoiseECA, damping, args...; kwargs...) @@ -46,7 +46,7 @@ struct NoiseConstant{T<:AbstractFloat} <: AbstractNoiseAdapter{T} end function NoiseConstant(initial_noise::AbstractVector; kwargs...) - return NoiseConstant(initial_noise, initial_noise) + return NoiseConstant(copy(initial_noise), copy(initial_noise)) end # function set!(sampler, nc::NoiseConstant, args...; kwargs...) end @@ -61,7 +61,7 @@ struct NoiseMALT{T<:AbstractFloat} <: AbstractNoiseAdapter{T} end function NoiseMALT(initial_noise::AbstractVector; kwargs...) - return NoiseMALT(initial_noise, initial_noise) + return NoiseMALT(copy(initial_noise), copy(initial_noise)) end function update!(nmalt::NoiseMALT, damping, stepsize, args...; noise_reduction_factor = 1e-2, kwargs...) diff --git a/src/onlinepca.jl b/src/onlinepca.jl index d63ad76..c399afa 100644 --- a/src/onlinepca.jl +++ b/src/onlinepca.jl @@ -62,7 +62,6 @@ function update!(opca::OnlinePCA, x::AbstractMatrix, location::AbstractMatrix, s T = eltype(x) u = Vector{T}(undef, dims) - f = Vector{T}(undef, dims) n = opca.n[1] l = opca.l @@ -70,9 +69,7 @@ function update!(opca::OnlinePCA, x::AbstractMatrix, location::AbstractMatrix, s n += 1 u .= (x[:, chain] .- location[:, metric]) ./ scale[:, metric] w = 1 / n - f .= (n - 1 - l) .* w .* opca.pc[:, pca] - f .+= (1 + l) .* w .* u .* (u' * opca.pc[:, pca]) ./ norm(opca.pc[:, pca]) - opca.pc[:, pca] .= f + opca.pc[:, pca] .= w .* ((n - 1 - l) .* opca.pc[:, pca] .+ (1 + l) .* u .* (u' * opca.pc[:, pca]) ./ norm(opca.pc[:, pca])) end opca.n[1] = n diff --git a/src/pca_adapter.jl b/src/pca_adapter.jl index 5dc2062..b92199f 100644 --- a/src/pca_adapter.jl +++ b/src/pca_adapter.jl @@ -23,7 +23,7 @@ end function PCAOnline(initial_pca::AbstractMatrix{T}; l = 2, kwargs...) where {T} dims, pcas = size(initial_pca) opca = OnlinePCA(T, dims, pcas, convert(T, l)::T) - return PCAOnline(opca, initial_pca) + return PCAOnline(opca, copy(initial_pca)) end PCAOnline(dims; kwargs...) = PCAOnline(Float64, dims; kwargs...) @@ -53,7 +53,7 @@ struct PCAConstant{T<:AbstractFloat} <: AbstractPCAAdapter{T} end function PCAConstant(initial_pca::AbstractMatrix, args...; kwargs...) - return PCAConstant(initial_pca) + return PCAConstant(copy(initial_pca)) end function update!(pca::PCAConstant, args...; kwargs...) end diff --git a/src/reductionfactor_adapter.jl b/src/reductionfactor_adapter.jl index 5f9239d..e425e96 100644 --- a/src/reductionfactor_adapter.jl +++ b/src/reductionfactor_adapter.jl @@ -19,7 +19,7 @@ end function ReductionFactorDualAverage(initial_reductionfactor::AbstractVector{T}, args...; reductionfactor_δ = 0.95, kwargs...) where {T} da = DualAverage(1, T; μ = -4) - return ReductionFactorDualAverage(da, initial_reductionfactor, initial_reductionfactor, fill(convert(T, reductionfactor_δ)::T, 1)) + return ReductionFactorDualAverage(da, copy(initial_reductionfactor), copy(initial_reductionfactor), fill(convert(T, reductionfactor_δ)::T, 1)) end function update!(sa::ReductionFactorDualAverage, α, args...; kwargs...) @@ -41,7 +41,7 @@ struct ReductionFactorConstant{T<:AbstractFloat} <: AbstractReductionFactorAdapt end function ReductionFactorConstant(initial_reductionfactor::AbstractVector, args...; kwargs...) - return ReductionFactorConstant(initial_reductionfactor, initial_reductionfactor) + return ReductionFactorConstant(copy(initial_reductionfactor), copy(initial_reductionfactor)) end function update!(sc::ReductionFactorConstant, args...; kwargs...) diff --git a/src/steps_adapter.jl b/src/steps_adapter.jl index a638cad..b1e3900 100644 --- a/src/steps_adapter.jl +++ b/src/steps_adapter.jl @@ -15,11 +15,11 @@ function update!(sa::StepsPCA, m, lambda_max, stepsize::AbstractVector, n, args. L = Iterators.cycle(1:length(lambda_max)) for (l, i) in zip(L, 1:length(stepsize)) step = lambda_max[l] / stepsize[i] - step = ifelse(isfinite(step), step, 10) - step = w * step + (1 - w) * 10 + step = ifelse(isfinite(step), step, sa.steps[i]) + # step = w * step + (1 - w) * 10 step = clamp(step, 1, max_steps) - step = round(Int, min(m, step)) - sa.steps[i] = step + step = min(m, step) + sa.steps[i] = round(Int, step) end end @@ -40,7 +40,7 @@ struct StepsConstant{T<:Integer} <: AbstractStepsAdapter{T} end function StepsConstant(initial_steps::AbstractVector, args...; kwargs...) - return StepsConstant(initial_steps) + return StepsConstant(copy(initial_steps)) end function update!(sc::StepsConstant, args...; kwargs...) @@ -55,8 +55,8 @@ struct StepsTrajectorylengthDualAverage{T<:AbstractFloat} <: AbstractStepsAdapte end function StepsTrajectorylengthDualAverage(initial_steps::AbstractVector{Int}, stepsize, args...; kwargs...) - tla = TrajectorylengthDualAverageLDG(stepsize) - return StepsTrajectorylengthDualAverage(initial_steps, tla) + tla = TrajectorylengthDualAverageLDG(copy(stepsize)) + return StepsTrajectorylengthDualAverage(copy(initial_steps), tla) end function update!(sa::StepsTrajectorylengthDualAverage, m, αs, previouspositions, proposedpositions, proposedmomentum, stepsize, ldg!, args...; max_steps = 1000, kwargs...) diff --git a/src/stepsize_adapter.jl b/src/stepsize_adapter.jl index 616f119..0b90916 100644 --- a/src/stepsize_adapter.jl +++ b/src/stepsize_adapter.jl @@ -29,8 +29,8 @@ function StepsizeAdam( adam = Adam(chains, warmup, T; kwargs...) return StepsizeAdam( adam, - initial_stepsize, - initial_stepsize, + copy(initial_stepsize), + copy(initial_stepsize), convert(T, δ)::T, convert(T, stepsize_smoothing_factor)::T, ) @@ -49,7 +49,7 @@ function update!(ssa::StepsizeAdam, abar, m, args...; stepsize_smooth=true, kwar end function reset!(ssa::StepsizeAdam, args...; kwargs...) - reset!(ssa.adam; initial_stepsize=ssa.stepsize, kwargs...) + reset!(ssa.adam; initial_stepsize=copy(ssa.stepsize), kwargs...) end struct StepsizeDualAverage{T<:AbstractFloat} <: AbstractStepsizeAdapter{T} @@ -70,8 +70,8 @@ function StepsizeDualAverage( chains = length(initial_stepsize) da = DualAverage(chains, T; kwargs...) return StepsizeDualAverage(da, - initial_stepsize, - initial_stepsize, + copy(initial_stepsize), + copy(initial_stepsize), fill(convert(T, δ)::T, 1) ) end @@ -99,7 +99,7 @@ function StepsizeConstant( initial_stepsize::AbstractVector{T}; kwargs... ) where {T<:AbstractFloat} chains = length(initial_stepsize) - return StepsizeConstant(initial_stepsize, initial_stepsize) + return StepsizeConstant(copy(initial_stepsize), copy(initial_stepsize)) end function update!(ssc::StepsizeConstant, args...; kwargs...) end @@ -119,7 +119,7 @@ end function StepsizeECA( initial_stepsize::AbstractVector{T}; kwargs... ) where {T<:AbstractFloat} - return StepsizeECA(initial_stepsize, initial_stepsize) + return StepsizeECA(copy(initial_stepsize), copy(initial_stepsize)) end function update!(seca::StepsizeECA, ldg!, positions, scale, idx, args...; kwargs...) @@ -156,7 +156,7 @@ function StepsizeGradientPCA( opca = OnlinePCA(T, dims, 1, convert(T, l)::T) om = OnlineMoments(T, dims, 1) ssda = StepsizeDualAverage(0.5 * ones(T, 1); δ = 0.8) - return StepsizeGradientPCA(initial_stepsize, initial_stepsize, opca, om, ssda, stepsize_smoothing_factor) + return StepsizeGradientPCA(copy(initial_stepsize), copy(initial_stepsize), opca, om, ssda, stepsize_smoothing_factor) end function update!(ssg::StepsizeGradientPCA, αs, positions, ldg!, scale, args...; stepsize_factor = 0.5, stepsize_smooth=true, kwargs...) diff --git a/src/stepsize_initializer.jl b/src/stepsize_initializer.jl index 8019ac9..7244f39 100644 --- a/src/stepsize_initializer.jl +++ b/src/stepsize_initializer.jl @@ -90,7 +90,7 @@ function stan_init_stepsize(stepsize, metric, rng, ldg!, position; kwargs...) elseif direction == -1 && !(ΔH < dh) break else - stepsize = direction == 1 ? 2 * stepsize : stepsize / 2 + stepsize = direction == 1 ? 2 * stepsize : 0.5 * stepsize end if stepsize > 1e7 diff --git a/src/tools.jl b/src/tools.jl index f6fd05c..5a0fab8 100644 --- a/src/tools.jl +++ b/src/tools.jl @@ -9,8 +9,11 @@ function drhmc!( stepsize, steps, noise, + drift, + acceptance_probability, J, reduction_factor, + nonreversible_update, maxdeltaH; kwargs..., ) @@ -70,8 +73,7 @@ function drhmc!( ptries[j + 1] = 1 - avec[j + 1] end - # TODO non-reversible update - accepted = rand(rng, T) < avec[j + 1] + accepted = abs(acceptance_probability[]) < avec[j + 1] if accepted jf = j break @@ -81,11 +83,18 @@ function drhmc!( if accepted position_next .= qj momentum .= pj + acceptance_probability[] *= -avec[jf + 1] else position_next .= position momentum .*= -1 end + acceptance_probability[] = if nonreversible_update + (acceptance_probability[] + 1 + drift) % 2 - 1 + else + rand(rng, T) + end + return (; accepted, divergent, diff --git a/src/trace.jl b/src/trace.jl index 7ae8e4b..6417a40 100644 --- a/src/trace.jl +++ b/src/trace.jl @@ -246,11 +246,6 @@ function trace(sampler::DrMALA{T}, iterations) where {T} damping=zeros(T, iterations, dims, chains), noise=zeros(T, iterations, dims, chains), ld=zeros(T, iterations, chains), - # previousmomentum=zeros(T, dims, chains), - # momentum=zeros(T, dims, chains), - # position=zeros(T, dims, chains), - pca=zeros(T, iterations, dims, chains), - previousposition=zeros(T, dims, chains), retries=zeros(Int, 3, iterations, chains), reductionfactor=zeros(T, iterations) ) @@ -273,12 +268,7 @@ function record!(sampler::DrMALA{T}, trace::NamedTuple, info, iteration, chain) end end trace[:reductionfactor][iteration] = info[:reductionfactor] - # trace[:previousmomentum] .= trace[:momentum] trace[:noise][iteration, :, chain] .= info[:noise] trace[:damping][iteration, :, chain] .= info[:damping] - # trace[:previousposition][:, chain] .= info[:previousposition] - # trace[:momentum] .= info[:momentum] - # trace[:position] .= info[:position] - trace[:pca][iteration, :, chain] .= info[:pca] trace[:retries][info[:retries], iteration, chain] += 1 end