Skip to content

Commit

Permalink
set up test suite for different neural net specs
Browse files Browse the repository at this point in the history
  • Loading branch information
thorek1 committed Nov 2, 2024
1 parent 259550d commit 214560b
Showing 1 changed file with 160 additions and 118 deletions.
278 changes: 160 additions & 118 deletions test/neural_net_solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ bounds_range = upper_bounds_par .- lower_bounds_par
outputs = zeros(Float32, n_vars, n_time_steps * n_parameter_draws)
inputs = zeros(Float32, n_inputs, n_time_steps * n_parameter_draws)

Rnadom.seed!(14124)
Random.seed!(14124)

for i in 1:n_parameter_draws
draw = next!(sob)
Expand Down Expand Up @@ -198,126 +198,168 @@ inputs /= 6
# h5write("data.h5", "outputs", outputs)


## Create Neural Network
n_hidden = max(256, n_vars * 2)
n_hidden_small = max(256, n_vars * 2)

Random.seed!(6794)

if recurrent
neural_net = Chain( Dense(n_inputs, n_hidden, asinh),
Flux.LSTM(n_hidden, n_hidden ÷ 2),
Flux.GRU(n_hidden ÷ 2, n_hidden ÷ 2), # optional
Dense(n_hidden ÷ 2, n_hidden ÷ 2, celu),
Dense(n_hidden ÷ 2, n_hidden, celu),
Dense(n_hidden, n_hidden, celu), # optional
Dense(n_hidden, n_vars))
else
if normalise
neural_net = Chain( Dense(n_inputs, n_hidden),
Dense(n_hidden, n_hidden, leakyrelu), # going to 256 brings it down to .0016
Dense(n_hidden, n_hidden_small, tanh_fast), # without these i get to .0032 and relnorm .0192
Dense(n_hidden_small, n_hidden_small, leakyrelu), # without these i get to .0032 and relnorm .0192, with these it goes to .002 and .0123
Dense(n_hidden_small, n_hidden_small, tanh_fast),
Dense(n_hidden_small, n_hidden_small, leakyrelu),
Dense(n_hidden_small, n_vars))
else
neural_net = Chain( Dense(n_inputs, n_hidden, asinh),
Dense(n_hidden, n_hidden, asinh),
Dense(n_hidden, n_hidden, tanh),
Dense(n_hidden, n_hidden, celu),
Dense(n_hidden, n_hidden, celu),
Dense(n_hidden, n_hidden, celu),
Dense(n_hidden, n_vars))
end
end

# Pretrain with L-BFGS
if pretrain
n_pretrain_epochs = 2000
n_pretrain_batches = 128

pretrain_loader = Flux.DataLoader((outputs, inputs), batchsize = (n_time_steps * n_parameter_draws) ÷ n_pretrain_batches, shuffle = true)

for (out,inp) in pretrain_loader
loss_func() = sqrt(Flux.mse(out, neural_net(inp)))
pars = Flux.params(neural_net)
lossfun, gradfun, fg!, p0 = optfuns(loss_func, pars)
res = Optim.optimize(Optim.only_fg!(fg!), p0, Optim.Options(iterations=n_pretrain_epochs, show_trace=true))
if loss_func() < 1e-2 break end
end

# Save and load model

BSON.@save "post_LBFGS.bson" neural_net

# BSON.@load "post_LBFGS.bson" neural_net
end



# Setup optimiser

n_epochs = 100 # 1000 goes to .0016; 300 goes to .0023

# optim = Flux.setup(Flux.Adam(), neural_net)
optim = Flux.setup(Flux.Optimiser(Flux.ClipNorm(1), Flux.AdamW()), neural_net)

# lr_start = 1e-3
# lr_end = 1e-10

# eta_sched = ParameterSchedulers.Stateful(CosAnneal(lr_start, lr_end, n_epochs * n_batches))

lr_start = 1e-3
lr_end = 1e-10

# degree = (log(lr_start) - log(lr_end)) / log((1 - (n_epochs - 1) / n_epochs))

eta_sched = ParameterSchedulers.Stateful(CosAnneal(lr_start, lr_end, n_epochs))
# eta_sched = ParameterSchedulers.Stateful(Exp(start = lr_start, decay = (lr_end / lr_start) ^ (1 / n_epochs)))
# eta_sched = ParameterSchedulers.Stateful(Poly(start = lr_start, degree = 3, max_iter = n_epochs))

# decay_sched = ParameterSchedulers.Stateful(CosAnneal(.00001, 1e-10, n_epochs * n_batches))
# s = ParameterSchedulers.Stateful(Sequence([ CosAnneal(.001, 1e-5, 5000),
# Exp(start = 1e-5, decay = .9995),
# Exp(start = 1e-6, decay = .999)],
# [scheduler_period ÷ 3, scheduler_period ÷ 3, scheduler_period ÷ 3]))


# Training loop

n_hidden = 256
batchsize = 512

train_loader = Flux.DataLoader((outputs, inputs), batchsize = batchsize, shuffle = true)

n_batches = length(train_loader)

print_every = 10
# print_every = 100000 ÷ batchsize

losses = []
for epoch in 1:n_epochs
for (out,inp) in train_loader
lss, grads = Flux.withgradient(neural_net) do nn
sqrt(Flux.mse(out, nn(inp)))
end

Flux.update!(optim, neural_net, grads[1])

push!(losses, lss) # logging, outside gradient context

# if length(losses) % print_every == 0 println("Epoch: $epoch; Loss: $(sum(losses[end-print_every+1:end])/print_every); η: $(optim.layers[1].weight.rule.opts[2].opts[1].eta); λ: $(optim.layers[1].weight.rule.opts[2].opts[2].lambda)") end
end

if epoch % print_every == 0 println("Epoch: $epoch; Loss: $(sum(losses[end-n_batches * print_every+1:end])/(n_batches*print_every)); η: $(optim.layers[1].weight.rule.opts[2].opts[1].eta); λ: $(optim.layers[1].weight.rule.opts[2].opts[2].lambda)") end

sched_update = ParameterSchedulers.next!(eta_sched)

Flux.adjust!(optim; eta = sched_update)
Flux.adjust!(optim; lambda = sched_update * 0.01)
end


n_epochs = 100

activation = :tanh
schedule = :cos
optimiser = :adam

results = []
for activation in [:tanh, :relu, :gelu]
for n_hidden in [256,128,384]
for schedule in [:cos,:poly,:exp]
# for lr_start in [1e-3,5e-3,5e-4,1e-4]
# for lr_end in [1e-8,1e-9,1e-10,1e-11]
for batchsize in [128,256,512]#,1024]#,2048]
# for optimiser in [:adam,:adamw]
# for n_epochs in [100,150]#,300]
## Create Neural Network
# n_hidden = max(256, n_vars * 2)
# n_hidden_small = max(256, n_vars * 2)

Random.seed!(6794)

if activation == :relu
act = leakyrelu
elseif activation == :tanh
act = tanh_fast
elseif activation == :celu
act = celu
elseif activation == :gelu
act = gelu
end

neural_net = Chain( Dense(n_inputs, n_hidden),
Dense(n_hidden, n_hidden, act), # going to 256 brings it down to .0016
Dense(n_hidden, n_hidden, act), # without these i get to .0032 and relnorm .0192
Dense(n_hidden, n_hidden, act), # without these i get to .0032 and relnorm .0192, with these it goes to .002 and .0123
Dense(n_hidden, n_hidden, act),
Dense(n_hidden, n_hidden, act),
Dense(n_hidden, n_vars))


# Setup optimiser

# n_epochs = 100 # 1000 goes to .0016; 300 goes to .0023

if optimiser == :adam
optim = Flux.setup(Flux.Adam(), neural_net)
elseif optimise == :adamw
optim = Flux.setup(Flux.Optimiser(Flux.ClipNorm(1), Flux.AdamW()), neural_net)
end

# lr_start = 1e-3
# lr_end = 1e-10
if schedule == :cos
eta_sched = ParameterSchedulers.Stateful(CosAnneal(lr_start, lr_end, n_epochs))
elseif schedule == :exp
eta_sched = ParameterSchedulers.Stateful(Exp(start = lr_start, decay = (lr_end / lr_start) ^ (1 / n_epochs)))
elseif schedule == :poly
eta_sched = ParameterSchedulers.Stateful(Poly(start = lr_start, degree = 3, max_iter = n_epochs))
end

# Training loop

# batchsize = 512

train_loader = Flux.DataLoader((outputs, inputs), batchsize = batchsize, shuffle = true)

n_batches = length(train_loader)

start_time = time()
losses = []
for epoch in 1:n_epochs
for (out,inp) in train_loader
lss, grads = Flux.withgradient(neural_net) do nn
sqrt(Flux.mse(out, nn(inp)))
end

Flux.update!(optim, neural_net, grads[1])

push!(losses, lss) # logging, outside gradient context

end

sched_update = ParameterSchedulers.next!(eta_sched)

Flux.adjust!(optim; eta = sched_update)
Flux.adjust!(optim; lambda = sched_update * 0.01)
end
end_time = time() # Record end time
elapsed_time = end_time - start_time

relnorm = norm(outputs - neural_net(inputs)) / norm(outputs)

push!(results,[lr_start,lr_end,activation,batchsize,n_epochs,n_hidden, schedule, optimiser, elapsed_time, sum(losses[end-500:end])/(500), relnorm])
println("Finished $(results[end])")
end
end
end
end
# end
# end
# end
# end

# Finished [1.0e-8, 256.0, 100.0, 758.8324751853943, 0.0017534949583932757]
# Finished [1.0e-8, 256.0, 150.0, 1088.4795179367065, 0.0014873802429065108]
# Finished [1.0e-8, 256.0, 300.0, 2198.8943860530853, 0.0011740062618628144]
# Finished [1.0e-8, 512.0, 100.0, 517.70987200737, 0.0019748075865209103]
# Finished [1.0e-8, 512.0, 150.0, 752.7117650508881, 0.001620362396351993]

# Finished Any[0.001, 1.0e-10, :tanh, 256, 100, 64, :cos, 87.80793190002441, 0.0077648805f0]
# Finished Any[0.001, 1.0e-10, :tanh, 512, 100, 64, :cos, 118.03787803649902, 0.009695039f0]
# Finished Any[0.001, 1.0e-10, :tanh, 1024, 100, 64, :cos, 126.16368412971497, 0.013013127f0]
# Finished Any[0.001, 1.0e-10, :relu, 256, 100, 64, :cos, 226.33820700645447, 0.0055024438f0]
# Finished Any[0.001, 1.0e-10, :relu, 512, 100, 64, :cos, 553.1387948989868, 0.006574779f0]
# Finished Any[0.001, 1.0e-10, :relu, 1024, 100, 64, :cos, 1023.6392850875854, 0.0074812747f0]
# Finished Any[0.001, 1.0e-10, :celu, 256, 100, 64, :cos, 117.92614006996155, 0.0063016475f0]
# Finished Any[0.001, 1.0e-10, :celu, 512, 100, 64, :cos, 1207.0850257873535, 0.008859483f0]
# Finished Any[0.001, 1.0e-10, :celu, 1024, 100, 64, :cos, 152.76296281814575, 0.01197234f0]
# Finished Any[0.001, 1.0e-10, :tanh, 256, 100, 128, :cos, 2269.02383184433, 0.0046920013f0]
# Finished Any[0.001, 1.0e-10, :tanh, 512, 100, 128, :cos, 216.41274404525757, 0.00684004f0]
# Finished Any[0.001, 1.0e-10, :tanh, 1024, 100, 128, :cos, 178.79252886772156, 0.008927653f0]
# Finished Any[0.001, 1.0e-10, :relu, 256, 100, 128, :cos, 304.0812249183655, 0.0028577521f0]
# Finished Any[0.001, 1.0e-10, :relu, 512, 100, 128, :cos, 216.76893019676208, 0.0034144435f0]
# Finished Any[0.001, 1.0e-10, :relu, 1024, 100, 128, :cos, 172.68909406661987, 0.004311412f0]
# Finished Any[0.001, 1.0e-10, :celu, 256, 100, 128, :cos, 1578.1288549900055, 0.0035333529f0]
# Finished Any[0.001, 1.0e-10, :celu, 512, 100, 128, :cos, 1232.3901619911194, 0.0057994383f0]
# Finished Any[0.001, 1.0e-10, :celu, 1024, 100, 128, :cos, 276.539302110672, 0.008073823f0]
# Finished Any[0.001, 1.0e-10, :tanh, 256, 100, 256, :cos, 760.0306468009949, 0.0045519243f0]
# Finished Any[0.001, 1.0e-10, :tanh, 512, 100, 256, :cos, 755.0039529800415, 0.005255837f0]
# Finished Any[0.001, 1.0e-10, :tanh, 1024, 100, 256, :cos, 453.7042829990387, 0.007580776f0]
# Finished Any[0.001, 1.0e-10, :relu, 256, 100, 256, :cos, 736.4661679267883, 0.0016993907f0]
# Finished Any[0.001, 1.0e-10, :relu, 512, 100, 256, :cos, 532.9272980690002, 0.0019593309f0]
# Finished Any[0.001, 1.0e-10, :relu, 1024, 100, 256, :cos, 417.4027919769287, 0.0025009874f0]
# Finished Any[0.001, 1.0e-10, :celu, 256, 100, 256, :cos, 964.1439228057861, 0.002458275f0]
# Finished Any[0.001, 1.0e-10, :celu, 512, 100, 256, :cos, 1650.427062034607, 0.0031529295f0]
# Finished Any[0.001, 1.0e-10, :celu, 1024, 100, 256, :cos, 1615.785187959671, 0.0053527183f0]
# Finished Any[0.001, 1.0e-10, :tanh, 256, 100, 64, :poly, 1009.7159638404846, 0.010393971f0]
# Finished Any[0.001, 1.0e-10, :tanh, 512, 100, 64, :poly, 167.65252709388733, 0.012821631f0]
# Finished Any[0.001, 1.0e-10, :tanh, 1024, 100, 64, :poly, 121.7584228515625, 0.015461803f0]
# Finished Any[0.001, 1.0e-10, :relu, 256, 100, 64, :poly, 90.5167019367218, 0.0073260977f0]
# Finished Any[0.001, 1.0e-10, :relu, 512, 100, 64, :poly, 177.72605991363525, 0.00796308f0]
# Finished Any[0.001, 1.0e-10, :relu, 1024, 100, 64, :poly, 132.531240940094, 0.009469025f0]
# Finished Any[0.001, 1.0e-10, :celu, 256, 100, 64, :poly, 123.05211400985718, 0.008425333f0]
# Finished Any[0.001, 1.0e-10, :celu, 512, 100, 64, :poly, 255.57281684875488, 0.011553445f0]
# Finished Any[0.001, 1.0e-10, :celu, 1024, 100, 64, :poly, 184.74529004096985, 0.01403696f0]
# Finished Any[0.001, 1.0e-10, :tanh, 256, 100, 128, :poly, 485.4806730747223, 0.006409546f0]
# Finished Any[0.001, 1.0e-10, :tanh, 512, 100, 128, :poly, 335.3510570526123, 0.009440255f0]
# Finished Any[0.001, 1.0e-10, :tanh, 1024, 100, 128, :poly, 255.51969480514526, 0.011993765f0]
# Finished Any[0.001, 1.0e-10, :relu, 256, 100, 128, :poly, 436.82732701301575, 0.003490948f0]
# Finished Any[0.001, 1.0e-10, :relu, 512, 100, 128, :poly, 305.7830169200897, 0.0043039513f0]
# Finished Any[0.001, 1.0e-10, :relu, 1024, 100, 128, :poly, 244.33404994010925, 0.00542185f0]
# Finished Any[0.001, 1.0e-10, :celu, 256, 100, 128, :poly, 722.4821009635925, 0.0054242457f0]
# Finished Any[0.001, 1.0e-10, :celu, 512, 100, 128, :poly, 311.80046796798706, 0.007692121f0]
# Finished Any[0.001, 1.0e-10, :celu, 1024, 100, 128, :poly, 278.63029313087463, 0.010396528f0]
# Finished Any[0.001, 1.0e-10, :tanh, 256, 100, 256, :poly, 1656.0463230609894, 0.0056912606f0]
# Finished Any[0.001, 1.0e-10, :tanh, 512, 100, 256, :poly, 500.66837191581726, 0.00709356f0]
# Finished Any[0.001, 1.0e-10, :tanh, 1024, 100, 256, :poly, 400.3649890422821, 0.009719013f0]
# Finished Any[0.001, 1.0e-10, :relu, 256, 100, 256, :poly, 649.3027341365814, 0.0020020679f0]
# BSON.@save "post_ADAM.bson" neural_net

# BSON.@load "post_ADAM.bson" neural_net
Expand Down

0 comments on commit 214560b

Please sign in to comment.