diff --git a/test/neural_net_solution.jl b/test/neural_net_solution.jl index c99e090d..0f8355ec 100644 --- a/test/neural_net_solution.jl +++ b/test/neural_net_solution.jl @@ -59,8 +59,13 @@ n_simul = n_batches * nn_params ÷ (n_vars * 10) n_burnin = 500 scheduler_period = 15000 -s = ParameterSchedulers.Stateful(CosAnneal(.001, 1e-8, scheduler_period)) -# s = ParameterSchedulers.Stateful(SinDecay2(.001, 1e-6, 500)) +# s = ParameterSchedulers.Stateful(CosAnneal(.001, 1e-8, scheduler_period)) +s = ParameterSchedulers.Stateful(Sequence([ CosAnneal(.001, 1e-5, 5000), + Exp(start = 1e-5, decay = .9995), + Exp(start = 1e-6, decay = .999)], + [scheduler_period ÷ 3, scheduler_period ÷ 3, scheduler_period ÷ 3])) + + shcks = randn(n_shocks, n_burnin + n_simul) @@ -87,6 +92,32 @@ train_loader = Flux.DataLoader((outputs, inputs), batchsize = n_simul ÷ n_batch losses = [] # Training loop for epoch in 1:scheduler_period + # if (epoch ≥ 5000 && + # epoch ≤ 10000 && + if epoch % 500 == 0#) + shcks = randn(n_shocks, n_burnin + n_simul) + + sims = get_irf(Smets_Wouters_2007, shocks = shcks, periods = 0, levels = true) + + if normalise + mn = get_mean(Smets_Wouters_2007, derivatives = false) + + stddev = get_std(Smets_Wouters_2007, derivatives = false) + + normalised_sims = collect((sims[:,n_burnin:end,1] .- mn) ./ stddev) + + inputs = Float32.(vcat(normalised_sims[:,1:end - 1], shcks[:,n_burnin + 1:n_burnin + n_simul])) + + outputs = Float32.(normalised_sims[:,2:end]) + else + inputs = Float32.(vcat(collect(sims[:,n_burnin:n_burnin + n_simul - 1,1]), shcks[:,n_burnin + 1:n_burnin + n_simul])) + + outputs = Float32.(collect(sims[:,n_burnin+1:n_burnin + n_simul,1])) + end + + train_loader = Flux.DataLoader((outputs, inputs), batchsize = n_simul ÷ n_batches, shuffle = true) + end + for (out,in) in train_loader lss, grads = Flux.withgradient(neural_net) do nn sqrt(Flux.mse(out, nn(in)))