From 9e9a21ab7d9e99d9ec2a51dd66bdbd3eb2907013 Mon Sep 17 00:00:00 2001 From: Diego Manzanas Lopez Date: Mon, 17 Apr 2023 23:52:41 -0500 Subject: [PATCH] Update load_rl_NNs.m "Fix" onnx loading error by loading mat files --- .../NNV_vs_MATLAB/rl_benchmarks/load_rl_NNs.m | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/code/nnv/examples/NNV2.0/Submission/CAV2023/NNV_vs_MATLAB/rl_benchmarks/load_rl_NNs.m b/code/nnv/examples/NNV2.0/Submission/CAV2023/NNV_vs_MATLAB/rl_benchmarks/load_rl_NNs.m index 090ecea2ca..988020c996 100644 --- a/code/nnv/examples/NNV2.0/Submission/CAV2023/NNV_vs_MATLAB/rl_benchmarks/load_rl_NNs.m +++ b/code/nnv/examples/NNV2.0/Submission/CAV2023/NNV_vs_MATLAB/rl_benchmarks/load_rl_NNs.m @@ -9,22 +9,28 @@ t = tic; for h = 1:length(listNN) % generlize NN loading options for all benchmarks if endsWith(listNN(h).name, ".onnx") - net = importONNXNetwork(benchmarkFolder+string(listNN(h).name), InputDataFormats="BC"); - % transform for matlab - if ~contains(listNN(h).name, "dubins") - Layers = net.Layers([1,4:end-1]); - net = dlnetwork(Layers); + if is_codeocean % error with swing (some internal operations not available when running -nodesktop) + modelname = split(listNN(h).name,'.'); + net = load(benchmarkFolder+string(modelname{1})+".mat"); + net = net.net; else - Layers = net.Layers; - ils = []; - for k=1:length(Layers)-1 - if isa(Layers(k), "nnet.onnx.layer.ElementwiseAffineLayer") - Layers(k-1).Bias = Layers(k).Offset; - else - ils = [ils k]; + net = importONNXNetwork(benchmarkFolder+string(listNN(h).name), InputDataFormats="BC"); + % transform for matlab + if ~contains(listNN(h).name, "dubins") + Layers = net.Layers([1,4:end-1]); + net = dlnetwork(Layers); + else + Layers = net.Layers; + ils = []; + for k=1:length(Layers)-1 + if isa(Layers(k), "nnet.onnx.layer.ElementwiseAffineLayer") + Layers(k-1).Bias = Layers(k).Offset; + else + ils = [ils k]; + end end + net = dlnetwork(net.Layers(ils)); end - net = dlnetwork(net.Layers(ils)); end nn = matlab2nnv(net); % store networks @@ -39,11 +45,5 @@ t = toc(t); names2idxs = containers.Map(names,idxs); disp("All networks are loaded in " + string(t) + " seconds"); - % Remove extra files - try - rmdir +cartpole s - end - try - rmdir +lunarlander s - end + end