Skip to content

Commit

Permalink
width first network
Browse files Browse the repository at this point in the history
  • Loading branch information
thorek1 committed Oct 31, 2024
1 parent f5c9b13 commit e772b65
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions test/neural_net_solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,13 @@ outputs /= 6
inputs /= 6
# inputs .+= .5

h5write("data.h5", "inputs", inputs)
h5write("data.h5", "outputs", outputs)
# h5write("data.h5", "inputs", inputs)
# h5write("data.h5", "outputs", outputs)


## Create Neural Network
n_hidden = max(256, n_vars * 2)
n_hidden = max(1024, n_vars * 2)
n_hidden_small = max(128, n_vars * 2)

if recurrent
neural_net = Chain( Dense(n_inputs, n_hidden, asinh),
Expand All @@ -215,12 +216,12 @@ if recurrent
else
if normalise
neural_net = Chain( Dense(n_inputs, n_hidden, tanh_fast),
Dense(n_hidden, n_hidden, leakyrelu),
Dense(n_hidden, n_hidden, tanh_fast),
Dense(n_hidden, n_hidden, leakyrelu),
Dense(n_hidden, n_hidden, tanh_fast),
Dense(n_hidden, n_hidden, leakyrelu),
Dense(n_hidden, n_vars, tanh_fast))
Dense(n_hidden, n_hidden_small, leakyrelu),
Dense(n_hidden_small, n_hidden_small, tanh_fast),
Dense(n_hidden_small, n_hidden_small, leakyrelu),
Dense(n_hidden_small, n_hidden_small, tanh_fast),
Dense(n_hidden_small, n_hidden_small, leakyrelu),
Dense(n_hidden_small, n_vars, tanh_fast))
else
neural_net = Chain( Dense(n_inputs, n_hidden, asinh),
Dense(n_hidden, n_hidden, asinh),
Expand Down Expand Up @@ -254,8 +255,12 @@ if pretrain
# BSON.@load "post_LBFGS.bson" neural_net
end



# Setup optimiser

n_epochs = 300

# optim = Flux.setup(Flux.Adam(), neural_net)
optim = Flux.setup(Flux.Optimiser(Flux.ClipNorm(1), Flux.AdamW()), neural_net)

Expand All @@ -268,11 +273,8 @@ eta_sched = ParameterSchedulers.Stateful(CosAnneal(.001, 1e-10, n_epochs))
# [scheduler_period ÷ 3, scheduler_period ÷ 3, scheduler_period ÷ 3]))



# Training loop

n_epochs = 300

batchsize = 1024

train_loader = Flux.DataLoader((outputs, inputs), batchsize = batchsize, shuffle = true)
Expand Down

0 comments on commit e772b65

Please sign in to comment.