Skip to content

Commit

Permalink
n_par matters most, depth also but less so
Browse files Browse the repository at this point in the history
  • Loading branch information
thorek1 committed Oct 31, 2024
1 parent 2b7db0a commit be6cc5b
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions test/neural_net_solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ inputs /= 6


## Create Neural Network
n_hidden = max(256, n_vars * 2)
n_hidden_small = max(256, n_vars * 2)
n_hidden = max(512, n_vars * 2)
n_hidden_small = max(128, n_vars * 2)

if recurrent
neural_net = Chain( Dense(n_inputs, n_hidden, asinh),
Expand All @@ -213,8 +213,8 @@ else
Dense(n_hidden, n_hidden, leakyrelu), # going to 256 brings it down to .0016
Dense(n_hidden, n_hidden_small, tanh_fast), # without these i get to .0032 and relnorm .0192
Dense(n_hidden_small, n_hidden_small, leakyrelu), # without these i get to .0032 and relnorm .0192, with these it goes to .002 and .0123
Dense(n_hidden_small, n_hidden_small, tanh_fast),
Dense(n_hidden_small, n_hidden_small, leakyrelu),
# Dense(n_hidden_small, n_hidden_small, tanh_fast),
# Dense(n_hidden_small, n_hidden_small, leakyrelu),
Dense(n_hidden_small, n_vars, tanh_fast))
else
neural_net = Chain( Dense(n_inputs, n_hidden, asinh),
Expand Down Expand Up @@ -306,7 +306,6 @@ end

# BSON.@load "post_ADAM.bson" neural_net


plot(losses[500:end], yaxis=:log)
eta_sched_plot = ParameterSchedulers.Stateful(CosAnneal(.001, 1e-10, n_epochs*length(train_loader)))
lr = [ParameterSchedulers.next!(eta_sched_plot) for i in 1:n_epochs*length(train_loader)]
Expand Down

0 comments on commit be6cc5b

Please sign in to comment.