Skip to content

Commit

Permalink
Update load_rl_NNs.m
Browse files Browse the repository at this point in the history
"Fix" onnx loading error by loading mat files
  • Loading branch information
mldiego authored Apr 18, 2023
1 parent 51b0d02 commit 9e9a21a
Showing 1 changed file with 20 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,28 @@
t = tic;
for h = 1:length(listNN) % generlize NN loading options for all benchmarks
if endsWith(listNN(h).name, ".onnx")
net = importONNXNetwork(benchmarkFolder+string(listNN(h).name), InputDataFormats="BC");
% transform for matlab
if ~contains(listNN(h).name, "dubins")
Layers = net.Layers([1,4:end-1]);
net = dlnetwork(Layers);
if is_codeocean % error with swing (some internal operations not available when running -nodesktop)
modelname = split(listNN(h).name,'.');
net = load(benchmarkFolder+string(modelname{1})+".mat");
net = net.net;
else
Layers = net.Layers;
ils = [];
for k=1:length(Layers)-1
if isa(Layers(k), "nnet.onnx.layer.ElementwiseAffineLayer")
Layers(k-1).Bias = Layers(k).Offset;
else
ils = [ils k];
net = importONNXNetwork(benchmarkFolder+string(listNN(h).name), InputDataFormats="BC");
% transform for matlab
if ~contains(listNN(h).name, "dubins")
Layers = net.Layers([1,4:end-1]);
net = dlnetwork(Layers);
else
Layers = net.Layers;
ils = [];
for k=1:length(Layers)-1
if isa(Layers(k), "nnet.onnx.layer.ElementwiseAffineLayer")
Layers(k-1).Bias = Layers(k).Offset;
else
ils = [ils k];
end
end
net = dlnetwork(net.Layers(ils));
end
net = dlnetwork(net.Layers(ils));
end
nn = matlab2nnv(net);
% store networks
Expand All @@ -39,11 +45,5 @@
t = toc(t);
names2idxs = containers.Map(names,idxs);
disp("All networks are loaded in " + string(t) + " seconds");
% Remove extra files
try
rmdir +cartpole s
end
try
rmdir +lunarlander s
end

end

0 comments on commit 9e9a21a

Please sign in to comment.