Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable >2GB models + allow model paths to be passed for generate_artifacts API #20958

Merged
merged 14 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions orttraining/orttraining/python/training/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from onnxruntime.tools.convert_onnx_models_to_ort import OptimizationStyle, convert_onnx_models_to_ort
from onnxruntime.training import onnxblock

# threshold for the size of the modelproto where you should use a path instead
USE_PATH_THRESHOLD = 2147483648


class LossType(Enum):
"""Loss type to be added to the training model.
Expand All @@ -37,7 +40,7 @@ class OptimType(Enum):


def generate_artifacts(
model: onnx.ModelProto,
model: Union[onnx.ModelProto, str],
requires_grad: Optional[List[str]] = None,
frozen_params: Optional[List[str]] = None,
loss: Optional[Union[LossType, onnxblock.Block]] = None,
Expand All @@ -61,7 +64,8 @@ def generate_artifacts(
All generated ModelProtos will use the same opsets defined by *model*.

Args:
model: The base model to be used for gradient graph generation.
model: The base model or path to the base model to be used for gradient graph generation. For models >2GB,
use the path to the base model.
requires_grad: List of names of model parameters that require gradient computation
frozen_params: List of names of model parameters that should be frozen.
loss: The loss function enum or onnxblock to be used for training. If None, no loss node is added to the graph.
Expand All @@ -86,6 +90,22 @@ def generate_artifacts(
RuntimeError: If the optimizer provided is not one of the supported optimizers.
"""

loaded_model = None
model_path = None

if isinstance(model, str):
loaded_model = onnx.load(model)
model_path = model
elif isinstance(model, onnx.ModelProto):
if model.ByteSize() > USE_PATH_THRESHOLD:
# infer_shapes and check_model from ONNX both require paths to be used for >2GB models.
raise RuntimeError("This model is > 2GB. Please pass in a path to the ONNX file instead.")

loaded_model = model
model_path = None
else:
raise RuntimeError("Please pass in either a string or an ONNX ModelProto for the model.")

loss_blocks = {
LossType.MSELoss: onnxblock.loss.MSELoss,
LossType.CrossEntropyLoss: onnxblock.loss.CrossEntropyLoss,
Expand Down Expand Up @@ -165,12 +185,12 @@ def build(self, *inputs_to_loss):
logging.info("Custom op library provided: %s", custom_op_library)
custom_op_library_path = pathlib.Path(custom_op_library)

with onnxblock.base(model), (
with onnxblock.base(loaded_model, model_path), (
onnxblock.custom_op_library(custom_op_library_path)
if custom_op_library is not None
else contextlib.nullcontext()
):
_ = training_block(*[output.name for output in model.graph.output])
_ = training_block(*[output.name for output in loaded_model.graph.output])
training_model, eval_model = training_block.to_model_proto()
model_params = training_block.parameters()

Expand Down Expand Up @@ -220,7 +240,7 @@ def _export_to_ort_format(model_path, output_dir, ort_format, custom_op_library_
return

opset_version = None
for domain in model.opset_import:
for domain in loaded_model.opset_import:
if domain.domain == "" or domain.domain == "ai.onnx":
opset_version = domain.version
break
Expand Down
56 changes: 54 additions & 2 deletions orttraining/orttraining/python/training/onnxblock/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import contextlib
import copy
import logging
import os
from abc import ABC, abstractmethod
from typing import Any, List, Optional

Expand All @@ -28,8 +29,13 @@ class Block(ABC):
base (onnx.ModelProto): The base model that the subclass can manipulate.
"""

def __init__(self):
def __init__(self, temp_file_name="temp.onnx"):
if (os.path.isabs(temp_file_name)):
raise RuntimeError("Please pass in a relative path for the temp_file_name.")
self.base = None
self.temp_onnx_file_path = os.path.join(os.getcwd(), temp_file_name)
# onnx.save location parameter requires a relative path to the model path
self.temp_external_data_file_name = temp_file_name + ".data"

@abstractmethod
def build(self, *args, **kwargs):
Expand All @@ -47,10 +53,56 @@ def __call__(self, *args, **kwargs):

output = self.build(*args, **kwargs)

onnx.checker.check_model(self.base, True)
if accessor._GLOBAL_ACCESSOR.has_path:
onnx.save(
accessor._GLOBAL_ACCESSOR.model,
self.temp_onnx_file_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=self.temp_external_data_file_name,
)

onnx.checker.check_model(self.temp_onnx_file_path, True)
else:
onnx.checker.check_model(self.base, True)

return output

def infer_shapes_on_base(self):
"""
Performs shape inference on the global model. If a path was used, then uses the
infer_shapes_path API to support models with external data.

Returns the shape-inferenced ModelProto.
"""
if accessor._GLOBAL_ACCESSOR.has_path:
onnx.save(
accessor._GLOBAL_ACCESSOR.model,
self.temp_onnx_file_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=self.temp_external_data_file_name,
)

onnx.shape_inference.infer_shapes_path(self.temp_onnx_file_path)
# shape inferenced model is saved to original path
model = onnx.load(self.temp_onnx_file_path)

return model
else:
return onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model)

def __del__(self):
# since the ModelProto does not store the external data parameters themselves, just the metadata
# for where the external data can be found, we retain the external data files for the intermediate
# calls until the Block no longer needs to be used.
if os.path.exists(self.temp_onnx_file_path):
os.remove(self.temp_onnx_file_path)
# get absolute path for the external data file
external_data_file_path = os.path.join(os.path.dirname(self.temp_onnx_file_path), self.temp_external_data_file_name)
if os.path.exists(external_data_file_path):
os.remove(external_data_file_path)


class _BinaryOp(Block):
def __init__(self, op_name):
Expand Down
15 changes: 8 additions & 7 deletions orttraining/orttraining/python/training/onnxblock/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,20 @@ def build(self, scores_input_name: str, labels_name: str = "labels"):
labels_input = copy.deepcopy(_graph_utils.get_output_from_output_name(self.base, scores_input_name))
labels_input.name = labels_name
labels_input.type.tensor_type.elem_type = onnx.TensorProto.INT64
# If the predictions are (num_examples x num_classes)
# labels should be (num_examples,)
del labels_input.type.tensor_type.shape.dim[1]
# Assumes classes is the last dimension
# e.g., predictions: (num_examples, num_classes) -> labels: (num_examples,)
# or predictions: (batch_size, seq_len, vocab) -> labels: (batch_size, seq_len)
del labels_input.type.tensor_type.shape.dim[-1]
self.base.graph.input.append(labels_input)

loss_node_input_names = [scores_input_name, labels_name]
if self._weight:
loss_node_input_names.append(weight_name)

loss_node_output_name = _graph_utils.generate_graph_name("loss")
loss_node_output_names = [
loss_node_output_name,
_graph_utils.generate_graph_name("log_prob"),
]
log_prob_output_name = _graph_utils.generate_graph_name("log_prob")

loss_node_output_names = [loss_node_output_name, log_prob_output_name]
loss_node = onnx.helper.make_node(
"SoftmaxCrossEntropyLoss",
loss_node_input_names,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ class ModelAccessor:

Attributes:
model: The onnx model that is manipulated by the onnx blocks.
model_path: The path to the base model. Can be None.
"""

def __init__(self, model: onnx.ModelProto):
def __init__(self, model: onnx.ModelProto, model_path: str | None = None):
self._model = model
self._path = model_path

@property
def model(self) -> onnx.ModelProto:
Expand All @@ -30,6 +32,22 @@ def model(self) -> onnx.ModelProto:
)
return self._model

@property
def path(self) -> str:
"""ModelAccessor property that gets the path to the base model."""

if self._path is None:
raise RuntimeError(
"The path to the onnx model was not set. Please use the context manager onnxblock.onnx_model to create the model and pass in a string."
)
return self._path

@property
def has_path(self) -> bool:
"""Returns True if ModelAccessor has a path to a model, False otherwise."""

return self._path is not None


# These variable resides in the global namespace.
# Different methods can access this global model and manipulate it.
Expand All @@ -39,7 +57,7 @@ def model(self) -> onnx.ModelProto:


@contextmanager
def base(model: onnx.ModelProto):
def base(model: onnx.ModelProto, model_path: str | None = None):
"""Registers the base model to be manipulated by the onnx blocks.

Example:
Expand All @@ -53,6 +71,7 @@ def base(model: onnx.ModelProto):

Args:
model: The base model to be manipulated by the onnx blocks.
model_path: The path to the base model. None if there is no model path to pass in.

Returns:
ModelAccessor: The model accessor that contains the modified model.
Expand All @@ -69,7 +88,7 @@ def base(model: onnx.ModelProto):
"model from scratch."
)

_GLOBAL_ACCESSOR = ModelAccessor(model_clone)
_GLOBAL_ACCESSOR = ModelAccessor(model_clone, model_path)
try:
yield _GLOBAL_ACCESSOR
finally:
Expand Down Expand Up @@ -112,7 +131,7 @@ def empty_base(opset_version: int | None = None):
)
)

_GLOBAL_ACCESSOR = ModelAccessor(model)
_GLOBAL_ACCESSOR = ModelAccessor(model, None)
try:
yield _GLOBAL_ACCESSOR
finally:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __call__(self, *args, **kwargs):

output = self.build(*args, **kwargs)

self._model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model)
self._model = self.infer_shapes_on_base()

_graph_utils.register_graph_outputs(self._model, output)

Expand Down Expand Up @@ -187,7 +187,7 @@ def __call__(self, *args, **kwargs):

output = self.build(*args, **kwargs)

model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model)
model = self.infer_shapes_on_base()

_graph_utils.register_graph_outputs(model, output)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1099,3 +1099,91 @@ def test_custom_optimizer_block():
for attr in node.attribute:
if attr.name == "weight_decay":
assert attr.f == weight_decay


def test_generate_artifacts_path():

with tempfile.TemporaryDirectory() as temp_dir:
_, simple_net = _get_models("cpu", 32, 28, 10, 10)

requires_grad_params = ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"]

onnx.save_model(
simple_net,
os.path.join(temp_dir, "simple_net.onnx"),
)

artifacts.generate_artifacts(
os.path.join(temp_dir, "simple_net.onnx"),
requires_grad=requires_grad_params,
loss=artifacts.LossType.CrossEntropyLoss,
optimizer=artifacts.OptimType.AdamW,
artifact_directory=temp_dir,
)

# generate_artifacts should have thrown if it didn't complete successfully.
# Below is a sanity check to validate that all the expected files were created.
assert os.path.exists(os.path.join(temp_dir, "training_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "checkpoint"))
carzh marked this conversation as resolved.
Show resolved Hide resolved


def test_generate_artifacts_external_data_one_file():
with tempfile.TemporaryDirectory() as temp_dir:
_, simple_net = _get_models("cpu", 32, 28, 10, 10)

requires_grad_params = ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"]

onnx.save_model(
simple_net,
os.path.join(temp_dir, "simple_net.onnx"),
save_as_external_data=True,
all_tensors_to_one_file=True,
size_threshold=0,
)

artifacts.generate_artifacts(
os.path.join(temp_dir, "simple_net.onnx"),
requires_grad=requires_grad_params,
loss=artifacts.LossType.CrossEntropyLoss,
optimizer=artifacts.OptimType.AdamW,
artifact_directory=temp_dir,
)

# generate_artifacts should have thrown if it didn't complete successfully.
# Below is a sanity check to validate that all the expected files were created.
assert os.path.exists(os.path.join(temp_dir, "training_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "checkpoint"))


def test_generate_artifacts_external_data_separate_files():
with tempfile.TemporaryDirectory() as temp_dir:
_, simple_net = _get_models("cpu", 32, 28, 10, 10)

requires_grad_params = ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"]

onnx.save_model(
simple_net,
os.path.join(temp_dir, "simple_net.onnx"),
save_as_external_data=True,
all_tensors_to_one_file=False,
size_threshold=0,
)

artifacts.generate_artifacts(
os.path.join(temp_dir, "simple_net.onnx"),
requires_grad=requires_grad_params,
loss=artifacts.LossType.CrossEntropyLoss,
optimizer=artifacts.OptimType.AdamW,
artifact_directory=temp_dir,
)

# generate_artifacts should have thrown if it didn't complete successfully.
# Below is a sanity check to validate that all the expected files were created.
assert os.path.exists(os.path.join(temp_dir, "training_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "checkpoint"))
Loading