diff --git a/test/neural_net_solution.jl b/test/neural_net_solution.jl index 2a27fd70..b1f5bb63 100644 --- a/test/neural_net_solution.jl +++ b/test/neural_net_solution.jl @@ -7,6 +7,10 @@ using Optim using FluxOptTools using StatsPlots using Sobol +# using FileIO +# using ParquetFiles +# using FeatherFiles +# using Arrow using Parquet using LinearAlgebra @@ -194,6 +198,12 @@ inputs /= 6 inout = (inputs_17_200000 = vec(inputs), outputs_6_200000 = vec(outputs)) write_parquet("test_data.parquet",inout) +# save("test_input.feather",inputs) +# save("test_output.feather",outputs) + +# save("test_input.arrow",inputs) +# save("test_output.arrow",outputs) + n_epochs = 300 @@ -205,7 +215,7 @@ n_batches = length(train_loader) ## Create Neural Network -n_hidden = max(64, n_vars * 2) +n_hidden = max(128, n_vars * 2) if recurrent neural_net = Chain( Dense(n_inputs, n_hidden, asinh), @@ -217,7 +227,7 @@ if recurrent Dense(n_hidden, n_vars)) else if normalise - neural_net = Chain( Dense(n_inputs, n_hidden, tanh), + neural_net = Chain( Dense(n_inputs, n_hidden, celu), Dense(n_hidden, n_hidden, celu), Dense(n_hidden, n_hidden, celu), Dense(n_hidden, n_hidden, celu), @@ -235,6 +245,19 @@ else end end +# Pretrain with L-BFGS + +pretrain_loader = Flux.DataLoader((outputs, inputs), batchsize = (n_time_steps * n_parameter_draws) ÷ 8, shuffle = true) + +for (out,inp) in pretrain_loader + loss_func() = sqrt(Flux.mse(out, neural_net(inp))) + pars = Flux.params(neural_net) + lossfun, gradfun, fg!, p0 = optfuns(loss_func, pars) + res = Optim.optimize(Optim.only_fg!(fg!), p0, Optim.Options(iterations=100, show_trace=true)) +end +# end + + # Setup optimiser # optim = Flux.setup(Flux.Adam(), neural_net) optim = Flux.setup(Flux.Optimiser(Flux.ClipNorm(1), Flux.AdamW()), neural_net) @@ -266,7 +289,7 @@ for epoch in 1:n_epochs push!(losses, lss) # logging, outside gradient context - if length(losses) % print_every == 0 println("Epoch: $epoch; Loss: $(sum(losses[end-print_every+1:end])/print_every); η: $(optim.layers[1].weight.rule.opts[2].opts[1].eta); λ: $(optim.layers[1].weight.rule.opts[2].opts[2].lambda)") end + if length(losses) % print_every == 0 println("Epoch: $epoch; Loss: $(sum(losses[end-print_every+1:end])/print_every)") end #; η: $(optim.layers[1].weight.rule.opts[2].opts[1].eta); λ: $(optim.layers[1].weight.rule.opts[2].opts[2].lambda)") end end # if epoch % 10 == 0 println("Epoch: $epoch; Loss: $(sum(losses[end-99:end])/100); Opt state: $(optim.layers[1].weight.rule)") end