Skip to content

Commit

Permalink
Enable >2GB models + allow model paths to be passed for generate_arti…
Browse files Browse the repository at this point in the history
…facts API (#20958)

### Description
Alternative design from #20942 

Allow users to pass in a model path for the generate_artifacts API. 

### Motivation and Context
- ONNX API calls such as the onnx checker + shape inference fail when
given a model > 2GB, but work if a path to a model >2GB is passed in.
  • Loading branch information
carzh authored Jun 21, 2024
1 parent 7cf9263 commit 6236707
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 20 deletions.
30 changes: 25 additions & 5 deletions orttraining/orttraining/python/training/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from onnxruntime.tools.convert_onnx_models_to_ort import OptimizationStyle, convert_onnx_models_to_ort
from onnxruntime.training import onnxblock

# threshold for the size of the modelproto where you should use a path instead
USE_PATH_THRESHOLD = 2147483648


class LossType(Enum):
"""Loss type to be added to the training model.
Expand All @@ -37,7 +40,7 @@ class OptimType(Enum):


def generate_artifacts(
model: onnx.ModelProto,
model: Union[onnx.ModelProto, str],
requires_grad: Optional[List[str]] = None,
frozen_params: Optional[List[str]] = None,
loss: Optional[Union[LossType, onnxblock.Block]] = None,
Expand All @@ -61,7 +64,8 @@ def generate_artifacts(
All generated ModelProtos will use the same opsets defined by *model*.
Args:
model: The base model to be used for gradient graph generation.
model: The base model or path to the base model to be used for gradient graph generation. For models >2GB,
use the path to the base model.
requires_grad: List of names of model parameters that require gradient computation
frozen_params: List of names of model parameters that should be frozen.
loss: The loss function enum or onnxblock to be used for training. If None, no loss node is added to the graph.
Expand All @@ -86,6 +90,22 @@ def generate_artifacts(
RuntimeError: If the optimizer provided is not one of the supported optimizers.
"""

loaded_model = None
model_path = None

if isinstance(model, str):
loaded_model = onnx.load(model)
model_path = model
elif isinstance(model, onnx.ModelProto):
if model.ByteSize() > USE_PATH_THRESHOLD:
# infer_shapes and check_model from ONNX both require paths to be used for >2GB models.
raise RuntimeError("This model is > 2GB. Please pass in a path to the ONNX file instead.")

loaded_model = model
model_path = None
else:
raise RuntimeError("Please pass in either a string or an ONNX ModelProto for the model.")

loss_blocks = {
LossType.MSELoss: onnxblock.loss.MSELoss,
LossType.CrossEntropyLoss: onnxblock.loss.CrossEntropyLoss,
Expand Down Expand Up @@ -165,12 +185,12 @@ def build(self, *inputs_to_loss):
logging.info("Custom op library provided: %s", custom_op_library)
custom_op_library_path = pathlib.Path(custom_op_library)

with onnxblock.base(model), (
with onnxblock.base(loaded_model, model_path), (
onnxblock.custom_op_library(custom_op_library_path)
if custom_op_library is not None
else contextlib.nullcontext()
):
_ = training_block(*[output.name for output in model.graph.output])
_ = training_block(*[output.name for output in loaded_model.graph.output])
training_model, eval_model = training_block.to_model_proto()
model_params = training_block.parameters()

Expand Down Expand Up @@ -220,7 +240,7 @@ def _export_to_ort_format(model_path, output_dir, ort_format, custom_op_library_
return

opset_version = None
for domain in model.opset_import:
for domain in loaded_model.opset_import:
if domain.domain == "" or domain.domain == "ai.onnx":
opset_version = domain.version
break
Expand Down
58 changes: 56 additions & 2 deletions orttraining/orttraining/python/training/onnxblock/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import contextlib
import copy
import logging
import os
from abc import ABC, abstractmethod
from typing import Any, List, Optional

Expand All @@ -28,8 +29,13 @@ class Block(ABC):
base (onnx.ModelProto): The base model that the subclass can manipulate.
"""

def __init__(self):
def __init__(self, temp_file_name="temp.onnx"):
if os.path.isabs(temp_file_name):
raise RuntimeError("Please pass in a relative path for the temp_file_name.")
self.base = None
self.temp_onnx_file_path = os.path.join(os.getcwd(), temp_file_name)
# onnx.save location parameter requires a relative path to the model path
self.temp_external_data_file_name = temp_file_name + ".data"

@abstractmethod
def build(self, *args, **kwargs):
Expand All @@ -47,10 +53,58 @@ def __call__(self, *args, **kwargs):

output = self.build(*args, **kwargs)

onnx.checker.check_model(self.base, True)
if accessor._GLOBAL_ACCESSOR.has_path:
onnx.save(
accessor._GLOBAL_ACCESSOR.model,
self.temp_onnx_file_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=self.temp_external_data_file_name,
)

onnx.checker.check_model(self.temp_onnx_file_path, True)
else:
onnx.checker.check_model(self.base, True)

return output

def infer_shapes_on_base(self):
"""
Performs shape inference on the global model. If a path was used, then uses the
infer_shapes_path API to support models with external data.
Returns the shape-inferenced ModelProto.
"""
if accessor._GLOBAL_ACCESSOR.has_path:
onnx.save(
accessor._GLOBAL_ACCESSOR.model,
self.temp_onnx_file_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=self.temp_external_data_file_name,
)

onnx.shape_inference.infer_shapes_path(self.temp_onnx_file_path)
# shape inferenced model is saved to original path
model = onnx.load(self.temp_onnx_file_path)

return model
else:
return onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model)

def __del__(self):
# since the ModelProto does not store the external data parameters themselves, just the metadata
# for where the external data can be found, we retain the external data files for the intermediate
# calls until the Block no longer needs to be used.
if os.path.exists(self.temp_onnx_file_path):
os.remove(self.temp_onnx_file_path)
# get absolute path for the external data file
external_data_file_path = os.path.join(
os.path.dirname(self.temp_onnx_file_path), self.temp_external_data_file_name
)
if os.path.exists(external_data_file_path):
os.remove(external_data_file_path)


class _BinaryOp(Block):
def __init__(self, op_name):
Expand Down
15 changes: 8 additions & 7 deletions orttraining/orttraining/python/training/onnxblock/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,20 @@ def build(self, scores_input_name: str, labels_name: str = "labels"):
labels_input = copy.deepcopy(_graph_utils.get_output_from_output_name(self.base, scores_input_name))
labels_input.name = labels_name
labels_input.type.tensor_type.elem_type = onnx.TensorProto.INT64
# If the predictions are (num_examples x num_classes)
# labels should be (num_examples,)
del labels_input.type.tensor_type.shape.dim[1]
# Assumes classes is the last dimension
# e.g., predictions: (num_examples, num_classes) -> labels: (num_examples,)
# or predictions: (batch_size, seq_len, vocab) -> labels: (batch_size, seq_len)
del labels_input.type.tensor_type.shape.dim[-1]
self.base.graph.input.append(labels_input)

loss_node_input_names = [scores_input_name, labels_name]
if self._weight:
loss_node_input_names.append(weight_name)

loss_node_output_name = _graph_utils.generate_graph_name("loss")
loss_node_output_names = [
loss_node_output_name,
_graph_utils.generate_graph_name("log_prob"),
]
log_prob_output_name = _graph_utils.generate_graph_name("log_prob")

loss_node_output_names = [loss_node_output_name, log_prob_output_name]
loss_node = onnx.helper.make_node(
"SoftmaxCrossEntropyLoss",
loss_node_input_names,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ class ModelAccessor:
Attributes:
model: The onnx model that is manipulated by the onnx blocks.
model_path: The path to the base model. Can be None.
"""

def __init__(self, model: onnx.ModelProto):
def __init__(self, model: onnx.ModelProto, model_path: str | None = None):
self._model = model
self._path = model_path

@property
def model(self) -> onnx.ModelProto:
Expand All @@ -30,6 +32,22 @@ def model(self) -> onnx.ModelProto:
)
return self._model

@property
def path(self) -> str:
"""ModelAccessor property that gets the path to the base model."""

if self._path is None:
raise RuntimeError(
"The path to the onnx model was not set. Please use the context manager onnxblock.onnx_model to create the model and pass in a string."
)
return self._path

@property
def has_path(self) -> bool:
"""Returns True if ModelAccessor has a path to a model, False otherwise."""

return self._path is not None


# These variable resides in the global namespace.
# Different methods can access this global model and manipulate it.
Expand All @@ -39,7 +57,7 @@ def model(self) -> onnx.ModelProto:


@contextmanager
def base(model: onnx.ModelProto):
def base(model: onnx.ModelProto, model_path: str | None = None):
"""Registers the base model to be manipulated by the onnx blocks.
Example:
Expand All @@ -53,6 +71,7 @@ def base(model: onnx.ModelProto):
Args:
model: The base model to be manipulated by the onnx blocks.
model_path: The path to the base model. None if there is no model path to pass in.
Returns:
ModelAccessor: The model accessor that contains the modified model.
Expand All @@ -69,7 +88,7 @@ def base(model: onnx.ModelProto):
"model from scratch."
)

_GLOBAL_ACCESSOR = ModelAccessor(model_clone)
_GLOBAL_ACCESSOR = ModelAccessor(model_clone, model_path)
try:
yield _GLOBAL_ACCESSOR
finally:
Expand Down Expand Up @@ -112,7 +131,7 @@ def empty_base(opset_version: int | None = None):
)
)

_GLOBAL_ACCESSOR = ModelAccessor(model)
_GLOBAL_ACCESSOR = ModelAccessor(model, None)
try:
yield _GLOBAL_ACCESSOR
finally:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __call__(self, *args, **kwargs):

output = self.build(*args, **kwargs)

self._model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model)
self._model = self.infer_shapes_on_base()

_graph_utils.register_graph_outputs(self._model, output)

Expand Down Expand Up @@ -187,7 +187,7 @@ def __call__(self, *args, **kwargs):

output = self.build(*args, **kwargs)

model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model)
model = self.infer_shapes_on_base()

_graph_utils.register_graph_outputs(model, output)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1099,3 +1099,91 @@ def test_custom_optimizer_block():
for attr in node.attribute:
if attr.name == "weight_decay":
assert attr.f == weight_decay


def test_generate_artifacts_path():

with tempfile.TemporaryDirectory() as temp_dir:
_, simple_net = _get_models("cpu", 32, 28, 10, 10)

requires_grad_params = ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"]

onnx.save_model(
simple_net,
os.path.join(temp_dir, "simple_net.onnx"),
)

artifacts.generate_artifacts(
os.path.join(temp_dir, "simple_net.onnx"),
requires_grad=requires_grad_params,
loss=artifacts.LossType.CrossEntropyLoss,
optimizer=artifacts.OptimType.AdamW,
artifact_directory=temp_dir,
)

# generate_artifacts should have thrown if it didn't complete successfully.
# Below is a sanity check to validate that all the expected files were created.
assert os.path.exists(os.path.join(temp_dir, "training_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "checkpoint"))


def test_generate_artifacts_external_data_one_file():
with tempfile.TemporaryDirectory() as temp_dir:
_, simple_net = _get_models("cpu", 32, 28, 10, 10)

requires_grad_params = ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"]

onnx.save_model(
simple_net,
os.path.join(temp_dir, "simple_net.onnx"),
save_as_external_data=True,
all_tensors_to_one_file=True,
size_threshold=0,
)

artifacts.generate_artifacts(
os.path.join(temp_dir, "simple_net.onnx"),
requires_grad=requires_grad_params,
loss=artifacts.LossType.CrossEntropyLoss,
optimizer=artifacts.OptimType.AdamW,
artifact_directory=temp_dir,
)

# generate_artifacts should have thrown if it didn't complete successfully.
# Below is a sanity check to validate that all the expected files were created.
assert os.path.exists(os.path.join(temp_dir, "training_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "checkpoint"))


def test_generate_artifacts_external_data_separate_files():
with tempfile.TemporaryDirectory() as temp_dir:
_, simple_net = _get_models("cpu", 32, 28, 10, 10)

requires_grad_params = ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"]

onnx.save_model(
simple_net,
os.path.join(temp_dir, "simple_net.onnx"),
save_as_external_data=True,
all_tensors_to_one_file=False,
size_threshold=0,
)

artifacts.generate_artifacts(
os.path.join(temp_dir, "simple_net.onnx"),
requires_grad=requires_grad_params,
loss=artifacts.LossType.CrossEntropyLoss,
optimizer=artifacts.OptimType.AdamW,
artifact_directory=temp_dir,
)

# generate_artifacts should have thrown if it didn't complete successfully.
# Below is a sanity check to validate that all the expected files were created.
assert os.path.exists(os.path.join(temp_dir, "training_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "checkpoint"))

0 comments on commit 6236707

Please sign in to comment.