From 9fe91d2ee8fae93fb524bec554b0e18e2deaa3f3 Mon Sep 17 00:00:00 2001 From: thorek1 Date: Fri, 27 Oct 2023 09:05:47 +0100 Subject: [PATCH] fix kalman filter pullback and estimation test --- src/MacroModelling.jl | 2 +- test/test_estimation.jl | 22 +++++++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/MacroModelling.jl b/src/MacroModelling.jl index 8bdc557d..bbe851db 100644 --- a/src/MacroModelling.jl +++ b/src/MacroModelling.jl @@ -5767,7 +5767,7 @@ function calculate_kalman_filter_loglikelihood(𝓂::ℳ, data::AbstractArray{Fl if Fdet < eps() return -Inf end - F̄ = RF.lu(F, check = false) + F̄ = ℒ.lu(F, check = false) if !ℒ.issuccess(F̄) return -Inf end diff --git a/test/test_estimation.jl b/test/test_estimation.jl index 6c9679ef..1d47ceaf 100644 --- a/test/test_estimation.jl +++ b/test/test_estimation.jl @@ -1,9 +1,11 @@ using MacroModelling import Turing -import Turing: NUTS, sample, logpdf +import Turing: NUTS, sample, logpdf#, SMC, PG, ESS +# import AdvancedPS import Optim, LineSearches using Random, CSV, DataFrames, MCMCChains, AxisKeys import DynamicPPL: logjoint +# using Pigeons include("models/FS2000.jl") @@ -41,11 +43,29 @@ FS2000_loglikelihood = FS2000_loglikelihood_function(data, FS2000, observables) n_samples = 1000 + +# pt = pigeons(target = TuringLogPotential(FS2000_loglikelihood_function(data, FS2000, observables)), +# # record = [traces; record_default()], +# record = [traces; round_trip; record_default()], +# n_rounds = 7, +# n_chains = 10, +# multithreaded = true, +# show_report = true)#,explorer = AAPS()); + +# samples = Chains(sample_array(pt), variable_names(pt)) + +# import StatsPlots +# StatsPlots.plot(samples) + # using Zygote # Turing.setadbackend(:zygote) +# sampsSMC = sample(FS2000_loglikelihood, SMC( AdvancedPS.resample_systematic), n_samples, progress = true)#, init_params = sol) +# sampsSMC = sample(FS2000_loglikelihood, SMC(1000), n_samples, progress = true)#, init_params = sol) samps = sample(FS2000_loglikelihood, NUTS(), n_samples, progress = true)#, init_params = sol) # println(mean(samps).nt.mean) +# using StatsPlots +# StatsPlots.plot(sampsSMC) Random.seed!(30)