Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable >2GB models + allow model paths to be passed for generate_artifacts API #20958

Merged
merged 14 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions orttraining/orttraining/python/training/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from onnxruntime.tools.convert_onnx_models_to_ort import OptimizationStyle, convert_onnx_models_to_ort
from onnxruntime.training import onnxblock

# threshold for the size of the modelproto where you should use a path instead
USE_PATH_THRESHOLD = 2147483648


class LossType(Enum):
"""Loss type to be added to the training model.
Expand All @@ -37,7 +40,7 @@ class OptimType(Enum):


def generate_artifacts(
model: onnx.ModelProto,
model: Union[onnx.ModelProto, str],
requires_grad: Optional[List[str]] = None,
frozen_params: Optional[List[str]] = None,
loss: Optional[Union[LossType, onnxblock.Block]] = None,
Expand All @@ -61,7 +64,8 @@ def generate_artifacts(
All generated ModelProtos will use the same opsets defined by *model*.

Args:
model: The base model to be used for gradient graph generation.
model: The base model or path to the base model to be used for gradient graph generation. For models >2GB,
use the path to the base model.
requires_grad: List of names of model parameters that require gradient computation
frozen_params: List of names of model parameters that should be frozen.
loss: The loss function enum or onnxblock to be used for training. If None, no loss node is added to the graph.
Expand All @@ -86,6 +90,22 @@ def generate_artifacts(
RuntimeError: If the optimizer provided is not one of the supported optimizers.
"""

loaded_model = None
model_path = None

if isinstance(model, str):
loaded_model = onnx.load(model)
model_path = model
elif isinstance(model, onnx.ModelProto):
if model.ByteSize() > USE_PATH_THRESHOLD:
# infer_shapes and check_model from ONNX both require paths to be used for >2GB models.
raise RuntimeError("This model is > 2GB. Please pass in a path to the ONNX file instead.")

loaded_model = model
model_path = None
else:
raise RuntimeError("Please pass in either a string or an ONNX ModelProto for the model.")

loss_blocks = {
LossType.MSELoss: onnxblock.loss.MSELoss,
LossType.CrossEntropyLoss: onnxblock.loss.CrossEntropyLoss,
Expand Down Expand Up @@ -165,12 +185,12 @@ def build(self, *inputs_to_loss):
logging.info("Custom op library provided: %s", custom_op_library)
custom_op_library_path = pathlib.Path(custom_op_library)

with onnxblock.base(model), (
with onnxblock.base(loaded_model, model_path), (
onnxblock.custom_op_library(custom_op_library_path)
if custom_op_library is not None
else contextlib.nullcontext()
):
_ = training_block(*[output.name for output in model.graph.output])
_ = training_block(*[output.name for output in loaded_model.graph.output])
training_model, eval_model = training_block.to_model_proto()
model_params = training_block.parameters()

Expand Down Expand Up @@ -220,7 +240,7 @@ def _export_to_ort_format(model_path, output_dir, ort_format, custom_op_library_
return

opset_version = None
for domain in model.opset_import:
for domain in loaded_model.opset_import:
if domain.domain == "" or domain.domain == "ai.onnx":
opset_version = domain.version
break
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,4 @@ def node_arg_exists(model: onnx.ModelProto, node_arg_name: str) -> bool:
return True

return False

5 changes: 4 additions & 1 deletion orttraining/orttraining/python/training/onnxblock/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def __call__(self, *args, **kwargs):

output = self.build(*args, **kwargs)

onnx.checker.check_model(self.base, True)
if accessor._GLOBAL_ACCESSOR.has_path:
onnx.checker.check_model(accessor._GLOBAL_ACCESSOR.path, True)
carzh marked this conversation as resolved.
Show resolved Hide resolved
else:
onnx.checker.check_model(self.base, True)

return output

Expand Down
10 changes: 5 additions & 5 deletions orttraining/orttraining/python/training/onnxblock/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,17 @@ def build(self, scores_input_name: str, labels_name: str = "labels"):
labels_input.type.tensor_type.elem_type = onnx.TensorProto.INT64
# If the predictions are (num_examples x num_classes)
# labels should be (num_examples,)
carzh marked this conversation as resolved.
Show resolved Hide resolved
del labels_input.type.tensor_type.shape.dim[1]
del labels_input.type.tensor_type.shape.dim[-1]
self.base.graph.input.append(labels_input)

loss_node_input_names = [scores_input_name, labels_name]
if self._weight:
loss_node_input_names.append(weight_name)

loss_node_output_name = _graph_utils.generate_graph_name("loss")
loss_node_output_names = [
loss_node_output_name,
_graph_utils.generate_graph_name("log_prob"),
]
log_prob_output_name = _graph_utils.generate_graph_name("log_prob")

loss_node_output_names = [loss_node_output_name, log_prob_output_name]
loss_node = onnx.helper.make_node(
"SoftmaxCrossEntropyLoss",
loss_node_input_names,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import copy
import os
from contextlib import contextmanager
from typing import Optional

import onnx

Expand All @@ -15,10 +16,12 @@

Attributes:
model: The onnx model that is manipulated by the onnx blocks.
model_path: The path to the base model. Can be None.
"""

def __init__(self, model: onnx.ModelProto):
def __init__(self, model: onnx.ModelProto, model_path: Optional[str] = None):
Fixed Show fixed Hide fixed
self._model = model
self._path = model_path

@property
def model(self) -> onnx.ModelProto:
Expand All @@ -30,6 +33,22 @@
)
return self._model

@property
def path(self) -> str:
"""ModelAccessor property that gets the path to the base model."""

if self._path is None:
raise RuntimeError(
"The path to the onnx model was not set. Please use the context manager onnxblock.onnx_model to create the model and pass in a string."
)
return self._path

@property
def has_path(self) -> bool:
"""Returns True if ModelAccessor has a path to a model, False otherwise."""

return self._path is not None


# These variable resides in the global namespace.
# Different methods can access this global model and manipulate it.
Expand All @@ -39,7 +58,7 @@


@contextmanager
def base(model: onnx.ModelProto):
def base(model: onnx.ModelProto, model_path: Optional[str] = None):
Fixed Show fixed Hide fixed
"""Registers the base model to be manipulated by the onnx blocks.

Example:
Expand All @@ -53,6 +72,7 @@

Args:
model: The base model to be manipulated by the onnx blocks.
model_path: The path to the base model. None if there is no model path to pass in.

Returns:
ModelAccessor: The model accessor that contains the modified model.
Expand All @@ -69,7 +89,7 @@
"model from scratch."
)

_GLOBAL_ACCESSOR = ModelAccessor(model_clone)
_GLOBAL_ACCESSOR = ModelAccessor(model_clone, model_path)
try:
yield _GLOBAL_ACCESSOR
finally:
Expand Down Expand Up @@ -112,7 +132,7 @@
)
)

_GLOBAL_ACCESSOR = ModelAccessor(model)
_GLOBAL_ACCESSOR = ModelAccessor(model, None)
try:
yield _GLOBAL_ACCESSOR
finally:
Expand Down
34 changes: 32 additions & 2 deletions orttraining/orttraining/python/training/onnxblock/onnxblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.

import logging
import os
from abc import abstractmethod
from typing import List, Tuple

Expand All @@ -12,6 +13,8 @@
import onnxruntime.training.onnxblock.blocks as blocks
import onnxruntime.training.onnxblock.model_accessor as accessor

TEMP_ONNX_PATH = "temp.onnx"


class ForwardBlock(blocks.Block):
"""Base class for all blocks that require forward model to be automatically built.
Expand Down Expand Up @@ -70,7 +73,20 @@ def __call__(self, *args, **kwargs):

output = self.build(*args, **kwargs)

self._model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model)
if accessor._GLOBAL_ACCESSOR.has_path:
onnx.save(self.base, TEMP_ONNX_PATH, save_as_external_data=True, all_tensors_to_one_file=True)

onnx.shape_inference.infer_shapes_path(TEMP_ONNX_PATH)
# shape inferenced model is saved to original path
self._model = onnx.load(TEMP_ONNX_PATH)

# clean-up temp files
if os.path.exists(TEMP_ONNX_PATH):
os.remove(TEMP_ONNX_PATH)
if os.path.exists(TEMP_ONNX_PATH):
os.remove(TEMP_ONNX_PATH)
carzh marked this conversation as resolved.
Show resolved Hide resolved
else:
self._model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model)

_graph_utils.register_graph_outputs(self._model, output)

Expand Down Expand Up @@ -187,7 +203,21 @@ def __call__(self, *args, **kwargs):

output = self.build(*args, **kwargs)

model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model)
model = None
if accessor._GLOBAL_ACCESSOR.has_path:
onnx.save(self.base, TEMP_ONNX_PATH, save_as_external_data=True, all_tensors_to_one_file=True)
carzh marked this conversation as resolved.
Show resolved Hide resolved

onnx.shape_inference.infer_shapes_path(TEMP_ONNX_PATH)
# shape inferenced model is saved to original path
model = onnx.load(TEMP_ONNX_PATH)

# clean-up temp files
if os.path.exists(TEMP_ONNX_PATH):
os.remove(TEMP_ONNX_PATH)
if os.path.exists(TEMP_ONNX_PATH):
os.remove(TEMP_ONNX_PATH)
else:
model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model)

_graph_utils.register_graph_outputs(model, output)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ def crossentropy_loss(prediction, target):

ort_outs = ort_session.run(ort_output_names, ort_inputs)
torch_outs = crossentropy_loss(pt_model(x), target)
print("ort_ outs", ort_outs)
carzh marked this conversation as resolved.
Show resolved Hide resolved
print("torch outs", torch_outs)

# Then
assert np.allclose(ort_outs[0], _to_numpy(torch_outs))
Expand Down Expand Up @@ -356,6 +358,8 @@ def bcewithlogits_loss(prediction, target):

ort_outs = ort_session.run(ort_output_names, ort_inputs)
torch_outs = bcewithlogits_loss(pt_model(x), target)
print("ort_ outs", ort_outs)
carzh marked this conversation as resolved.
Show resolved Hide resolved
print("torch outs", torch_outs)

# Then
assert np.allclose(ort_outs[0], _to_numpy(torch_outs))
Expand Down Expand Up @@ -1099,3 +1103,85 @@ def test_custom_optimizer_block():
for attr in node.attribute:
if attr.name == "weight_decay":
assert attr.f == weight_decay


def test_generate_artifacts_path():

with tempfile.TemporaryDirectory() as temp_dir:
_, simple_net = _get_models("cpu", 32, 28, 10, 10)

requires_grad_params = ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"]

onnx.save_model(
simple_net,
os.path.join(temp_dir, "simple_net.onnx"),
)

artifacts.generate_artifacts(
os.path.join(temp_dir, "simple_net.onnx"),
requires_grad=requires_grad_params,
loss=artifacts.LossType.CrossEntropyLoss,
optimizer=artifacts.OptimType.AdamW,
artifact_directory=temp_dir,
)

assert os.path.exists(os.path.join(temp_dir, "training_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "checkpoint"))
carzh marked this conversation as resolved.
Show resolved Hide resolved


def test_generate_artifacts_external_data_one_file():
with tempfile.TemporaryDirectory() as temp_dir:
_, simple_net = _get_models("cpu", 32, 28, 10, 10)

requires_grad_params = ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"]

onnx.save_model(
simple_net,
os.path.join(temp_dir, "simple_net.onnx"),
save_as_external_data=True,
all_tensors_to_one_file=True,
size_threshold=0,
)

artifacts.generate_artifacts(
os.path.join(temp_dir, "simple_net.onnx"),
requires_grad=requires_grad_params,
loss=artifacts.LossType.CrossEntropyLoss,
optimizer=artifacts.OptimType.AdamW,
artifact_directory=temp_dir,
)

assert os.path.exists(os.path.join(temp_dir, "training_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "checkpoint"))


def test_generate_artifacts_external_data_separate_files():
with tempfile.TemporaryDirectory() as temp_dir:
_, simple_net = _get_models("cpu", 32, 28, 10, 10)

requires_grad_params = ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"]

onnx.save_model(
simple_net,
os.path.join(temp_dir, "simple_net.onnx"),
save_as_external_data=True,
all_tensors_to_one_file=False,
size_threshold=0,
)

artifacts.generate_artifacts(
os.path.join(temp_dir, "simple_net.onnx"),
requires_grad=requires_grad_params,
loss=artifacts.LossType.CrossEntropyLoss,
optimizer=artifacts.OptimType.AdamW,
artifact_directory=temp_dir,
)

assert os.path.exists(os.path.join(temp_dir, "training_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "checkpoint"))
Loading