From e772b65b6f098a13245f7c4eb38fecaf11340280 Mon Sep 17 00:00:00 2001 From: thorek1 Date: Thu, 31 Oct 2024 16:01:29 +0100 Subject: [PATCH] width first network --- test/neural_net_solution.jl | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/test/neural_net_solution.jl b/test/neural_net_solution.jl index 80003d05..348795a8 100644 --- a/test/neural_net_solution.jl +++ b/test/neural_net_solution.jl @@ -197,12 +197,13 @@ outputs /= 6 inputs /= 6 # inputs .+= .5 -h5write("data.h5", "inputs", inputs) -h5write("data.h5", "outputs", outputs) +# h5write("data.h5", "inputs", inputs) +# h5write("data.h5", "outputs", outputs) ## Create Neural Network -n_hidden = max(256, n_vars * 2) +n_hidden = max(1024, n_vars * 2) +n_hidden_small = max(128, n_vars * 2) if recurrent neural_net = Chain( Dense(n_inputs, n_hidden, asinh), @@ -215,12 +216,12 @@ if recurrent else if normalise neural_net = Chain( Dense(n_inputs, n_hidden, tanh_fast), - Dense(n_hidden, n_hidden, leakyrelu), - Dense(n_hidden, n_hidden, tanh_fast), - Dense(n_hidden, n_hidden, leakyrelu), - Dense(n_hidden, n_hidden, tanh_fast), - Dense(n_hidden, n_hidden, leakyrelu), - Dense(n_hidden, n_vars, tanh_fast)) + Dense(n_hidden, n_hidden_small, leakyrelu), + Dense(n_hidden_small, n_hidden_small, tanh_fast), + Dense(n_hidden_small, n_hidden_small, leakyrelu), + Dense(n_hidden_small, n_hidden_small, tanh_fast), + Dense(n_hidden_small, n_hidden_small, leakyrelu), + Dense(n_hidden_small, n_vars, tanh_fast)) else neural_net = Chain( Dense(n_inputs, n_hidden, asinh), Dense(n_hidden, n_hidden, asinh), @@ -254,8 +255,12 @@ if pretrain # BSON.@load "post_LBFGS.bson" neural_net end + + # Setup optimiser +n_epochs = 300 + # optim = Flux.setup(Flux.Adam(), neural_net) optim = Flux.setup(Flux.Optimiser(Flux.ClipNorm(1), Flux.AdamW()), neural_net) @@ -268,11 +273,8 @@ eta_sched = ParameterSchedulers.Stateful(CosAnneal(.001, 1e-10, n_epochs)) # [scheduler_period ÷ 3, scheduler_period ÷ 3, scheduler_period ÷ 3])) - # Training loop -n_epochs = 300 - batchsize = 1024 train_loader = Flux.DataLoader((outputs, inputs), batchsize = batchsize, shuffle = true)