diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py index 624b30ffdab3b..c98e5bcd97092 100644 --- a/orttraining/orttraining/python/training/artifacts.py +++ b/orttraining/orttraining/python/training/artifacts.py @@ -13,6 +13,9 @@ from onnxruntime.tools.convert_onnx_models_to_ort import OptimizationStyle, convert_onnx_models_to_ort from onnxruntime.training import onnxblock +# threshold for the size of the modelproto where you should use a path instead +USE_PATH_THRESHOLD = 2147483648 + class LossType(Enum): """Loss type to be added to the training model. @@ -37,7 +40,7 @@ class OptimType(Enum): def generate_artifacts( - model: onnx.ModelProto, + model: Union[onnx.ModelProto, str], requires_grad: Optional[List[str]] = None, frozen_params: Optional[List[str]] = None, loss: Optional[Union[LossType, onnxblock.Block]] = None, @@ -61,7 +64,8 @@ def generate_artifacts( All generated ModelProtos will use the same opsets defined by *model*. Args: - model: The base model to be used for gradient graph generation. + model: The base model or path to the base model to be used for gradient graph generation. For models >2GB, + use the path to the base model. requires_grad: List of names of model parameters that require gradient computation frozen_params: List of names of model parameters that should be frozen. loss: The loss function enum or onnxblock to be used for training. If None, no loss node is added to the graph. @@ -86,6 +90,22 @@ def generate_artifacts( RuntimeError: If the optimizer provided is not one of the supported optimizers. """ + loaded_model = None + model_path = None + + if isinstance(model, str): + loaded_model = onnx.load(model) + model_path = model + elif isinstance(model, onnx.ModelProto): + if model.ByteSize() > USE_PATH_THRESHOLD: + # infer_shapes and check_model from ONNX both require paths to be used for >2GB models. + raise RuntimeError("This model is > 2GB. Please pass in a path to the ONNX file instead.") + + loaded_model = model + model_path = None + else: + raise RuntimeError("Please pass in either a string or an ONNX ModelProto for the model.") + loss_blocks = { LossType.MSELoss: onnxblock.loss.MSELoss, LossType.CrossEntropyLoss: onnxblock.loss.CrossEntropyLoss, @@ -165,12 +185,12 @@ def build(self, *inputs_to_loss): logging.info("Custom op library provided: %s", custom_op_library) custom_op_library_path = pathlib.Path(custom_op_library) - with onnxblock.base(model), ( + with onnxblock.base(loaded_model, model_path), ( onnxblock.custom_op_library(custom_op_library_path) if custom_op_library is not None else contextlib.nullcontext() ): - _ = training_block(*[output.name for output in model.graph.output]) + _ = training_block(*[output.name for output in loaded_model.graph.output]) training_model, eval_model = training_block.to_model_proto() model_params = training_block.parameters() @@ -220,7 +240,7 @@ def _export_to_ort_format(model_path, output_dir, ort_format, custom_op_library_ return opset_version = None - for domain in model.opset_import: + for domain in loaded_model.opset_import: if domain.domain == "" or domain.domain == "ai.onnx": opset_version = domain.version break diff --git a/orttraining/orttraining/python/training/onnxblock/blocks.py b/orttraining/orttraining/python/training/onnxblock/blocks.py index 149d0a360f7d3..80f07c3738a7e 100644 --- a/orttraining/orttraining/python/training/onnxblock/blocks.py +++ b/orttraining/orttraining/python/training/onnxblock/blocks.py @@ -4,6 +4,7 @@ import contextlib import copy import logging +import os from abc import ABC, abstractmethod from typing import Any, List, Optional @@ -28,8 +29,13 @@ class Block(ABC): base (onnx.ModelProto): The base model that the subclass can manipulate. """ - def __init__(self): + def __init__(self, temp_file_name="temp.onnx"): + if os.path.isabs(temp_file_name): + raise RuntimeError("Please pass in a relative path for the temp_file_name.") self.base = None + self.temp_onnx_file_path = os.path.join(os.getcwd(), temp_file_name) + # onnx.save location parameter requires a relative path to the model path + self.temp_external_data_file_name = temp_file_name + ".data" @abstractmethod def build(self, *args, **kwargs): @@ -47,10 +53,58 @@ def __call__(self, *args, **kwargs): output = self.build(*args, **kwargs) - onnx.checker.check_model(self.base, True) + if accessor._GLOBAL_ACCESSOR.has_path: + onnx.save( + accessor._GLOBAL_ACCESSOR.model, + self.temp_onnx_file_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=self.temp_external_data_file_name, + ) + + onnx.checker.check_model(self.temp_onnx_file_path, True) + else: + onnx.checker.check_model(self.base, True) return output + def infer_shapes_on_base(self): + """ + Performs shape inference on the global model. If a path was used, then uses the + infer_shapes_path API to support models with external data. + + Returns the shape-inferenced ModelProto. + """ + if accessor._GLOBAL_ACCESSOR.has_path: + onnx.save( + accessor._GLOBAL_ACCESSOR.model, + self.temp_onnx_file_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=self.temp_external_data_file_name, + ) + + onnx.shape_inference.infer_shapes_path(self.temp_onnx_file_path) + # shape inferenced model is saved to original path + model = onnx.load(self.temp_onnx_file_path) + + return model + else: + return onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model) + + def __del__(self): + # since the ModelProto does not store the external data parameters themselves, just the metadata + # for where the external data can be found, we retain the external data files for the intermediate + # calls until the Block no longer needs to be used. + if os.path.exists(self.temp_onnx_file_path): + os.remove(self.temp_onnx_file_path) + # get absolute path for the external data file + external_data_file_path = os.path.join( + os.path.dirname(self.temp_onnx_file_path), self.temp_external_data_file_name + ) + if os.path.exists(external_data_file_path): + os.remove(external_data_file_path) + class _BinaryOp(Block): def __init__(self, op_name): diff --git a/orttraining/orttraining/python/training/onnxblock/loss/loss.py b/orttraining/orttraining/python/training/onnxblock/loss/loss.py index e719301e13f48..09429dd844187 100644 --- a/orttraining/orttraining/python/training/onnxblock/loss/loss.py +++ b/orttraining/orttraining/python/training/onnxblock/loss/loss.py @@ -93,19 +93,20 @@ def build(self, scores_input_name: str, labels_name: str = "labels"): labels_input = copy.deepcopy(_graph_utils.get_output_from_output_name(self.base, scores_input_name)) labels_input.name = labels_name labels_input.type.tensor_type.elem_type = onnx.TensorProto.INT64 - # If the predictions are (num_examples x num_classes) - # labels should be (num_examples,) - del labels_input.type.tensor_type.shape.dim[1] + # Assumes classes is the last dimension + # e.g., predictions: (num_examples, num_classes) -> labels: (num_examples,) + # or predictions: (batch_size, seq_len, vocab) -> labels: (batch_size, seq_len) + del labels_input.type.tensor_type.shape.dim[-1] self.base.graph.input.append(labels_input) loss_node_input_names = [scores_input_name, labels_name] if self._weight: loss_node_input_names.append(weight_name) + loss_node_output_name = _graph_utils.generate_graph_name("loss") - loss_node_output_names = [ - loss_node_output_name, - _graph_utils.generate_graph_name("log_prob"), - ] + log_prob_output_name = _graph_utils.generate_graph_name("log_prob") + + loss_node_output_names = [loss_node_output_name, log_prob_output_name] loss_node = onnx.helper.make_node( "SoftmaxCrossEntropyLoss", loss_node_input_names, diff --git a/orttraining/orttraining/python/training/onnxblock/model_accessor.py b/orttraining/orttraining/python/training/onnxblock/model_accessor.py index ac7a53a554e0a..302573064be6e 100644 --- a/orttraining/orttraining/python/training/onnxblock/model_accessor.py +++ b/orttraining/orttraining/python/training/onnxblock/model_accessor.py @@ -15,10 +15,12 @@ class ModelAccessor: Attributes: model: The onnx model that is manipulated by the onnx blocks. + model_path: The path to the base model. Can be None. """ - def __init__(self, model: onnx.ModelProto): + def __init__(self, model: onnx.ModelProto, model_path: str | None = None): self._model = model + self._path = model_path @property def model(self) -> onnx.ModelProto: @@ -30,6 +32,22 @@ def model(self) -> onnx.ModelProto: ) return self._model + @property + def path(self) -> str: + """ModelAccessor property that gets the path to the base model.""" + + if self._path is None: + raise RuntimeError( + "The path to the onnx model was not set. Please use the context manager onnxblock.onnx_model to create the model and pass in a string." + ) + return self._path + + @property + def has_path(self) -> bool: + """Returns True if ModelAccessor has a path to a model, False otherwise.""" + + return self._path is not None + # These variable resides in the global namespace. # Different methods can access this global model and manipulate it. @@ -39,7 +57,7 @@ def model(self) -> onnx.ModelProto: @contextmanager -def base(model: onnx.ModelProto): +def base(model: onnx.ModelProto, model_path: str | None = None): """Registers the base model to be manipulated by the onnx blocks. Example: @@ -53,6 +71,7 @@ def base(model: onnx.ModelProto): Args: model: The base model to be manipulated by the onnx blocks. + model_path: The path to the base model. None if there is no model path to pass in. Returns: ModelAccessor: The model accessor that contains the modified model. @@ -69,7 +88,7 @@ def base(model: onnx.ModelProto): "model from scratch." ) - _GLOBAL_ACCESSOR = ModelAccessor(model_clone) + _GLOBAL_ACCESSOR = ModelAccessor(model_clone, model_path) try: yield _GLOBAL_ACCESSOR finally: @@ -112,7 +131,7 @@ def empty_base(opset_version: int | None = None): ) ) - _GLOBAL_ACCESSOR = ModelAccessor(model) + _GLOBAL_ACCESSOR = ModelAccessor(model, None) try: yield _GLOBAL_ACCESSOR finally: diff --git a/orttraining/orttraining/python/training/onnxblock/onnxblock.py b/orttraining/orttraining/python/training/onnxblock/onnxblock.py index a2922353ac70e..64f7acf4dc02c 100644 --- a/orttraining/orttraining/python/training/onnxblock/onnxblock.py +++ b/orttraining/orttraining/python/training/onnxblock/onnxblock.py @@ -70,7 +70,7 @@ def __call__(self, *args, **kwargs): output = self.build(*args, **kwargs) - self._model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model) + self._model = self.infer_shapes_on_base() _graph_utils.register_graph_outputs(self._model, output) @@ -187,7 +187,7 @@ def __call__(self, *args, **kwargs): output = self.build(*args, **kwargs) - model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model) + model = self.infer_shapes_on_base() _graph_utils.register_graph_outputs(model, output) diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py index ac49c1c2834c7..5c63be92d2b2f 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py @@ -1099,3 +1099,91 @@ def test_custom_optimizer_block(): for attr in node.attribute: if attr.name == "weight_decay": assert attr.f == weight_decay + + +def test_generate_artifacts_path(): + + with tempfile.TemporaryDirectory() as temp_dir: + _, simple_net = _get_models("cpu", 32, 28, 10, 10) + + requires_grad_params = ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"] + + onnx.save_model( + simple_net, + os.path.join(temp_dir, "simple_net.onnx"), + ) + + artifacts.generate_artifacts( + os.path.join(temp_dir, "simple_net.onnx"), + requires_grad=requires_grad_params, + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + artifact_directory=temp_dir, + ) + + # generate_artifacts should have thrown if it didn't complete successfully. + # Below is a sanity check to validate that all the expected files were created. + assert os.path.exists(os.path.join(temp_dir, "training_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "checkpoint")) + + +def test_generate_artifacts_external_data_one_file(): + with tempfile.TemporaryDirectory() as temp_dir: + _, simple_net = _get_models("cpu", 32, 28, 10, 10) + + requires_grad_params = ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"] + + onnx.save_model( + simple_net, + os.path.join(temp_dir, "simple_net.onnx"), + save_as_external_data=True, + all_tensors_to_one_file=True, + size_threshold=0, + ) + + artifacts.generate_artifacts( + os.path.join(temp_dir, "simple_net.onnx"), + requires_grad=requires_grad_params, + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + artifact_directory=temp_dir, + ) + + # generate_artifacts should have thrown if it didn't complete successfully. + # Below is a sanity check to validate that all the expected files were created. + assert os.path.exists(os.path.join(temp_dir, "training_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "checkpoint")) + + +def test_generate_artifacts_external_data_separate_files(): + with tempfile.TemporaryDirectory() as temp_dir: + _, simple_net = _get_models("cpu", 32, 28, 10, 10) + + requires_grad_params = ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"] + + onnx.save_model( + simple_net, + os.path.join(temp_dir, "simple_net.onnx"), + save_as_external_data=True, + all_tensors_to_one_file=False, + size_threshold=0, + ) + + artifacts.generate_artifacts( + os.path.join(temp_dir, "simple_net.onnx"), + requires_grad=requires_grad_params, + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + artifact_directory=temp_dir, + ) + + # generate_artifacts should have thrown if it didn't complete successfully. + # Below is a sanity check to validate that all the expected files were created. + assert os.path.exists(os.path.join(temp_dir, "training_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "checkpoint"))