Skip to content

Commit

Permalink
prep DrMALA for run on Rusty
Browse files Browse the repository at this point in the history
  • Loading branch information
roualdes committed Jul 24, 2023
1 parent 20b3e7c commit c605912
Show file tree
Hide file tree
Showing 13 changed files with 78 additions and 59 deletions.
6 changes: 3 additions & 3 deletions src/damping_adapter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ struct DampingECA{T<:AbstractFloat} <: AbstractDampingAdapter{T}
end

function DampingECA(initial_damping::AbstractVector{T}; kwargs...) where {T}
return DampingECA(initial_damping, zero(initial_damping))
return DampingECA(copy(initial_damping), copy(initial_damping))
end

function update!(deca::DampingECA, m, zpositions, stepsize, idx, args...; kwargs...)
Expand All @@ -41,7 +41,7 @@ struct DampingConstant{T<:AbstractFloat} <: AbstractDampingAdapter{T}
end

function DampingConstant(initial_damping::AbstractVector; kwargs...)
return DampingConstant(initial_damping, initial_damping)
return DampingConstant(copy(initial_damping), copy(initial_damping))
end

function update!(dc::DampingConstant, args...; kwargs...) end
Expand All @@ -54,7 +54,7 @@ struct DampingMALT{T<:AbstractFloat} <: AbstractDampingAdapter{T}
end

function DampingMALT(initial_damping::AbstractVector, args...; kwargs...)
return DampingMALT(initial_damping, initial_damping)
return DampingMALT(copy(initial_damping), copy(initial_damping))
end

function update!(dmalt::DampingMALT, m, gamma, stepsize, args...; damping_coefficient=1, kwargs...)
Expand Down
15 changes: 12 additions & 3 deletions src/draws_initializer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ struct DrawsInitializerStan end

function initialize_draws!(initializer::DrawsInitializerStan, draws, rngs, ldg!; kwargs...)
for chain in axes(draws, 3)
@views draws[1, :, chain] = stan_initialize_draw(
draws[1, :, chain] .= stan_initialize_draw(
draws[1, :, chain], ldg!, rngs[chain]; kwargs...
)
end
Expand All @@ -47,6 +47,7 @@ function stan_initialize_draw(position, ldg!, rng; radius=2, attempts=100, kwarg
dims = length(position)
q = copy(position)
gradient = similar(q)
momentum = similar(q)

while a < attempts && !initialized
q .= radius .* (2 .* rand(rng, T, dims) .- 1)
Expand All @@ -57,10 +58,18 @@ function stan_initialize_draw(position, ldg!, rng; radius=2, attempts=100, kwarg
end

g = sum(gradient)
if isfinite(g) && !isnan(g)
initialized &= true
if !isfinite(g) || isnan(g)
initialized = false
continue
end

# randn!(momentum)
# ld = leapfrog!(q, momentum, ldg!, gradient, 1, 1)

# if !isfinite(ld) || isnan(ld) || any(isnan.(q)) || any(isnan.(momentum))
# initialized = false
# end

a += 1
end

Expand Down
34 changes: 24 additions & 10 deletions src/drmala.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ struct DrMALA{T} <: AbstractDrMALA{T}
reductionfactor::Vector{T}
damping::Vector{T}
noise::Vector{T}
drift::Vector{T}
acceptanceprob::Vector{T}
dims::Int
chains::Int
end
Expand All @@ -33,8 +35,10 @@ function DrMALA(
pca ./= mapslices(norm, pca, dims = 1)
damping = ones(T, 1)
noise = exp.(-2 .* damping .* stepsize)
drift = (1 .- noise .^ 2) ./ 2
acceptanceprob = 2 .* rand(T, chains) .- 1
return DrMALA(
momentum, metric, pca, steps, stepsize, reductionfactor, damping, noise, D, chains
momentum, metric, pca, steps, stepsize, reductionfactor, damping, noise, drift, acceptanceprob, D, chains
)
end

Expand All @@ -45,7 +49,7 @@ function sample!(
warmup=iterations,
draws_initializer=DrawsInitializerStan(),
stepsize_initializer=StepsizeInitializerStan(),
steps_adapter=StepsPCA(sampler.steps),
steps_adapter=StepsPCA(copy(sampler.steps)),
stepsize_adapter=StepsizeDualAverage(sampler.stepsize; δ=0.5),
reductionfactor_adapter=ReductionFactorDualAverage(sampler.reductionfactor; reductionfactor_δ = 0.9),
metric_adapter=MetricOnlineMoments(sampler.metric),
Expand Down Expand Up @@ -75,19 +79,25 @@ function sample!(
end

function transition!(sampler::DrMALA, m, ldg!, draws, rngs, trace;
J = 3, kwargs...)
J = 3, nonreversible_update = false, malt = true,
kwargs...)
nt = get(kwargs, :threads, Threads.nthreads())
chains = size(draws, 3)
pca = sampler.pca ./ mapslices(norm, sampler.pca, dims = 1)
reduction_factor = sampler.reductionfactor[1]
if malt
randn!(sampler.momentum)
end
damping = sampler.damping[1]
metric = sampler.metric[:, 1]
metric ./= maximum(metric)

Threads.@threads for it in 1:nt
for chain in it:nt:chains
stepsize = sampler.stepsize[chain]
steps = sampler.steps[chain]
metric = sampler.metric[:, 1]
metric ./= maximum(metric)
noise = sampler.noise[chain]
damping = sampler.damping[1]
drift = sampler.drift[chain]
acceptanceprob = sampler.acceptanceprob[chain:chain]

local info
acceptstats = zeros(steps)
Expand All @@ -106,8 +116,11 @@ function transition!(sampler::DrMALA, m, ldg!, draws, rngs, trace;
stepsize,
1,
noise,
drift,
acceptanceprob,
J,
reduction_factor,
nonreversible_update,
1000;
kwargs...,
)
Expand All @@ -119,8 +132,6 @@ function transition!(sampler::DrMALA, m, ldg!, draws, rngs, trace;
info = (; info...,
damping,
reductionfactor = reduction_factor,
pca = pca[:, 1],
# previousposition = draws[m, :, chain],
acceptstat = mean(acceptstats),
finalacceptstat = maybe_mean(finalacceptstats),
steps = sum(retried))
Expand All @@ -147,6 +158,7 @@ function adapt!(
damping_adapter,
noise_adapter,
drift_adapter;
steps_coefficient = 1,
kwargs...,
)

Expand All @@ -158,10 +170,12 @@ function adapt!(

update!(noise_adapter, sampler.damping, sampler.stepsize; kwargs...)
set!(sampler, noise_adapter; kwargs...)
sampler.drift .= (1 .- sampler.noise .^ 2) ./ 2

if schedule.firstwindow <= m <= schedule.lastwindow
final_accept_stats = trace.finalacceptstat[m + 1, :]
update!(reductionfactor_adapter, maybe_mean(final_accept_stats); kwargs...)
set!(sampler, reductionfactor_adapter; kwargs...)

positions = draws[m + 1, :, :]

Expand All @@ -175,7 +189,7 @@ function adapt!(
)

lambda = sqrt.(lambda_max(pca_adapter))
update!(steps_adapter, m + 1, lambda, sampler.stepsize, pca_adapter.opca.n[1]; kwargs...)
update!(steps_adapter, m + 1, steps_coefficient * lambda, sampler.stepsize, pca_adapter.opca.n[1]; kwargs...)

update!(damping_adapter, m + 1, lambda, sampler.stepsize; kwargs...)
end
Expand Down
8 changes: 4 additions & 4 deletions src/metric_adapter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function MetricOnlineMoments(
) where {T}
dims, metrics = size(initial_metric)
om = OnlineMoments(T, dims, metrics)
return MetricOnlineMoments(om, initial_metric)
return MetricOnlineMoments(om, copy(initial_metric))
end

function update!(
Expand Down Expand Up @@ -51,7 +51,7 @@ struct MetricConstant{T<:AbstractFloat} <: AbstractMetricAdapter{T}
end

function MetricConstant(initial_metric::AbstractMatrix, args...; kwargs...)
return MetricConstant(initial_metric)
return MetricConstant(copy(initial_metric))
end

function metric_mean(mc::MetricConstant, args...; kwargs...)
Expand All @@ -69,7 +69,7 @@ struct MetricECA{T<:AbstractFloat} <: AbstractMetricAdapter{T}
end

function MetricECA(initial_metric::AbstractMatrix, args...; kwargs...)
return MetricECA(initial_metric)
return MetricECA(copy(initial_metric))
end

function update!(meca::MetricECA, sigma, idx, args...; kwargs...)
Expand All @@ -92,7 +92,7 @@ function MetricFisherDivergence(
dims, metrics = size(initial_metric)
om = OnlineMoments(T, dims, metrics)
og = OnlineMoments(T, dims, metrics)
return MetricFisherDivergence(om, og, initial_metric)
return MetricFisherDivergence(om, og, copy(initial_metric))
end

function update!(
Expand Down
6 changes: 3 additions & 3 deletions src/noise_adapter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ struct NoiseECA{T<:AbstractFloat} <: AbstractNoiseAdapter{T}
end

function NoiseECA(initial_noise::AbstractVector{T}; kwargs...) where {T}
return NoiseECA(initial_noise, zero(initial_noise))
return NoiseECA(copy(initial_noise), copy(initial_noise))
end

function update!(neca::NoiseECA, damping, args...; kwargs...)
Expand Down Expand Up @@ -46,7 +46,7 @@ struct NoiseConstant{T<:AbstractFloat} <: AbstractNoiseAdapter{T}
end

function NoiseConstant(initial_noise::AbstractVector; kwargs...)
return NoiseConstant(initial_noise, initial_noise)
return NoiseConstant(copy(initial_noise), copy(initial_noise))
end

# function set!(sampler, nc::NoiseConstant, args...; kwargs...) end
Expand All @@ -61,7 +61,7 @@ struct NoiseMALT{T<:AbstractFloat} <: AbstractNoiseAdapter{T}
end

function NoiseMALT(initial_noise::AbstractVector; kwargs...)
return NoiseMALT(initial_noise, initial_noise)
return NoiseMALT(copy(initial_noise), copy(initial_noise))
end

function update!(nmalt::NoiseMALT, damping, stepsize, args...; noise_reduction_factor = 1e-2, kwargs...)
Expand Down
5 changes: 1 addition & 4 deletions src/onlinepca.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,14 @@ function update!(opca::OnlinePCA, x::AbstractMatrix, location::AbstractMatrix, s

T = eltype(x)
u = Vector{T}(undef, dims)
f = Vector{T}(undef, dims)

n = opca.n[1]
l = opca.l
@views for (pca, metric, chain) in zip(Iterators.cycle(1:pcas), Iterators.cycle(1:metrics), 1:chains)
n += 1
u .= (x[:, chain] .- location[:, metric]) ./ scale[:, metric]
w = 1 / n
f .= (n - 1 - l) .* w .* opca.pc[:, pca]
f .+= (1 + l) .* w .* u .* (u' * opca.pc[:, pca]) ./ norm(opca.pc[:, pca])
opca.pc[:, pca] .= f
opca.pc[:, pca] .= w .* ((n - 1 - l) .* opca.pc[:, pca] .+ (1 + l) .* u .* (u' * opca.pc[:, pca]) ./ norm(opca.pc[:, pca]))
end

opca.n[1] = n
Expand Down
4 changes: 2 additions & 2 deletions src/pca_adapter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ end
function PCAOnline(initial_pca::AbstractMatrix{T}; l = 2, kwargs...) where {T}
dims, pcas = size(initial_pca)
opca = OnlinePCA(T, dims, pcas, convert(T, l)::T)
return PCAOnline(opca, initial_pca)
return PCAOnline(opca, copy(initial_pca))
end

PCAOnline(dims; kwargs...) = PCAOnline(Float64, dims; kwargs...)
Expand Down Expand Up @@ -53,7 +53,7 @@ struct PCAConstant{T<:AbstractFloat} <: AbstractPCAAdapter{T}
end

function PCAConstant(initial_pca::AbstractMatrix, args...; kwargs...)
return PCAConstant(initial_pca)
return PCAConstant(copy(initial_pca))
end

function update!(pca::PCAConstant, args...; kwargs...) end
Expand Down
4 changes: 2 additions & 2 deletions src/reductionfactor_adapter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ end

function ReductionFactorDualAverage(initial_reductionfactor::AbstractVector{T}, args...; reductionfactor_δ = 0.95, kwargs...) where {T}
da = DualAverage(1, T; μ = -4)
return ReductionFactorDualAverage(da, initial_reductionfactor, initial_reductionfactor, fill(convert(T, reductionfactor_δ)::T, 1))
return ReductionFactorDualAverage(da, copy(initial_reductionfactor), copy(initial_reductionfactor), fill(convert(T, reductionfactor_δ)::T, 1))
end

function update!(sa::ReductionFactorDualAverage, α, args...; kwargs...)
Expand All @@ -41,7 +41,7 @@ struct ReductionFactorConstant{T<:AbstractFloat} <: AbstractReductionFactorAdapt
end

function ReductionFactorConstant(initial_reductionfactor::AbstractVector, args...; kwargs...)
return ReductionFactorConstant(initial_reductionfactor, initial_reductionfactor)
return ReductionFactorConstant(copy(initial_reductionfactor), copy(initial_reductionfactor))
end

function update!(sc::ReductionFactorConstant, args...; kwargs...)
Expand Down
14 changes: 7 additions & 7 deletions src/steps_adapter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ function update!(sa::StepsPCA, m, lambda_max, stepsize::AbstractVector, n, args.
L = Iterators.cycle(1:length(lambda_max))
for (l, i) in zip(L, 1:length(stepsize))
step = lambda_max[l] / stepsize[i]
step = ifelse(isfinite(step), step, 10)
step = w * step + (1 - w) * 10
step = ifelse(isfinite(step), step, sa.steps[i])
# step = w * step + (1 - w) * 10
step = clamp(step, 1, max_steps)
step = round(Int, min(m, step))
sa.steps[i] = step
step = min(m, step)
sa.steps[i] = round(Int, step)
end
end

Expand All @@ -40,7 +40,7 @@ struct StepsConstant{T<:Integer} <: AbstractStepsAdapter{T}
end

function StepsConstant(initial_steps::AbstractVector, args...; kwargs...)
return StepsConstant(initial_steps)
return StepsConstant(copy(initial_steps))
end

function update!(sc::StepsConstant, args...; kwargs...)
Expand All @@ -55,8 +55,8 @@ struct StepsTrajectorylengthDualAverage{T<:AbstractFloat} <: AbstractStepsAdapte
end

function StepsTrajectorylengthDualAverage(initial_steps::AbstractVector{Int}, stepsize, args...; kwargs...)
tla = TrajectorylengthDualAverageLDG(stepsize)
return StepsTrajectorylengthDualAverage(initial_steps, tla)
tla = TrajectorylengthDualAverageLDG(copy(stepsize))
return StepsTrajectorylengthDualAverage(copy(initial_steps), tla)
end

function update!(sa::StepsTrajectorylengthDualAverage, m, αs, previouspositions, proposedpositions, proposedmomentum, stepsize, ldg!, args...; max_steps = 1000, kwargs...)
Expand Down
16 changes: 8 additions & 8 deletions src/stepsize_adapter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ function StepsizeAdam(
adam = Adam(chains, warmup, T; kwargs...)
return StepsizeAdam(
adam,
initial_stepsize,
initial_stepsize,
copy(initial_stepsize),
copy(initial_stepsize),
convert(T, δ)::T,
convert(T, stepsize_smoothing_factor)::T,
)
Expand All @@ -49,7 +49,7 @@ function update!(ssa::StepsizeAdam, abar, m, args...; stepsize_smooth=true, kwar
end

function reset!(ssa::StepsizeAdam, args...; kwargs...)
reset!(ssa.adam; initial_stepsize=ssa.stepsize, kwargs...)
reset!(ssa.adam; initial_stepsize=copy(ssa.stepsize), kwargs...)
end

struct StepsizeDualAverage{T<:AbstractFloat} <: AbstractStepsizeAdapter{T}
Expand All @@ -70,8 +70,8 @@ function StepsizeDualAverage(
chains = length(initial_stepsize)
da = DualAverage(chains, T; kwargs...)
return StepsizeDualAverage(da,
initial_stepsize,
initial_stepsize,
copy(initial_stepsize),
copy(initial_stepsize),
fill(convert(T, δ)::T, 1)
)
end
Expand Down Expand Up @@ -99,7 +99,7 @@ function StepsizeConstant(
initial_stepsize::AbstractVector{T}; kwargs...
) where {T<:AbstractFloat}
chains = length(initial_stepsize)
return StepsizeConstant(initial_stepsize, initial_stepsize)
return StepsizeConstant(copy(initial_stepsize), copy(initial_stepsize))
end

function update!(ssc::StepsizeConstant, args...; kwargs...) end
Expand All @@ -119,7 +119,7 @@ end
function StepsizeECA(
initial_stepsize::AbstractVector{T}; kwargs...
) where {T<:AbstractFloat}
return StepsizeECA(initial_stepsize, initial_stepsize)
return StepsizeECA(copy(initial_stepsize), copy(initial_stepsize))
end

function update!(seca::StepsizeECA, ldg!, positions, scale, idx, args...; kwargs...)
Expand Down Expand Up @@ -156,7 +156,7 @@ function StepsizeGradientPCA(
opca = OnlinePCA(T, dims, 1, convert(T, l)::T)
om = OnlineMoments(T, dims, 1)
ssda = StepsizeDualAverage(0.5 * ones(T, 1); δ = 0.8)
return StepsizeGradientPCA(initial_stepsize, initial_stepsize, opca, om, ssda, stepsize_smoothing_factor)
return StepsizeGradientPCA(copy(initial_stepsize), copy(initial_stepsize), opca, om, ssda, stepsize_smoothing_factor)
end

function update!(ssg::StepsizeGradientPCA, αs, positions, ldg!, scale, args...; stepsize_factor = 0.5, stepsize_smooth=true, kwargs...)
Expand Down
2 changes: 1 addition & 1 deletion src/stepsize_initializer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ function stan_init_stepsize(stepsize, metric, rng, ldg!, position; kwargs...)
elseif direction == -1 && !(ΔH < dh)
break
else
stepsize = direction == 1 ? 2 * stepsize : stepsize / 2
stepsize = direction == 1 ? 2 * stepsize : 0.5 * stepsize
end

if stepsize > 1e7
Expand Down
Loading

0 comments on commit c605912

Please sign in to comment.