From d5418cae2ddaab7f21d4701da6fd5476196f41a6 Mon Sep 17 00:00:00 2001 From: carzh Date: Thu, 6 Jun 2024 18:26:00 +0000 Subject: [PATCH 01/13] added string option to generate artifacts --- .../orttraining/python/training/artifacts.py | 8 +++-- .../python/training/onnxblock/blocks.py | 5 +++- .../training/onnxblock/model_accessor.py | 29 +++++++++++++++++-- .../python/training/onnxblock/onnxblock.py | 5 +++- 4 files changed, 39 insertions(+), 8 deletions(-) diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py index 624b30ffdab3b..40eb5cb8ec15f 100644 --- a/orttraining/orttraining/python/training/artifacts.py +++ b/orttraining/orttraining/python/training/artifacts.py @@ -37,7 +37,7 @@ class OptimType(Enum): def generate_artifacts( - model: onnx.ModelProto, + model: Union[onnx.ModelProto, str], requires_grad: Optional[List[str]] = None, frozen_params: Optional[List[str]] = None, loss: Optional[Union[LossType, onnxblock.Block]] = None, @@ -61,7 +61,8 @@ def generate_artifacts( All generated ModelProtos will use the same opsets defined by *model*. Args: - model: The base model to be used for gradient graph generation. + model: The base model or path to the base model to be used for gradient graph generation. For models >2GB, + use the path to the base model. requires_grad: List of names of model parameters that require gradient computation frozen_params: List of names of model parameters that should be frozen. loss: The loss function enum or onnxblock to be used for training. If None, no loss node is added to the graph. @@ -170,7 +171,8 @@ def build(self, *inputs_to_loss): if custom_op_library is not None else contextlib.nullcontext() ): - _ = training_block(*[output.name for output in model.graph.output]) + loaded_model = onnx.load(model) if isinstance(model, str) else model + _ = training_block(*[output.name for output in loaded_model.graph.output]) training_model, eval_model = training_block.to_model_proto() model_params = training_block.parameters() diff --git a/orttraining/orttraining/python/training/onnxblock/blocks.py b/orttraining/orttraining/python/training/onnxblock/blocks.py index 149d0a360f7d3..d92530d4b21c5 100644 --- a/orttraining/orttraining/python/training/onnxblock/blocks.py +++ b/orttraining/orttraining/python/training/onnxblock/blocks.py @@ -47,7 +47,10 @@ def __call__(self, *args, **kwargs): output = self.build(*args, **kwargs) - onnx.checker.check_model(self.base, True) + if accessor._GLOBAL_ACCESSOR.has_path: + onnx.checker.check_model(accessor._GLOBAL_ACCESSOR.path, True) + else: + onnx.checker.check_model(self.base, True) return output diff --git a/orttraining/orttraining/python/training/onnxblock/model_accessor.py b/orttraining/orttraining/python/training/onnxblock/model_accessor.py index ac7a53a554e0a..98d9be1e30fb1 100644 --- a/orttraining/orttraining/python/training/onnxblock/model_accessor.py +++ b/orttraining/orttraining/python/training/onnxblock/model_accessor.py @@ -6,6 +6,7 @@ import copy import os from contextlib import contextmanager +from typing import Union import onnx @@ -14,11 +15,17 @@ class ModelAccessor: """This class stores the onnx model that is manipulated by the onnx blocks. Attributes: - model: The onnx model that is manipulated by the onnx blocks. + model: The onnx model or path to the model that is manipulated by the onnx blocks. """ - def __init__(self, model: onnx.ModelProto): - self._model = model + def __init__(self, model: Union[onnx.ModelProto, str]): + if isinstance(model, onnx.ModelProto): + self._model = model + elif isinstance(model, str): + self._path = model + self._model = onnx.load(model) + else: + raise RuntimeError("Please pass in either a file path as string to the base model, or the ModelProto") @property def model(self) -> onnx.ModelProto: @@ -30,6 +37,22 @@ def model(self) -> onnx.ModelProto: ) return self._model + @property + def path(self) -> str: + """ModelAccessor property that gets the path to the base model.""" + + if self._path is None: + raise RuntimeError( + "The path to the onnx model was not set. Please use the context manager onnxblock.onnx_model to create the model and pass in a string." + ) + return self._path + + @property + def has_path(self) -> bool: + """Returns True if ModelAccessor has a path to a model, False otherwise.""" + + return self._path is not None + # These variable resides in the global namespace. # Different methods can access this global model and manipulate it. diff --git a/orttraining/orttraining/python/training/onnxblock/onnxblock.py b/orttraining/orttraining/python/training/onnxblock/onnxblock.py index a2922353ac70e..0445d873ca5cc 100644 --- a/orttraining/orttraining/python/training/onnxblock/onnxblock.py +++ b/orttraining/orttraining/python/training/onnxblock/onnxblock.py @@ -187,7 +187,10 @@ def __call__(self, *args, **kwargs): output = self.build(*args, **kwargs) - model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model) + if accessor._GLOBAL_ACCESSOR.has_path: + model = onnx.shape_inference.infer_shapes_path(accessor._GLOBAL_ACCESSOR.path) + else: + model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model) _graph_utils.register_graph_outputs(model, output) From 924cedae736df48e10a36dc46f333530a3985235 Mon Sep 17 00:00:00 2001 From: carzh Date: Thu, 6 Jun 2024 20:23:09 +0000 Subject: [PATCH 02/13] switched to passing both the modelproto + modelpath to the model accessor, so that onnx.load is called twice instead of three times --- .../orttraining/python/training/artifacts.py | 7 +++--- .../training/onnxblock/model_accessor.py | 22 ++++++++----------- .../python/training/onnxblock/onnxblock.py | 12 ++++++++-- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py index 40eb5cb8ec15f..ddef0a035f7d6 100644 --- a/orttraining/orttraining/python/training/artifacts.py +++ b/orttraining/orttraining/python/training/artifacts.py @@ -160,18 +160,19 @@ def build(self, *inputs_to_loss): training_model = None eval_model = None model_params = None + loaded_model = onnx.load(model) if isinstance(model, str) else model + model_path = model if isinstance(model, str) else None custom_op_library_path = None if custom_op_library is not None: logging.info("Custom op library provided: %s", custom_op_library) custom_op_library_path = pathlib.Path(custom_op_library) - with onnxblock.base(model), ( + with onnxblock.base(loaded_model, model_path), ( onnxblock.custom_op_library(custom_op_library_path) if custom_op_library is not None else contextlib.nullcontext() ): - loaded_model = onnx.load(model) if isinstance(model, str) else model _ = training_block(*[output.name for output in loaded_model.graph.output]) training_model, eval_model = training_block.to_model_proto() model_params = training_block.parameters() @@ -222,7 +223,7 @@ def _export_to_ort_format(model_path, output_dir, ort_format, custom_op_library_ return opset_version = None - for domain in model.opset_import: + for domain in loaded_model.opset_import: if domain.domain == "" or domain.domain == "ai.onnx": opset_version = domain.version break diff --git a/orttraining/orttraining/python/training/onnxblock/model_accessor.py b/orttraining/orttraining/python/training/onnxblock/model_accessor.py index 98d9be1e30fb1..53a50acce134d 100644 --- a/orttraining/orttraining/python/training/onnxblock/model_accessor.py +++ b/orttraining/orttraining/python/training/onnxblock/model_accessor.py @@ -6,7 +6,6 @@ import copy import os from contextlib import contextmanager -from typing import Union import onnx @@ -15,17 +14,13 @@ class ModelAccessor: """This class stores the onnx model that is manipulated by the onnx blocks. Attributes: - model: The onnx model or path to the model that is manipulated by the onnx blocks. + model: The onnx model that is manipulated by the onnx blocks. + model_path: The path to the base model. Can be None. """ - def __init__(self, model: Union[onnx.ModelProto, str]): - if isinstance(model, onnx.ModelProto): - self._model = model - elif isinstance(model, str): - self._path = model - self._model = onnx.load(model) - else: - raise RuntimeError("Please pass in either a file path as string to the base model, or the ModelProto") + def __init__(self, model: onnx.ModelProto, model_path: str): + self._model = model + self._path = model_path @property def model(self) -> onnx.ModelProto: @@ -62,7 +57,7 @@ def has_path(self) -> bool: @contextmanager -def base(model: onnx.ModelProto): +def base(model: onnx.ModelProto, model_path: str): """Registers the base model to be manipulated by the onnx blocks. Example: @@ -76,6 +71,7 @@ def base(model: onnx.ModelProto): Args: model: The base model to be manipulated by the onnx blocks. + model_path: The path to the base model. None if there is no model path to pass in. Returns: ModelAccessor: The model accessor that contains the modified model. @@ -92,7 +88,7 @@ def base(model: onnx.ModelProto): "model from scratch." ) - _GLOBAL_ACCESSOR = ModelAccessor(model_clone) + _GLOBAL_ACCESSOR = ModelAccessor(model_clone, model_path) try: yield _GLOBAL_ACCESSOR finally: @@ -135,7 +131,7 @@ def empty_base(opset_version: int | None = None): ) ) - _GLOBAL_ACCESSOR = ModelAccessor(model) + _GLOBAL_ACCESSOR = ModelAccessor(model, None) try: yield _GLOBAL_ACCESSOR finally: diff --git a/orttraining/orttraining/python/training/onnxblock/onnxblock.py b/orttraining/orttraining/python/training/onnxblock/onnxblock.py index 0445d873ca5cc..5deddc2c04c96 100644 --- a/orttraining/orttraining/python/training/onnxblock/onnxblock.py +++ b/orttraining/orttraining/python/training/onnxblock/onnxblock.py @@ -70,7 +70,12 @@ def __call__(self, *args, **kwargs): output = self.build(*args, **kwargs) - self._model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model) + if accessor._GLOBAL_ACCESSOR.has_path: + onnx.shape_inference.infer_shapes_path(accessor._GLOBAL_ACCESSOR.path) + # shape inferenced model is saved to original path + self._model = onnx.load(accessor._GLOBAL_ACCESSOR.path) + else: + self._model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model) _graph_utils.register_graph_outputs(self._model, output) @@ -187,8 +192,11 @@ def __call__(self, *args, **kwargs): output = self.build(*args, **kwargs) + model = None if accessor._GLOBAL_ACCESSOR.has_path: - model = onnx.shape_inference.infer_shapes_path(accessor._GLOBAL_ACCESSOR.path) + onnx.shape_inference.infer_shapes_path(accessor._GLOBAL_ACCESSOR.path) + # shape inferenced model is saved to original path + model = onnx.load(accessor._GLOBAL_ACCESSOR.path) else: model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model) From e45849213c33cf83f8ceae1eac87211a04919758 Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 12 Jun 2024 14:31:00 -0700 Subject: [PATCH 03/13] wrote unit tests + added check for large files + fixed some bugs --- .../orttraining/python/training/artifacts.py | 23 +++- .../python/training/onnxblock/_graph_utils.py | 1 + .../python/training/onnxblock/loss/loss.py | 10 +- .../training/onnxblock/model_accessor.py | 5 +- .../python/training/onnxblock/onnxblock.py | 27 ++++- .../orttraining_test_ort_apis_onnxblock.py | 113 ++++++++++++++++++ 6 files changed, 165 insertions(+), 14 deletions(-) diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py index ddef0a035f7d6..c98e5bcd97092 100644 --- a/orttraining/orttraining/python/training/artifacts.py +++ b/orttraining/orttraining/python/training/artifacts.py @@ -13,6 +13,9 @@ from onnxruntime.tools.convert_onnx_models_to_ort import OptimizationStyle, convert_onnx_models_to_ort from onnxruntime.training import onnxblock +# threshold for the size of the modelproto where you should use a path instead +USE_PATH_THRESHOLD = 2147483648 + class LossType(Enum): """Loss type to be added to the training model. @@ -61,7 +64,7 @@ def generate_artifacts( All generated ModelProtos will use the same opsets defined by *model*. Args: - model: The base model or path to the base model to be used for gradient graph generation. For models >2GB, + model: The base model or path to the base model to be used for gradient graph generation. For models >2GB, use the path to the base model. requires_grad: List of names of model parameters that require gradient computation frozen_params: List of names of model parameters that should be frozen. @@ -87,6 +90,22 @@ def generate_artifacts( RuntimeError: If the optimizer provided is not one of the supported optimizers. """ + loaded_model = None + model_path = None + + if isinstance(model, str): + loaded_model = onnx.load(model) + model_path = model + elif isinstance(model, onnx.ModelProto): + if model.ByteSize() > USE_PATH_THRESHOLD: + # infer_shapes and check_model from ONNX both require paths to be used for >2GB models. + raise RuntimeError("This model is > 2GB. Please pass in a path to the ONNX file instead.") + + loaded_model = model + model_path = None + else: + raise RuntimeError("Please pass in either a string or an ONNX ModelProto for the model.") + loss_blocks = { LossType.MSELoss: onnxblock.loss.MSELoss, LossType.CrossEntropyLoss: onnxblock.loss.CrossEntropyLoss, @@ -160,8 +179,6 @@ def build(self, *inputs_to_loss): training_model = None eval_model = None model_params = None - loaded_model = onnx.load(model) if isinstance(model, str) else model - model_path = model if isinstance(model, str) else None custom_op_library_path = None if custom_op_library is not None: diff --git a/orttraining/orttraining/python/training/onnxblock/_graph_utils.py b/orttraining/orttraining/python/training/onnxblock/_graph_utils.py index 42743a4200d17..3e9eb3812129f 100644 --- a/orttraining/orttraining/python/training/onnxblock/_graph_utils.py +++ b/orttraining/orttraining/python/training/onnxblock/_graph_utils.py @@ -85,3 +85,4 @@ def node_arg_exists(model: onnx.ModelProto, node_arg_name: str) -> bool: return True return False + diff --git a/orttraining/orttraining/python/training/onnxblock/loss/loss.py b/orttraining/orttraining/python/training/onnxblock/loss/loss.py index e719301e13f48..52194a7e0f06f 100644 --- a/orttraining/orttraining/python/training/onnxblock/loss/loss.py +++ b/orttraining/orttraining/python/training/onnxblock/loss/loss.py @@ -95,17 +95,17 @@ def build(self, scores_input_name: str, labels_name: str = "labels"): labels_input.type.tensor_type.elem_type = onnx.TensorProto.INT64 # If the predictions are (num_examples x num_classes) # labels should be (num_examples,) - del labels_input.type.tensor_type.shape.dim[1] + del labels_input.type.tensor_type.shape.dim[-1] self.base.graph.input.append(labels_input) loss_node_input_names = [scores_input_name, labels_name] if self._weight: loss_node_input_names.append(weight_name) + loss_node_output_name = _graph_utils.generate_graph_name("loss") - loss_node_output_names = [ - loss_node_output_name, - _graph_utils.generate_graph_name("log_prob"), - ] + log_prob_output_name = _graph_utils.generate_graph_name("log_prob") + + loss_node_output_names = [loss_node_output_name, log_prob_output_name] loss_node = onnx.helper.make_node( "SoftmaxCrossEntropyLoss", loss_node_input_names, diff --git a/orttraining/orttraining/python/training/onnxblock/model_accessor.py b/orttraining/orttraining/python/training/onnxblock/model_accessor.py index 53a50acce134d..7962e97e5a4bb 100644 --- a/orttraining/orttraining/python/training/onnxblock/model_accessor.py +++ b/orttraining/orttraining/python/training/onnxblock/model_accessor.py @@ -6,6 +6,7 @@ import copy import os from contextlib import contextmanager +from typing import Optional import onnx @@ -18,7 +19,7 @@ class ModelAccessor: model_path: The path to the base model. Can be None. """ - def __init__(self, model: onnx.ModelProto, model_path: str): + def __init__(self, model: onnx.ModelProto, model_path: Optional[str] = None): self._model = model self._path = model_path @@ -57,7 +58,7 @@ def has_path(self) -> bool: @contextmanager -def base(model: onnx.ModelProto, model_path: str): +def base(model: onnx.ModelProto, model_path: Optional[str] = None): """Registers the base model to be manipulated by the onnx blocks. Example: diff --git a/orttraining/orttraining/python/training/onnxblock/onnxblock.py b/orttraining/orttraining/python/training/onnxblock/onnxblock.py index 5deddc2c04c96..d90310a9f9f6b 100644 --- a/orttraining/orttraining/python/training/onnxblock/onnxblock.py +++ b/orttraining/orttraining/python/training/onnxblock/onnxblock.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import logging +import os from abc import abstractmethod from typing import List, Tuple @@ -12,6 +13,8 @@ import onnxruntime.training.onnxblock.blocks as blocks import onnxruntime.training.onnxblock.model_accessor as accessor +TEMP_ONNX_PATH = "temp.onnx" + class ForwardBlock(blocks.Block): """Base class for all blocks that require forward model to be automatically built. @@ -71,9 +74,17 @@ def __call__(self, *args, **kwargs): output = self.build(*args, **kwargs) if accessor._GLOBAL_ACCESSOR.has_path: - onnx.shape_inference.infer_shapes_path(accessor._GLOBAL_ACCESSOR.path) + onnx.save(self.base, TEMP_ONNX_PATH, save_as_external_data=True, all_tensors_to_one_file=True) + + onnx.shape_inference.infer_shapes_path(TEMP_ONNX_PATH) # shape inferenced model is saved to original path - self._model = onnx.load(accessor._GLOBAL_ACCESSOR.path) + self._model = onnx.load(TEMP_ONNX_PATH) + + # clean-up temp files + if os.path.exists(TEMP_ONNX_PATH): + os.remove(TEMP_ONNX_PATH) + if os.path.exists(TEMP_ONNX_PATH): + os.remove(TEMP_ONNX_PATH) else: self._model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model) @@ -194,9 +205,17 @@ def __call__(self, *args, **kwargs): model = None if accessor._GLOBAL_ACCESSOR.has_path: - onnx.shape_inference.infer_shapes_path(accessor._GLOBAL_ACCESSOR.path) + onnx.save(self.base, TEMP_ONNX_PATH, save_as_external_data=True, all_tensors_to_one_file=True) + + onnx.shape_inference.infer_shapes_path(TEMP_ONNX_PATH) # shape inferenced model is saved to original path - model = onnx.load(accessor._GLOBAL_ACCESSOR.path) + model = onnx.load(TEMP_ONNX_PATH) + + # clean-up temp files + if os.path.exists(TEMP_ONNX_PATH): + os.remove(TEMP_ONNX_PATH) + if os.path.exists(TEMP_ONNX_PATH): + os.remove(TEMP_ONNX_PATH) else: model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model) diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py index ac49c1c2834c7..1b559456a31d7 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py @@ -1,6 +1,7 @@ import copy import io import os +import pathlib import random import tempfile @@ -206,6 +207,32 @@ def _get_training_ort_inputs(x, target, pt_model, onnx_model, target_type=None): return ort_inputs +def get_root_ort_path(): + curr_dir = pathlib.Path.cwd() + + if curr_dir.name == "onnxruntime": + return curr_dir + + for parent in list(curr_dir.parents): + if parent.name == "onnxruntime": + return parent + + raise RuntimeError("Cannot find ONNXRuntime root directory.") + + +def get_string_path_to_testdata_onnx_file(onnx_file_name): + from_build_dir = pathlib.Path(f"testdata/{onnx_file_name}") + + if from_build_dir.is_file(): + return str(from_build_dir) + else: + ort_root = get_root_ort_path() + path_to_testdata_onnx_file = ort_root / "onnxruntime" / "test" / "testdata" / f"{onnx_file_name}" + if path_to_testdata_onnx_file.is_file(): + return str(path_to_testdata_onnx_file) + raise RuntimeError(f"Cannot find the path for {onnx_file_name}") + + # All unit tests @@ -318,6 +345,8 @@ def crossentropy_loss(prediction, target): ort_outs = ort_session.run(ort_output_names, ort_inputs) torch_outs = crossentropy_loss(pt_model(x), target) + print("ort_ outs", ort_outs) + print("torch outs", torch_outs) # Then assert np.allclose(ort_outs[0], _to_numpy(torch_outs)) @@ -356,6 +385,8 @@ def bcewithlogits_loss(prediction, target): ort_outs = ort_session.run(ort_output_names, ort_inputs) torch_outs = bcewithlogits_loss(pt_model(x), target) + print("ort_ outs", ort_outs) + print("torch outs", torch_outs) # Then assert np.allclose(ort_outs[0], _to_numpy(torch_outs)) @@ -1099,3 +1130,85 @@ def test_custom_optimizer_block(): for attr in node.attribute: if attr.name == "weight_decay": assert attr.f == weight_decay + + +def test_generate_artifacts_path(): + + with tempfile.TemporaryDirectory() as temp_dir: + _, simple_net = _get_models("cpu", 32, 28, 10, 10) + + requires_grad_params = ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"] + + onnx.save_model( + simple_net, + os.path.join(temp_dir, "simple_net.onnx"), + ) + + artifacts.generate_artifacts( + os.path.join(temp_dir, "simple_net.onnx"), + requires_grad=requires_grad_params, + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + artifact_directory=temp_dir, + ) + + assert os.path.exists(os.path.join(temp_dir, "training_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "checkpoint")) + + +def test_generate_artifacts_external_data_one_file(): + with tempfile.TemporaryDirectory() as temp_dir: + _, simple_net = _get_models("cpu", 32, 28, 10, 10) + + requires_grad_params = ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"] + + onnx.save_model( + simple_net, + os.path.join(temp_dir, "simple_net.onnx"), + save_as_external_data=True, + all_tensors_to_one_file=True, + size_threshold=0, + ) + + artifacts.generate_artifacts( + os.path.join(temp_dir, "simple_net.onnx"), + requires_grad=requires_grad_params, + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + artifact_directory=temp_dir, + ) + + assert os.path.exists(os.path.join(temp_dir, "training_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "checkpoint")) + + +def test_generate_artifacts_external_data_separate_files(): + with tempfile.TemporaryDirectory() as temp_dir: + _, simple_net = _get_models("cpu", 32, 28, 10, 10) + + requires_grad_params = ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"] + + onnx.save_model( + simple_net, + os.path.join(temp_dir, "simple_net.onnx"), + save_as_external_data=True, + all_tensors_to_one_file=False, + size_threshold=0, + ) + + artifacts.generate_artifacts( + os.path.join(temp_dir, "simple_net.onnx"), + requires_grad=requires_grad_params, + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + artifact_directory=temp_dir, + ) + + assert os.path.exists(os.path.join(temp_dir, "training_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "checkpoint")) From 8069700f7a4fcdfac468a7fa7f5c88ab6c400399 Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 12 Jun 2024 14:35:15 -0700 Subject: [PATCH 04/13] removed unnecessary helper functions + import from unit tests --- .../orttraining_test_ort_apis_onnxblock.py | 27 ------------------- 1 file changed, 27 deletions(-) diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py index 1b559456a31d7..b374d231cf809 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py @@ -1,7 +1,6 @@ import copy import io import os -import pathlib import random import tempfile @@ -207,32 +206,6 @@ def _get_training_ort_inputs(x, target, pt_model, onnx_model, target_type=None): return ort_inputs -def get_root_ort_path(): - curr_dir = pathlib.Path.cwd() - - if curr_dir.name == "onnxruntime": - return curr_dir - - for parent in list(curr_dir.parents): - if parent.name == "onnxruntime": - return parent - - raise RuntimeError("Cannot find ONNXRuntime root directory.") - - -def get_string_path_to_testdata_onnx_file(onnx_file_name): - from_build_dir = pathlib.Path(f"testdata/{onnx_file_name}") - - if from_build_dir.is_file(): - return str(from_build_dir) - else: - ort_root = get_root_ort_path() - path_to_testdata_onnx_file = ort_root / "onnxruntime" / "test" / "testdata" / f"{onnx_file_name}" - if path_to_testdata_onnx_file.is_file(): - return str(path_to_testdata_onnx_file) - raise RuntimeError(f"Cannot find the path for {onnx_file_name}") - - # All unit tests From 29ad02a084a0fbd17f3d83549823c65f89e8cc3b Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 12 Jun 2024 22:48:33 +0000 Subject: [PATCH 05/13] fine I'll use X|Y type annotations --- .../orttraining/python/training/onnxblock/model_accessor.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/orttraining/orttraining/python/training/onnxblock/model_accessor.py b/orttraining/orttraining/python/training/onnxblock/model_accessor.py index 7962e97e5a4bb..302573064be6e 100644 --- a/orttraining/orttraining/python/training/onnxblock/model_accessor.py +++ b/orttraining/orttraining/python/training/onnxblock/model_accessor.py @@ -6,7 +6,6 @@ import copy import os from contextlib import contextmanager -from typing import Optional import onnx @@ -19,7 +18,7 @@ class ModelAccessor: model_path: The path to the base model. Can be None. """ - def __init__(self, model: onnx.ModelProto, model_path: Optional[str] = None): + def __init__(self, model: onnx.ModelProto, model_path: str | None = None): self._model = model self._path = model_path @@ -58,7 +57,7 @@ def has_path(self) -> bool: @contextmanager -def base(model: onnx.ModelProto, model_path: Optional[str] = None): +def base(model: onnx.ModelProto, model_path: str | None = None): """Registers the base model to be manipulated by the onnx blocks. Example: From a35bc8049f6ab429e687da213890972a88ab7a2e Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 12 Jun 2024 23:04:09 +0000 Subject: [PATCH 06/13] lintrunner format --- .../orttraining/python/training/onnxblock/_graph_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/orttraining/orttraining/python/training/onnxblock/_graph_utils.py b/orttraining/orttraining/python/training/onnxblock/_graph_utils.py index 3e9eb3812129f..42743a4200d17 100644 --- a/orttraining/orttraining/python/training/onnxblock/_graph_utils.py +++ b/orttraining/orttraining/python/training/onnxblock/_graph_utils.py @@ -85,4 +85,3 @@ def node_arg_exists(model: onnx.ModelProto, node_arg_name: str) -> bool: return True return False - From 72a55c7a1a6d264aa5aa068dce26c9df95aa0264 Mon Sep 17 00:00:00 2001 From: carzh Date: Thu, 13 Jun 2024 07:03:12 +0000 Subject: [PATCH 07/13] removed some print statements + cleaned up some code --- .../python/training/onnxblock/blocks.py | 47 ++++++++++++++++++- .../python/training/onnxblock/onnxblock.py | 34 +------------- .../orttraining_test_ort_apis_onnxblock.py | 4 -- 3 files changed, 47 insertions(+), 38 deletions(-) diff --git a/orttraining/orttraining/python/training/onnxblock/blocks.py b/orttraining/orttraining/python/training/onnxblock/blocks.py index d92530d4b21c5..f1bb2cb84bea4 100644 --- a/orttraining/orttraining/python/training/onnxblock/blocks.py +++ b/orttraining/orttraining/python/training/onnxblock/blocks.py @@ -4,6 +4,7 @@ import contextlib import copy import logging +import os from abc import ABC, abstractmethod from typing import Any, List, Optional @@ -28,8 +29,9 @@ class Block(ABC): base (onnx.ModelProto): The base model that the subclass can manipulate. """ - def __init__(self): + def __init__(self, temp_file_name="temp.onnx"): self.base = None + self.temp_onnx_path = temp_file_name @abstractmethod def build(self, *args, **kwargs): @@ -48,12 +50,53 @@ def __call__(self, *args, **kwargs): output = self.build(*args, **kwargs) if accessor._GLOBAL_ACCESSOR.has_path: - onnx.checker.check_model(accessor._GLOBAL_ACCESSOR.path, True) + onnx.save( + accessor._GLOBAL_ACCESSOR.model, + self.temp_onnx_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + ) + + onnx.checker.check_model(self.temp_onnx_path, True) + + # clean-up temp files + if os.path.exists(self.temp_onnx_path): + os.remove(self.temp_onnx_path) + if os.path.exists(self.temp_onnx_path): + os.remove(self.temp_onnx_path) else: onnx.checker.check_model(self.base, True) return output + def infer_shapes_on_base(self): + """ + Performs shape inference on the global model. If a path was used, then uses the + infer_shapes_path API to support models with external data. + + Returns the shape-inferenced ModelProto. + """ + if accessor._GLOBAL_ACCESSOR.has_path: + onnx.save( + accessor._GLOBAL_ACCESSOR.model, + self.temp_onnx_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + ) + + onnx.shape_inference.infer_shapes_path(self.temp_onnx_path) + # shape inferenced model is saved to original path + model = onnx.load(self.temp_onnx_path) + + # clean-up temp files + if os.path.exists(self.temp_onnx_path): + os.remove(self.temp_onnx_path) + if os.path.exists(self.temp_onnx_path): + os.remove(self.temp_onnx_path) + return model + else: + return onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model) + class _BinaryOp(Block): def __init__(self, op_name): diff --git a/orttraining/orttraining/python/training/onnxblock/onnxblock.py b/orttraining/orttraining/python/training/onnxblock/onnxblock.py index d90310a9f9f6b..64f7acf4dc02c 100644 --- a/orttraining/orttraining/python/training/onnxblock/onnxblock.py +++ b/orttraining/orttraining/python/training/onnxblock/onnxblock.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import logging -import os from abc import abstractmethod from typing import List, Tuple @@ -13,8 +12,6 @@ import onnxruntime.training.onnxblock.blocks as blocks import onnxruntime.training.onnxblock.model_accessor as accessor -TEMP_ONNX_PATH = "temp.onnx" - class ForwardBlock(blocks.Block): """Base class for all blocks that require forward model to be automatically built. @@ -73,20 +70,7 @@ def __call__(self, *args, **kwargs): output = self.build(*args, **kwargs) - if accessor._GLOBAL_ACCESSOR.has_path: - onnx.save(self.base, TEMP_ONNX_PATH, save_as_external_data=True, all_tensors_to_one_file=True) - - onnx.shape_inference.infer_shapes_path(TEMP_ONNX_PATH) - # shape inferenced model is saved to original path - self._model = onnx.load(TEMP_ONNX_PATH) - - # clean-up temp files - if os.path.exists(TEMP_ONNX_PATH): - os.remove(TEMP_ONNX_PATH) - if os.path.exists(TEMP_ONNX_PATH): - os.remove(TEMP_ONNX_PATH) - else: - self._model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model) + self._model = self.infer_shapes_on_base() _graph_utils.register_graph_outputs(self._model, output) @@ -203,21 +187,7 @@ def __call__(self, *args, **kwargs): output = self.build(*args, **kwargs) - model = None - if accessor._GLOBAL_ACCESSOR.has_path: - onnx.save(self.base, TEMP_ONNX_PATH, save_as_external_data=True, all_tensors_to_one_file=True) - - onnx.shape_inference.infer_shapes_path(TEMP_ONNX_PATH) - # shape inferenced model is saved to original path - model = onnx.load(TEMP_ONNX_PATH) - - # clean-up temp files - if os.path.exists(TEMP_ONNX_PATH): - os.remove(TEMP_ONNX_PATH) - if os.path.exists(TEMP_ONNX_PATH): - os.remove(TEMP_ONNX_PATH) - else: - model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model) + model = self.infer_shapes_on_base() _graph_utils.register_graph_outputs(model, output) diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py index b374d231cf809..4a4d669d97ff8 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py @@ -318,8 +318,6 @@ def crossentropy_loss(prediction, target): ort_outs = ort_session.run(ort_output_names, ort_inputs) torch_outs = crossentropy_loss(pt_model(x), target) - print("ort_ outs", ort_outs) - print("torch outs", torch_outs) # Then assert np.allclose(ort_outs[0], _to_numpy(torch_outs)) @@ -358,8 +356,6 @@ def bcewithlogits_loss(prediction, target): ort_outs = ort_session.run(ort_output_names, ort_inputs) torch_outs = bcewithlogits_loss(pt_model(x), target) - print("ort_ outs", ort_outs) - print("torch outs", torch_outs) # Then assert np.allclose(ort_outs[0], _to_numpy(torch_outs)) From 378afa8a09f903b39268ed0ed17629e522c01700 Mon Sep 17 00:00:00 2001 From: carzh Date: Mon, 17 Jun 2024 21:01:53 +0000 Subject: [PATCH 08/13] updated handling of temp files + added comment on crossentropyloss update --- .../orttraining/python/training/artifacts.py | 2 ++ .../python/training/onnxblock/blocks.py | 36 ++++++++++--------- .../python/training/onnxblock/loss/loss.py | 5 +-- 3 files changed, 24 insertions(+), 19 deletions(-) diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py index c98e5bcd97092..a9cad4e33209d 100644 --- a/orttraining/orttraining/python/training/artifacts.py +++ b/orttraining/orttraining/python/training/artifacts.py @@ -268,3 +268,5 @@ def _export_to_ort_format(model_path, output_dir, ort_format, custom_op_library_ onnx.save(optim_model, optimizer_model_path) _export_to_ort_format(optimizer_model_path, artifact_directory, ort_format, custom_op_library_path) logging.info("Saved optimizer model to %s", optimizer_model_path) + + training_block.release() diff --git a/orttraining/orttraining/python/training/onnxblock/blocks.py b/orttraining/orttraining/python/training/onnxblock/blocks.py index f1bb2cb84bea4..00a6278429ff7 100644 --- a/orttraining/orttraining/python/training/onnxblock/blocks.py +++ b/orttraining/orttraining/python/training/onnxblock/blocks.py @@ -31,7 +31,7 @@ class Block(ABC): def __init__(self, temp_file_name="temp.onnx"): self.base = None - self.temp_onnx_path = temp_file_name + self.temp_onnx_file_name = temp_file_name @abstractmethod def build(self, *args, **kwargs): @@ -50,20 +50,16 @@ def __call__(self, *args, **kwargs): output = self.build(*args, **kwargs) if accessor._GLOBAL_ACCESSOR.has_path: + temp_external_path = self.temp_onnx_file_name + ".data" onnx.save( accessor._GLOBAL_ACCESSOR.model, - self.temp_onnx_path, + self.temp_onnx_file_name, save_as_external_data=True, all_tensors_to_one_file=True, + location=temp_external_path, ) - onnx.checker.check_model(self.temp_onnx_path, True) - - # clean-up temp files - if os.path.exists(self.temp_onnx_path): - os.remove(self.temp_onnx_path) - if os.path.exists(self.temp_onnx_path): - os.remove(self.temp_onnx_path) + onnx.checker.check_model(self.temp_onnx_file_name, True) else: onnx.checker.check_model(self.base, True) @@ -77,26 +73,32 @@ def infer_shapes_on_base(self): Returns the shape-inferenced ModelProto. """ if accessor._GLOBAL_ACCESSOR.has_path: + temp_external_path = self.temp_onnx_file_name + ".data" onnx.save( accessor._GLOBAL_ACCESSOR.model, - self.temp_onnx_path, + self.temp_onnx_file_name, save_as_external_data=True, all_tensors_to_one_file=True, + location=temp_external_path, ) - onnx.shape_inference.infer_shapes_path(self.temp_onnx_path) + onnx.shape_inference.infer_shapes_path(self.temp_onnx_file_name) # shape inferenced model is saved to original path - model = onnx.load(self.temp_onnx_path) + model = onnx.load(self.temp_onnx_file_name) - # clean-up temp files - if os.path.exists(self.temp_onnx_path): - os.remove(self.temp_onnx_path) - if os.path.exists(self.temp_onnx_path): - os.remove(self.temp_onnx_path) return model else: return onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model) + def release(self): + # since the ModelProto does not store the external data parameters themselves, just the metadata + # for where the external data can be found, we retain the external data files for the intermediate + # calls until the Block no longer needs to be used. + if os.path.exists(self.temp_onnx_file_name): + os.remove(self.temp_onnx_file_name) + if os.path.exists(self.temp_onnx_file_name + ".data"): + os.remove(self.temp_onnx_file_name + ".data") + class _BinaryOp(Block): def __init__(self, op_name): diff --git a/orttraining/orttraining/python/training/onnxblock/loss/loss.py b/orttraining/orttraining/python/training/onnxblock/loss/loss.py index 52194a7e0f06f..09429dd844187 100644 --- a/orttraining/orttraining/python/training/onnxblock/loss/loss.py +++ b/orttraining/orttraining/python/training/onnxblock/loss/loss.py @@ -93,8 +93,9 @@ def build(self, scores_input_name: str, labels_name: str = "labels"): labels_input = copy.deepcopy(_graph_utils.get_output_from_output_name(self.base, scores_input_name)) labels_input.name = labels_name labels_input.type.tensor_type.elem_type = onnx.TensorProto.INT64 - # If the predictions are (num_examples x num_classes) - # labels should be (num_examples,) + # Assumes classes is the last dimension + # e.g., predictions: (num_examples, num_classes) -> labels: (num_examples,) + # or predictions: (batch_size, seq_len, vocab) -> labels: (batch_size, seq_len) del labels_input.type.tensor_type.shape.dim[-1] self.base.graph.input.append(labels_input) From 44acd08776bab8513b67f4bfd7df84853e9be758 Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 19 Jun 2024 18:31:36 +0000 Subject: [PATCH 09/13] changed release method to __del__ & added comments to explain that generate_artifacts typically will throw an error if unsuccessful --- orttraining/orttraining/python/training/artifacts.py | 2 -- orttraining/orttraining/python/training/onnxblock/blocks.py | 2 +- .../test/python/orttraining_test_ort_apis_onnxblock.py | 2 ++ 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py index a9cad4e33209d..c98e5bcd97092 100644 --- a/orttraining/orttraining/python/training/artifacts.py +++ b/orttraining/orttraining/python/training/artifacts.py @@ -268,5 +268,3 @@ def _export_to_ort_format(model_path, output_dir, ort_format, custom_op_library_ onnx.save(optim_model, optimizer_model_path) _export_to_ort_format(optimizer_model_path, artifact_directory, ort_format, custom_op_library_path) logging.info("Saved optimizer model to %s", optimizer_model_path) - - training_block.release() diff --git a/orttraining/orttraining/python/training/onnxblock/blocks.py b/orttraining/orttraining/python/training/onnxblock/blocks.py index 00a6278429ff7..6296f0b5a0c75 100644 --- a/orttraining/orttraining/python/training/onnxblock/blocks.py +++ b/orttraining/orttraining/python/training/onnxblock/blocks.py @@ -90,7 +90,7 @@ def infer_shapes_on_base(self): else: return onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model) - def release(self): + def __del__(self): # since the ModelProto does not store the external data parameters themselves, just the metadata # for where the external data can be found, we retain the external data files for the intermediate # calls until the Block no longer needs to be used. diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py index 4a4d669d97ff8..52c1a9874d170 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py @@ -1177,6 +1177,8 @@ def test_generate_artifacts_external_data_separate_files(): artifact_directory=temp_dir, ) + # generate_artifacts should have thrown if it didn't complete successfully. + # Below is a sanity check to validate that all the expected files were created. assert os.path.exists(os.path.join(temp_dir, "training_model.onnx")) assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx")) assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx")) From 0f939744fedb12e990407c278f9f391bff52cfaa Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 19 Jun 2024 18:33:42 +0000 Subject: [PATCH 10/13] comments --- .../test/python/orttraining_test_ort_apis_onnxblock.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py index 52c1a9874d170..5c63be92d2b2f 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py @@ -1121,6 +1121,8 @@ def test_generate_artifacts_path(): artifact_directory=temp_dir, ) + # generate_artifacts should have thrown if it didn't complete successfully. + # Below is a sanity check to validate that all the expected files were created. assert os.path.exists(os.path.join(temp_dir, "training_model.onnx")) assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx")) assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx")) @@ -1149,6 +1151,8 @@ def test_generate_artifacts_external_data_one_file(): artifact_directory=temp_dir, ) + # generate_artifacts should have thrown if it didn't complete successfully. + # Below is a sanity check to validate that all the expected files were created. assert os.path.exists(os.path.join(temp_dir, "training_model.onnx")) assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx")) assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx")) From fff6723af0753514e21ab83f4fde955a974846d0 Mon Sep 17 00:00:00 2001 From: carzh Date: Thu, 20 Jun 2024 07:28:52 +0000 Subject: [PATCH 11/13] switched to setting absolute path in constructor for temp onnx file path & relative path for external data --- .../python/training/onnxblock/blocks.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/orttraining/orttraining/python/training/onnxblock/blocks.py b/orttraining/orttraining/python/training/onnxblock/blocks.py index 6296f0b5a0c75..643d4d8105481 100644 --- a/orttraining/orttraining/python/training/onnxblock/blocks.py +++ b/orttraining/orttraining/python/training/onnxblock/blocks.py @@ -31,7 +31,9 @@ class Block(ABC): def __init__(self, temp_file_name="temp.onnx"): self.base = None - self.temp_onnx_file_name = temp_file_name + self.temp_onnx_file_path = os.path.join(os.getcwd(), temp_file_name) + # onnx.save location parameter requires a relative path to the model path + self.temp_external_data_relative_path = temp_file_name + ".data" @abstractmethod def build(self, *args, **kwargs): @@ -50,16 +52,15 @@ def __call__(self, *args, **kwargs): output = self.build(*args, **kwargs) if accessor._GLOBAL_ACCESSOR.has_path: - temp_external_path = self.temp_onnx_file_name + ".data" onnx.save( accessor._GLOBAL_ACCESSOR.model, - self.temp_onnx_file_name, + self.temp_onnx_file_path, save_as_external_data=True, all_tensors_to_one_file=True, - location=temp_external_path, + location=self.temp_external_data_relative_path, ) - onnx.checker.check_model(self.temp_onnx_file_name, True) + onnx.checker.check_model(self.temp_onnx_file_path, True) else: onnx.checker.check_model(self.base, True) @@ -73,18 +74,17 @@ def infer_shapes_on_base(self): Returns the shape-inferenced ModelProto. """ if accessor._GLOBAL_ACCESSOR.has_path: - temp_external_path = self.temp_onnx_file_name + ".data" onnx.save( accessor._GLOBAL_ACCESSOR.model, - self.temp_onnx_file_name, + self.temp_onnx_file_path, save_as_external_data=True, all_tensors_to_one_file=True, - location=temp_external_path, + location=self.temp_external_data_relative_path, ) - onnx.shape_inference.infer_shapes_path(self.temp_onnx_file_name) + onnx.shape_inference.infer_shapes_path(self.temp_onnx_file_path) # shape inferenced model is saved to original path - model = onnx.load(self.temp_onnx_file_name) + model = onnx.load(self.temp_onnx_file_path) return model else: @@ -94,10 +94,10 @@ def __del__(self): # since the ModelProto does not store the external data parameters themselves, just the metadata # for where the external data can be found, we retain the external data files for the intermediate # calls until the Block no longer needs to be used. - if os.path.exists(self.temp_onnx_file_name): - os.remove(self.temp_onnx_file_name) - if os.path.exists(self.temp_onnx_file_name + ".data"): - os.remove(self.temp_onnx_file_name + ".data") + if os.path.exists(self.temp_onnx_file_path): + os.remove(self.temp_onnx_file_path) + if os.path.exists(self.temp_external_data_relative_path): + os.remove(self.temp_external_data_relative_path) class _BinaryOp(Block): From c5a5cc1157515c7c937a09ad29371f4b1267c966 Mon Sep 17 00:00:00 2001 From: carzh Date: Thu, 20 Jun 2024 18:13:04 +0000 Subject: [PATCH 12/13] updated variable name + finds absolute path for external data file for deletion --- .../python/training/onnxblock/blocks.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/orttraining/orttraining/python/training/onnxblock/blocks.py b/orttraining/orttraining/python/training/onnxblock/blocks.py index 643d4d8105481..4c7b4434b2c2b 100644 --- a/orttraining/orttraining/python/training/onnxblock/blocks.py +++ b/orttraining/orttraining/python/training/onnxblock/blocks.py @@ -30,10 +30,12 @@ class Block(ABC): """ def __init__(self, temp_file_name="temp.onnx"): + if (os.path.isabs(temp_file_name)): + raise RuntimeError("Please pass in a relative path for the temp_file_name.") self.base = None self.temp_onnx_file_path = os.path.join(os.getcwd(), temp_file_name) # onnx.save location parameter requires a relative path to the model path - self.temp_external_data_relative_path = temp_file_name + ".data" + self.temp_external_data_file_name = temp_file_name + ".data" @abstractmethod def build(self, *args, **kwargs): @@ -57,7 +59,7 @@ def __call__(self, *args, **kwargs): self.temp_onnx_file_path, save_as_external_data=True, all_tensors_to_one_file=True, - location=self.temp_external_data_relative_path, + location=self.temp_external_data_file_name, ) onnx.checker.check_model(self.temp_onnx_file_path, True) @@ -79,7 +81,7 @@ def infer_shapes_on_base(self): self.temp_onnx_file_path, save_as_external_data=True, all_tensors_to_one_file=True, - location=self.temp_external_data_relative_path, + location=self.temp_external_data_file_name, ) onnx.shape_inference.infer_shapes_path(self.temp_onnx_file_path) @@ -96,8 +98,10 @@ def __del__(self): # calls until the Block no longer needs to be used. if os.path.exists(self.temp_onnx_file_path): os.remove(self.temp_onnx_file_path) - if os.path.exists(self.temp_external_data_relative_path): - os.remove(self.temp_external_data_relative_path) + # get absolute path for the external data file + external_data_file_path = os.path.join(os.path.dirname(self.temp_onnx_file_path), self.temp_external_data_file_name) + if os.path.exists(external_data_file_path): + os.remove(external_data_file_path) class _BinaryOp(Block): From 8fcd74c2eac6ee8bd1f3ea1398d2f5c1aa7b2494 Mon Sep 17 00:00:00 2001 From: carzh Date: Thu, 20 Jun 2024 23:10:22 +0000 Subject: [PATCH 13/13] lintrunner sigh --- orttraining/orttraining/python/training/onnxblock/blocks.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/orttraining/orttraining/python/training/onnxblock/blocks.py b/orttraining/orttraining/python/training/onnxblock/blocks.py index 4c7b4434b2c2b..80f07c3738a7e 100644 --- a/orttraining/orttraining/python/training/onnxblock/blocks.py +++ b/orttraining/orttraining/python/training/onnxblock/blocks.py @@ -30,7 +30,7 @@ class Block(ABC): """ def __init__(self, temp_file_name="temp.onnx"): - if (os.path.isabs(temp_file_name)): + if os.path.isabs(temp_file_name): raise RuntimeError("Please pass in a relative path for the temp_file_name.") self.base = None self.temp_onnx_file_path = os.path.join(os.getcwd(), temp_file_name) @@ -99,7 +99,9 @@ def __del__(self): if os.path.exists(self.temp_onnx_file_path): os.remove(self.temp_onnx_file_path) # get absolute path for the external data file - external_data_file_path = os.path.join(os.path.dirname(self.temp_onnx_file_path), self.temp_external_data_file_name) + external_data_file_path = os.path.join( + os.path.dirname(self.temp_onnx_file_path), self.temp_external_data_file_name + ) if os.path.exists(external_data_file_path): os.remove(external_data_file_path)