Skip to content

Commit

Permalink
Simona/onnx (#161)
Browse files Browse the repository at this point in the history
Creates a copy of the policy for exporting to avoid the interference of
the training and inference loops
  • Loading branch information
spetravic authored Aug 24, 2023
1 parent b28e91b commit 68af9f2
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions emote/extra/onnx_exporter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import copy
import io
import logging
import time
Expand Down Expand Up @@ -152,6 +153,11 @@ def process_pending_exports(self):
item.process(self)

def _trace(self):
with self.scopes.scope("policycopy"):
policy = copy.deepcopy(self.policy)

policy.train(False)

with self.scopes.scope("trace"):
args = []

Expand All @@ -162,15 +168,13 @@ def _trace(self):

args.append(arg)

self.policy.train(False)

# NOTE: This might seem like a good use case for torch.jit.trace,
# but it unfortunately leaks a full copy of the neural network.
# See: https://github.com/pytorch/pytorch/issues/82532

with io.BytesIO() as f:
torch.onnx.export(
model=self.policy,
model=policy,
args=tuple(args),
f=f,
input_names=list(map(lambda pair: pair[0], self.inputs)),
Expand All @@ -182,8 +186,6 @@ def _trace(self):
opset_version=13,
)

self.policy.train(True)

f.seek(0)
model_proto = onnx.load_model(f, onnx.ModelProto)

Expand Down

0 comments on commit 68af9f2

Please sign in to comment.