diff --git a/test/neural_net_solution.jl b/test/neural_net_solution.jl index 3aaa8839..e46cbcd6 100644 --- a/test/neural_net_solution.jl +++ b/test/neural_net_solution.jl @@ -43,36 +43,31 @@ s = ParameterSchedulers.Stateful(CosAnneal(.001, 1e-6, 500)) optim = Flux.setup(Flux.Adam(), neural_net) # will store optimiser momentum, etc. -# for i in 1:10 -n_simul = 1000 -n_burnin = 500 +for i in 1:10 + n_simul = 1000 + n_burnin = 500 -shcks = randn(n_shocks, n_burnin + n_simul) + shcks = randn(n_shocks, n_burnin + n_simul) -sims = get_irf(Smets_Wouters_2007, shocks = shcks, periods = 0, levels = true) + sims = get_irf(Smets_Wouters_2007, shocks = shcks, periods = 0, levels = true) -normalised_sims = Flux.normalise(collect(sims[:,n_burnin:end,1]), dims=1) + normalised_sims = Flux.normalise(collect(sims[:,n_burnin:end,1]), dims=1) -normalised_sim_slices = Float32.(vcat(normalised_sims[:,1:end - 1], shcks[:,n_burnin + 1:n_burnin + n_simul])) + normalised_sim_slices = Float32.(vcat(normalised_sims[:,1:end - 1], shcks[:,n_burnin + 1:n_burnin + n_simul])) -normalised_out_slices = Float32.(normalised_sims[:,2:end]) + normalised_out_slices = Float32.(normalised_sims[:,2:end]) -# loss() = sqrt(sum(abs2, out_slices - neural_net(sim_slices))) -# loss() = Flux.mse(neural_net(sim_slices), out_slices) - -# Training loop, using the whole data set 1000 times: -losses = [] -for epoch in 1:5000 - # for (x, y) in loader + # Training loop, using the whole data set 1000 times: + losses = [] + for epoch in 1:1000 lss, grads = Flux.withgradient(neural_net) do nn - # Evaluate model and loss inside gradient context: sqrt(Flux.mse(nn(normalised_sim_slices), normalised_out_slices)) end Flux.adjust!(optim, ParameterSchedulers.next!(s)) Flux.update!(optim, neural_net, grads[1]) push!(losses, loss) # logging, outside gradient context - if epoch % 100 == 0 println("Epoch: $epoch; Loss: $lss; Opt state: $(optim.layers[1].weight.rule)") end - # end + if epoch % 10 == 0 println("Epoch: $epoch; Loss: $lss; Opt state: $(optim.layers[1].weight.rule)") end + end end