From e3e100178f284ab7a9deeade00ee09e7d2e2f8a7 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 20 Nov 2023 19:36:42 +0100 Subject: [PATCH] fix unittests --- .../test/python/orttraining_test_onnxblock.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/orttraining/orttraining/test/python/orttraining_test_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_onnxblock.py index 699341e056a0b..6e5d54cbb9427 100644 --- a/orttraining/orttraining/test/python/orttraining_test_onnxblock.py +++ b/orttraining/orttraining/test/python/orttraining_test_onnxblock.py @@ -18,7 +18,10 @@ def get_opsets_model(filename): - onx = onnx.load(filename) + if isinstance(filename, onnx.ModelProto): + onx = filename + else: + onx = onnx.load(filename) return {d.domain: d.version for d in onx.opset_import} @@ -1005,9 +1008,9 @@ def test_save_ort_format(): assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx")) assert os.path.exists(os.path.join(temp_dir, "optimizer_model.ort")) base_opsets = get_opsets_model(base_model) - training_opsets = get_opsets_model("training_model.onnx") - eval_opsets = get_opsets_model("eval_model.onnx") - optimizer_opsets = get_opsets_model("optimizer_model.onnx") + training_opsets = get_opsets_model(os.path.join(temp_dir, "training_model.onnx")) + eval_opsets = get_opsets_model(os.path.join(temp_dir, "eval_model.onnx")) + optimizer_opsets = get_opsets_model(os.path.join(temp_dir, "optimizer_model.onnx")) if base_opsets[""] != training_opsets[""]: raise AssertionError(f"Opsets mismatch {base_opsets['']} != {training_opsets['']}.") if base_opsets[""] != eval_opsets[""]: