From be6cc5b41f7bdbde4e25b740718b316bbcce9fb5 Mon Sep 17 00:00:00 2001 From: thorek1 Date: Fri, 1 Nov 2024 00:23:33 +0100 Subject: [PATCH] n_par matters most, depth also but less so --- test/neural_net_solution.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/test/neural_net_solution.jl b/test/neural_net_solution.jl index 74f0b707..b52967cc 100644 --- a/test/neural_net_solution.jl +++ b/test/neural_net_solution.jl @@ -196,8 +196,8 @@ inputs /= 6 ## Create Neural Network -n_hidden = max(256, n_vars * 2) -n_hidden_small = max(256, n_vars * 2) +n_hidden = max(512, n_vars * 2) +n_hidden_small = max(128, n_vars * 2) if recurrent neural_net = Chain( Dense(n_inputs, n_hidden, asinh), @@ -213,8 +213,8 @@ else Dense(n_hidden, n_hidden, leakyrelu), # going to 256 brings it down to .0016 Dense(n_hidden, n_hidden_small, tanh_fast), # without these i get to .0032 and relnorm .0192 Dense(n_hidden_small, n_hidden_small, leakyrelu), # without these i get to .0032 and relnorm .0192, with these it goes to .002 and .0123 - Dense(n_hidden_small, n_hidden_small, tanh_fast), - Dense(n_hidden_small, n_hidden_small, leakyrelu), + # Dense(n_hidden_small, n_hidden_small, tanh_fast), + # Dense(n_hidden_small, n_hidden_small, leakyrelu), Dense(n_hidden_small, n_vars, tanh_fast)) else neural_net = Chain( Dense(n_inputs, n_hidden, asinh), @@ -306,7 +306,6 @@ end # BSON.@load "post_ADAM.bson" neural_net - plot(losses[500:end], yaxis=:log) eta_sched_plot = ParameterSchedulers.Stateful(CosAnneal(.001, 1e-10, n_epochs*length(train_loader))) lr = [ParameterSchedulers.next!(eta_sched_plot) for i in 1:n_epochs*length(train_loader)]