Skip to content

Commit

Permalink
loop over simulated data sets
Browse files Browse the repository at this point in the history
  • Loading branch information
thorek1 committed Oct 29, 2024
1 parent b4c5121 commit 58f3e1c
Showing 1 changed file with 13 additions and 18 deletions.
31 changes: 13 additions & 18 deletions test/neural_net_solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,36 +43,31 @@ s = ParameterSchedulers.Stateful(CosAnneal(.001, 1e-6, 500))
optim = Flux.setup(Flux.Adam(), neural_net) # will store optimiser momentum, etc.


# for i in 1:10
n_simul = 1000
n_burnin = 500
for i in 1:10
n_simul = 1000
n_burnin = 500

shcks = randn(n_shocks, n_burnin + n_simul)
shcks = randn(n_shocks, n_burnin + n_simul)

sims = get_irf(Smets_Wouters_2007, shocks = shcks, periods = 0, levels = true)
sims = get_irf(Smets_Wouters_2007, shocks = shcks, periods = 0, levels = true)

normalised_sims = Flux.normalise(collect(sims[:,n_burnin:end,1]), dims=1)
normalised_sims = Flux.normalise(collect(sims[:,n_burnin:end,1]), dims=1)

normalised_sim_slices = Float32.(vcat(normalised_sims[:,1:end - 1], shcks[:,n_burnin + 1:n_burnin + n_simul]))
normalised_sim_slices = Float32.(vcat(normalised_sims[:,1:end - 1], shcks[:,n_burnin + 1:n_burnin + n_simul]))

normalised_out_slices = Float32.(normalised_sims[:,2:end])
normalised_out_slices = Float32.(normalised_sims[:,2:end])

# loss() = sqrt(sum(abs2, out_slices - neural_net(sim_slices)))
# loss() = Flux.mse(neural_net(sim_slices), out_slices)

# Training loop, using the whole data set 1000 times:
losses = []
for epoch in 1:5000
# for (x, y) in loader
# Training loop, using the whole data set 1000 times:
losses = []
for epoch in 1:1000
lss, grads = Flux.withgradient(neural_net) do nn
# Evaluate model and loss inside gradient context:
sqrt(Flux.mse(nn(normalised_sim_slices), normalised_out_slices))
end
Flux.adjust!(optim, ParameterSchedulers.next!(s))
Flux.update!(optim, neural_net, grads[1])
push!(losses, loss) # logging, outside gradient context
if epoch % 100 == 0 println("Epoch: $epoch; Loss: $lss; Opt state: $(optim.layers[1].weight.rule)") end
# end
if epoch % 10 == 0 println("Epoch: $epoch; Loss: $lss; Opt state: $(optim.layers[1].weight.rule)") end
end
end


Expand Down

0 comments on commit 58f3e1c

Please sign in to comment.