Skip to content

Commit

Permalink
add test to check opsets
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Nov 20, 2023
1 parent d2257cf commit 7537ca7
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
4 changes: 2 additions & 2 deletions orttraining/orttraining/python/training/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def generate_artifacts(
3. Checkpoint (directory): Contains the model parameters.
4. Optimizer model (onnx.ModelProto): Model containing the optimizer graph.
All generated ModelProto are using the same opsets defined by *model*.
All generated ModelProtos will use the same opsets defined by *model*.
Args:
model: The base model to be used for gradient graph generation.
Expand Down Expand Up @@ -211,7 +211,7 @@ def _export_to_ort_format(model_path, output_dir, extra_options):

opset_version = None
for domain in model.opset_import:
if domain.domain == "":
if domain.domain == "" or domain.domain == "ai.onnx":
opset_version = domain.version
break

Expand Down
15 changes: 15 additions & 0 deletions orttraining/orttraining/test/python/orttraining_test_onnxblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
# PyTorch Module definitions


def get_opsets_model(filename):
onx = onnx.load(filename)
return {d.domain: d.version for d in onx.opset_import}


class SimpleNet(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super().__init__()
Expand Down Expand Up @@ -999,3 +1004,13 @@ def test_save_ort_format():
assert os.path.exists(os.path.join(temp_dir, "eval_model.ort"))
assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "optimizer_model.ort"))
base_opsets = get_opsets_model(base_model)
training_opsets = get_opsets_model("training_model.onnx")
eval_opsets = get_opsets_model("eval_model.onnx")
optimizer_opsets = get_opsets_model("optimizer_model.onnx")
if base_opsets[""] != training_opsets[""]:
raise AssertionError(f"Opsets mismatch {base_opsets['']} != {training_opsets['']}.")
if base_opsets[""] != eval_opsets[""]:
raise AssertionError(f"Opsets mismatch {base_opsets['']} != {eval_opsets['']}.")
if base_opsets[""] != optimizer_opsets[""]:
raise AssertionError(f"Opsets mismatch {base_opsets['']} != {optimizer_opsets['']}.")

0 comments on commit 7537ca7

Please sign in to comment.