Skip to content

Commit

Permalink
previous test_estimation script
Browse files Browse the repository at this point in the history
  • Loading branch information
thorek1 committed Dec 18, 2024
1 parent 06d2e1e commit 33e3829
Showing 1 changed file with 8 additions and 24 deletions.
32 changes: 8 additions & 24 deletions test/test_estimation.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using Revise
using MacroModelling
import Turing, Pigeons, Zygote
import Turing: NUTS, sample, logpdf, AutoZygote
Expand All @@ -9,7 +8,7 @@ import DynamicPPL
include("../models/FS2000.jl")

# load data
dat = CSV.read("test/data/FS2000_data.csv", DataFrame)
dat = CSV.read("data/FS2000_data.csv", DataFrame)
data = KeyedArray(Array(dat)',Variable = Symbol.("log_".*names(dat)),Time = 1:size(dat)[1])
data = log.(data)

Expand Down Expand Up @@ -37,7 +36,7 @@ Turing.@model function FS2000_loglikelihood_function(data, m)
all_params ~ Turing.arraydist(dists)

if DynamicPPL.leafcontext(__context__) !== DynamicPPL.PriorContext()
Turing.@addlogprob! get_loglikelihood(m, data, all_params, verbose = false)
Turing.@addlogprob! get_loglikelihood(m, data, all_params)
end
end

Expand All @@ -48,15 +47,15 @@ n_samples = 1000

# using Zygote
# Turing.setadbackend(:zygote)
# samps = @time sample(FS2000_loglikelihood, NUTS(), n_samples, progress = true, initial_params = FS2000.parameter_values)
samps = @time sample(FS2000_loglikelihood, NUTS(), n_samples, progress = true, initial_params = FS2000.parameter_values)

# println("Mean variable values (ForwardDiff): $(mean(samps).nt.mean)")
println("Mean variable values (ForwardDiff): $(mean(samps).nt.mean)")

# samps = @time sample(FS2000_loglikelihood, NUTS(adtype = Turing.AutoZygote()), n_samples, progress = true, initial_params = FS2000.parameter_values)
samps = @time sample(FS2000_loglikelihood, NUTS(adtype = Turing.AutoZygote()), n_samples, progress = true, initial_params = FS2000.parameter_values)

# println("Mean variable values (Zygote): $(mean(samps).nt.mean)")
println("Mean variable values (Zygote): $(mean(samps).nt.mean)")

# sample_nuts = mean(samps).nt.mean
sample_nuts = mean(samps).nt.mean


# generate a Pigeons log potential
Expand Down Expand Up @@ -84,27 +83,12 @@ end
# define a specific initialization for this model
Pigeons.initialization(::Pigeons.TuringLogPotential{typeof(FS2000_loglikelihood_function)}, ::AbstractRNG, ::Int64) = deepcopy(XMAX)

Pigeons.pigeons(target = FS2000_lp,
pt = @time Pigeons.pigeons(target = FS2000_lp,
record = [Pigeons.traces; Pigeons.round_trip; Pigeons.record_default()],
n_chains = 1,
n_rounds = 10,
multithreaded = true)

pt = @profview Pigeons.pigeons(target = FS2000_lp,
record = [Pigeons.traces; Pigeons.round_trip; Pigeons.record_default()],
n_chains = 1,
n_rounds = 1,
multithreaded = false)
# ────────────────────────────────────────────────────────────────────────────
# scans restarts time(s) allc(B) log(Z₁/Z₀) min(αₑ) mean(αₑ)
# ────────── ────────── ────────── ────────── ────────── ────────── ──────────
# 2 0 4.8 3.21e+09 0 0.982 0.982
# 4 0 7.17 4.99e+09 0 1 1
# 8 0 12.5 8.73e+09 0 1 1
# 16 0 28.5 1.97e+10 0 1 1
# 32 0 56.2 3.89e+10 0 1 1
# 64 0 109 7.47e+10 0 1 1
# 128 0 230 1.57e+11 0 1 1
samps = MCMCChains.Chains(pt)

println("Mean variable values (Pigeons): $(mean(samps).nt.mean)")
Expand Down

0 comments on commit 33e3829

Please sign in to comment.