Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable >2GB models + allow model paths to be passed for generate_artifacts API #20958

Merged
merged 14 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions orttraining/orttraining/python/training/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@


def generate_artifacts(
model: onnx.ModelProto,
model: Union[onnx.ModelProto, str],
requires_grad: Optional[List[str]] = None,
frozen_params: Optional[List[str]] = None,
loss: Optional[Union[LossType, onnxblock.Block]] = None,
Expand All @@ -61,7 +61,8 @@
All generated ModelProtos will use the same opsets defined by *model*.

Args:
model: The base model to be used for gradient graph generation.
model: The base model or path to the base model to be used for gradient graph generation. For models >2GB,
Fixed Show fixed Hide fixed
use the path to the base model.
requires_grad: List of names of model parameters that require gradient computation
frozen_params: List of names of model parameters that should be frozen.
loss: The loss function enum or onnxblock to be used for training. If None, no loss node is added to the graph.
Expand Down Expand Up @@ -159,18 +160,20 @@
training_model = None
eval_model = None
model_params = None
loaded_model = onnx.load(model) if isinstance(model, str) else model
model_path = model if isinstance(model, str) else None

custom_op_library_path = None
if custom_op_library is not None:
logging.info("Custom op library provided: %s", custom_op_library)
custom_op_library_path = pathlib.Path(custom_op_library)

with onnxblock.base(model), (
with onnxblock.base(loaded_model, model_path), (
onnxblock.custom_op_library(custom_op_library_path)
if custom_op_library is not None
else contextlib.nullcontext()
):
_ = training_block(*[output.name for output in model.graph.output])
_ = training_block(*[output.name for output in loaded_model.graph.output])
training_model, eval_model = training_block.to_model_proto()
model_params = training_block.parameters()

Expand Down Expand Up @@ -220,7 +223,7 @@
return

opset_version = None
for domain in model.opset_import:
for domain in loaded_model.opset_import:
if domain.domain == "" or domain.domain == "ai.onnx":
opset_version = domain.version
break
Expand Down
5 changes: 4 additions & 1 deletion orttraining/orttraining/python/training/onnxblock/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def __call__(self, *args, **kwargs):

output = self.build(*args, **kwargs)

onnx.checker.check_model(self.base, True)
if accessor._GLOBAL_ACCESSOR.has_path:
onnx.checker.check_model(accessor._GLOBAL_ACCESSOR.path, True)
carzh marked this conversation as resolved.
Show resolved Hide resolved
else:
onnx.checker.check_model(self.base, True)

return output

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ class ModelAccessor:

Attributes:
model: The onnx model that is manipulated by the onnx blocks.
model_path: The path to the base model. Can be None.
"""

def __init__(self, model: onnx.ModelProto):
def __init__(self, model: onnx.ModelProto, model_path: str):
self._model = model
self._path = model_path

@property
def model(self) -> onnx.ModelProto:
Expand All @@ -30,6 +32,22 @@ def model(self) -> onnx.ModelProto:
)
return self._model

@property
def path(self) -> str:
"""ModelAccessor property that gets the path to the base model."""

if self._path is None:
raise RuntimeError(
"The path to the onnx model was not set. Please use the context manager onnxblock.onnx_model to create the model and pass in a string."
)
return self._path

@property
def has_path(self) -> bool:
"""Returns True if ModelAccessor has a path to a model, False otherwise."""

return self._path is not None


# These variable resides in the global namespace.
# Different methods can access this global model and manipulate it.
Expand All @@ -39,7 +57,7 @@ def model(self) -> onnx.ModelProto:


@contextmanager
def base(model: onnx.ModelProto):
def base(model: onnx.ModelProto, model_path: str):
"""Registers the base model to be manipulated by the onnx blocks.

Example:
Expand All @@ -53,6 +71,7 @@ def base(model: onnx.ModelProto):

Args:
model: The base model to be manipulated by the onnx blocks.
model_path: The path to the base model. None if there is no model path to pass in.

Returns:
ModelAccessor: The model accessor that contains the modified model.
Expand All @@ -69,7 +88,7 @@ def base(model: onnx.ModelProto):
"model from scratch."
)

_GLOBAL_ACCESSOR = ModelAccessor(model_clone)
_GLOBAL_ACCESSOR = ModelAccessor(model_clone, model_path)
try:
yield _GLOBAL_ACCESSOR
finally:
Expand Down Expand Up @@ -112,7 +131,7 @@ def empty_base(opset_version: int | None = None):
)
)

_GLOBAL_ACCESSOR = ModelAccessor(model)
_GLOBAL_ACCESSOR = ModelAccessor(model, None)
try:
yield _GLOBAL_ACCESSOR
finally:
Expand Down
15 changes: 13 additions & 2 deletions orttraining/orttraining/python/training/onnxblock/onnxblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,12 @@ def __call__(self, *args, **kwargs):

output = self.build(*args, **kwargs)

self._model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model)
if accessor._GLOBAL_ACCESSOR.has_path:
onnx.shape_inference.infer_shapes_path(accessor._GLOBAL_ACCESSOR.path)
# shape inferenced model is saved to original path
self._model = onnx.load(accessor._GLOBAL_ACCESSOR.path)
else:
self._model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model)

_graph_utils.register_graph_outputs(self._model, output)

Expand Down Expand Up @@ -187,7 +192,13 @@ def __call__(self, *args, **kwargs):

output = self.build(*args, **kwargs)

model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model)
model = None
if accessor._GLOBAL_ACCESSOR.has_path:
onnx.shape_inference.infer_shapes_path(accessor._GLOBAL_ACCESSOR.path)
# shape inferenced model is saved to original path
model = onnx.load(accessor._GLOBAL_ACCESSOR.path)
else:
model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model)

_graph_utils.register_graph_outputs(model, output)

Expand Down
Loading