Skip to content

Commit

Permalink
diff schedule and update sample during training
Browse files Browse the repository at this point in the history
  • Loading branch information
thorek1 committed Oct 29, 2024
1 parent 2f96b9b commit 86672e8
Showing 1 changed file with 33 additions and 2 deletions.
35 changes: 33 additions & 2 deletions test/neural_net_solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,13 @@ n_simul = n_batches * nn_params ÷ (n_vars * 10)
n_burnin = 500
scheduler_period = 15000

s = ParameterSchedulers.Stateful(CosAnneal(.001, 1e-8, scheduler_period))
# s = ParameterSchedulers.Stateful(SinDecay2(.001, 1e-6, 500))
# s = ParameterSchedulers.Stateful(CosAnneal(.001, 1e-8, scheduler_period))
s = ParameterSchedulers.Stateful(Sequence([ CosAnneal(.001, 1e-5, 5000),
Exp(start = 1e-5, decay = .9995),
Exp(start = 1e-6, decay = .999)],
[scheduler_period ÷ 3, scheduler_period ÷ 3, scheduler_period ÷ 3]))



shcks = randn(n_shocks, n_burnin + n_simul)

Expand All @@ -87,6 +92,32 @@ train_loader = Flux.DataLoader((outputs, inputs), batchsize = n_simul ÷ n_batch
losses = []
# Training loop
for epoch in 1:scheduler_period
# if (epoch ≥ 5000 &&
# epoch ≤ 10000 &&
if epoch % 500 == 0#)
shcks = randn(n_shocks, n_burnin + n_simul)

sims = get_irf(Smets_Wouters_2007, shocks = shcks, periods = 0, levels = true)

if normalise
mn = get_mean(Smets_Wouters_2007, derivatives = false)

stddev = get_std(Smets_Wouters_2007, derivatives = false)

normalised_sims = collect((sims[:,n_burnin:end,1] .- mn) ./ stddev)

inputs = Float32.(vcat(normalised_sims[:,1:end - 1], shcks[:,n_burnin + 1:n_burnin + n_simul]))

outputs = Float32.(normalised_sims[:,2:end])
else
inputs = Float32.(vcat(collect(sims[:,n_burnin:n_burnin + n_simul - 1,1]), shcks[:,n_burnin + 1:n_burnin + n_simul]))

outputs = Float32.(collect(sims[:,n_burnin+1:n_burnin + n_simul,1]))
end

train_loader = Flux.DataLoader((outputs, inputs), batchsize = n_simul ÷ n_batches, shuffle = true)
end

for (out,in) in train_loader
lss, grads = Flux.withgradient(neural_net) do nn
sqrt(Flux.mse(out, nn(in)))
Expand Down

0 comments on commit 86672e8

Please sign in to comment.