Skip to content

Commit

Permalink
fix unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Nov 20, 2023
1 parent 7537ca7 commit e3e1001
Showing 1 changed file with 7 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@


def get_opsets_model(filename):
onx = onnx.load(filename)
if isinstance(filename, onnx.ModelProto):
onx = filename
else:
onx = onnx.load(filename)
return {d.domain: d.version for d in onx.opset_import}


Expand Down Expand Up @@ -1005,9 +1008,9 @@ def test_save_ort_format():
assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "optimizer_model.ort"))
base_opsets = get_opsets_model(base_model)
training_opsets = get_opsets_model("training_model.onnx")
eval_opsets = get_opsets_model("eval_model.onnx")
optimizer_opsets = get_opsets_model("optimizer_model.onnx")
training_opsets = get_opsets_model(os.path.join(temp_dir, "training_model.onnx"))
eval_opsets = get_opsets_model(os.path.join(temp_dir, "eval_model.onnx"))
optimizer_opsets = get_opsets_model(os.path.join(temp_dir, "optimizer_model.onnx"))
if base_opsets[""] != training_opsets[""]:
raise AssertionError(f"Opsets mismatch {base_opsets['']} != {training_opsets['']}.")
if base_opsets[""] != eval_opsets[""]:
Expand Down

0 comments on commit e3e1001

Please sign in to comment.