From d1e53e4c864c7e59312575e902456c8ca735b9e5 Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Wed, 3 Jan 2024 07:30:20 +0000 Subject: [PATCH 01/32] save --- .../ortmodule/_graph_execution_manager.py | 523 +++++++++++++++--- .../training/ortmodule/_inference_manager.py | 43 +- .../python/training/ortmodule/_io.py | 411 ++++++++------ .../training/ortmodule/_runtime_inspector.py | 3 + .../training/ortmodule/_training_manager.py | 176 +++--- .../python/training/utils/torch_io_helper.py | 8 +- .../python/training/utils/torch_to_onnx.py | 6 + .../python/orttraining_test_ortmodule_api.py | 160 ++++-- 8 files changed, 921 insertions(+), 409 deletions(-) create mode 100644 orttraining/orttraining/python/training/utils/torch_to_onnx.py diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 76943b954837b..ae6ccb1acdfff 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -9,8 +9,9 @@ import logging import os from abc import ABC, abstractmethod # noqa: F401 +from functools import partial from hashlib import md5 as hash_fn -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Mapping, Optional, OrderedDict, Sequence, Tuple import onnx import torch @@ -19,7 +20,13 @@ import onnxruntime from onnxruntime.capi import _pybind_state as C from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference -from onnxruntime.training.utils import ORTModelInputOutputSchemaType, PTable, onnx_dtype_to_pytorch_dtype +from onnxruntime.training.utils import ( + ORTModelInputOutputSchemaType, + ORTModelInputOutputType, + PrimitiveType, + PTable, + onnx_dtype_to_pytorch_dtype, +) from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils @@ -33,7 +40,7 @@ ) from ._gradient_accumulation_manager import GradientAccumulationManager from ._graph_execution_interface import GraphExecutionInterface -from ._io import _FlattenedModule, _InputInfo +from ._io import _FlattenedModule from ._runtime_inspector import RuntimeInspector from ._utils import check_function_has_param, get_rank from ._zero_stage3_compatibility import stage3_export_context @@ -51,7 +58,308 @@ def __init__(self, state, output_info: List[Tuple[torch.Size, torch.device, torc self.output_info = output_info -class GraphExecutionManager(GraphExecutionInterface): +def _get_onnx_file_name(name_prefix, name, export_mode): + suffix = "training" if export_mode == torch.onnx.TrainingMode.TRAINING else "inference" + return f"{name_prefix}_{name}_{suffix}.onnx" + + +def _save_model(model: onnx.ModelProto, file_path: str): + onnx.save(model, file_path) + + +class StaticGraphManager: + def __init__(self): + # Export graph infos + + self._pre_export_graph_info = _io._PreExportGraphInfo() + self._data_accessor = None + + # Value can be either torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL + # To be instantiated in the concrete implementation of GraphExecutionManager + self._export_mode = None + # Exporter can take extra arguments for ORTModule extensions + # It cannot overlap with required/immutable arguments (validated in runtime) + self._export_extra_kwargs = {} + self._exported_graph_info = _io._ExportedGraphInfo() + self._module_parameters: List[inspect.Parameter] = [] + self._exported_model: Optional[onnx.ModelProto] = None + self._args_input_schema: Optional[ORTModelInputOutputSchemaType] = None + self._kwargs_input_schema: Optional[ORTModelInputOutputSchemaType] = None + + # Pre-grad graph infos + self._finalized_graph_info = _io._FinalizedGraphInfo() + self._finalized_model: Optional[onnx.ModelProto] = None + + # self._buffers_as_onnx_graph_input: List[torch.nn.parameter.Parameter] = [] # Cache the list of free buffers, which will be used as onnx graph inputs. + # self._params_as_onnx_graph_input: List[torch.nn.parameter.Parameter] = [] # Cache the list of parameters, which will be used as onnx graph inputs. + + def use_cached_exported_model_or_reexport( + self, + args: Sequence[ORTModelInputOutputType], + kwargs: Mapping[str, ORTModelInputOutputType], + device: Optional[torch.device], + ) -> Tuple[bool, _io._PreExportGraphInfo, _io._ExportedGraphInfo]: + """Create the exported graph if it does not exist, otherwise use the cached one""" + + need_export_model = not self._exported_model + need_export_model = need_export_model or self._original_model_has_changed + + # print("args: ", args, ", kwargs: ", kwargs) + + # Check graph inputs parsed from the model's forward function signature and current inputs, + # if they are different, we need to re-export the model. + pre_export_graph_info, data_accessor = _io.parse_inputs_for_onnx_export( + self._module_parameters, args, kwargs, True, device + ) + + # print(">>>>pre_export_graph_info.onnx_graph_input_names: ", pre_export_graph_info.onnx_graph_input_names) + # print( + # ">>>>pre_export_graph_info.onnx_graph_input_names_require_grad: ", + # pre_export_graph_info.onnx_graph_input_names_require_grad, + # ) + need_export_model = ( + need_export_model + or self._pre_export_graph_info.onnx_graph_input_names != pre_export_graph_info.onnx_graph_input_names + ) + + # Maybe we should also check parameters count or size, because user could modify the parameters after the export. + # But pre_export_graph_info did not contains any parameters as its inputs. + + # Extract the schema from the args and kwargs, and compare with the cached one. + # This check ideally is not needed as we already have the above check, but it is added as a safeguard. + flatten_args, args_schema = _io._extract_schema(copy.copy(args), device) + # print("!!!!!!!!!!!!!!!!!!kwargs", kwargs) + flatten_kwargs, kwargs_schema = _io._extract_schema(copy.copy(kwargs), device) + # print("!!!!!!!!!!!!!!!!!!flatten_kwargs", flatten_kwargs) + # schema = _io._extract_schema({"args": copy.copy(args), "kwargs": copy.copy(kwargs)}, device) + need_export_model = ( + need_export_model or args_schema != self._args_input_schema or kwargs_schema != self._kwargs_input_schema + ) + + if need_export_model: + # Set the schema before exporting the model, so that we can use the schema to unflatten the inputs + # during the flatten module forward run. + self._args_input_schema = args_schema + self._kwargs_input_schema = kwargs_schema + + def _unflatten_inputs( + num_positionals, + args_schema: Optional[ORTModelInputOutputSchemaType], + kwargs_schema: Optional[ORTModelInputOutputSchemaType], + inputs: Sequence[ORTModelInputOutputType], + ): + """Unflattens the inputs into args and kwargs + + The inputs are unflattened in the order they appear in the model's forward function arguments. + + Mainly used for PyTorch run for ONNX export. + """ + restored_args = _io.unflatten_data_using_schema(inputs[:num_positionals], args_schema) + restored_kwargs = _io.unflatten_data_using_schema(inputs[num_positionals:], kwargs_schema) + + return restored_args, restored_kwargs + + self._flattened_module._unflatten_functor = partial( + _unflatten_inputs, len(flatten_args), self._args_input_schema, self._kwargs_input_schema + ) + self._flattened_module.device = device + self._exported_model, exported_graph_info = self._export_model( + self._flattened_module, pre_export_graph_info, flatten_args + flatten_kwargs, {} + ) + self._pre_export_graph_info = pre_export_graph_info + self._data_accessor = data_accessor + + self._original_model_has_changed = False + + self._exported_graph_info = exported_graph_info + + # save the ortmodule exported model + if self._debug_options.save_onnx_models.save: + _save_model( + self._exported_model, + os.path.join( + self._debug_options.save_onnx_models.path, + _get_onnx_file_name( + self._debug_options.save_onnx_models.name_prefix, "torch_exported", self._export_mode + ), + ), + ) + + return need_export_model, pre_export_graph_info, self._exported_graph_info + + def _post_process( + self, exported_model: onnx.ModelProto, exported_graph_info: _io._ExportedGraphInfo + ) -> Tuple[onnx.ModelProto, _io._FinalizedGraphInfo]: + """Post process the exported model, for example, add extra information to the model proto""" + + # Deepcopy the exported model as pre-grad model, in case modification affects the exported model. + + # TODO(): Do pre-grad graph modification as needed, for memory efficient gradient management, etc. + # Currently, we don't do any modification, so just use the exported graph as pre-grad graph. + + finalized_model = copy.deepcopy(exported_model) + + finalized_graph_info = _io._FinalizedGraphInfo() + finalized_graph_info.onnx_graph_input_names = exported_graph_info.onnx_graph_input_names + finalized_graph_info.onnx_graph_input_names_require_grad = ( + exported_graph_info.onnx_graph_input_names_require_grad + ) + + self._finalized_model = finalized_model + self._finalized_graph_info = finalized_graph_info + + return finalized_model, finalized_graph_info + + def use_cached_pre_grad_model_or_reinitialize( + self, reexported_model: bool, pre_export_graph_info: _io._PreExportGraphInfo + ) -> bool: + if self._export_mode == torch.onnx.TrainingMode.TRAINING: + # initializer_names_to_train_set_user_model = [ + # name + # for name, param in self._flattened_module.named_parameters() + # if param.requires_grad and name in self._finalized_graph_info.onnx_graph_input_names + # ] + + if reexported_model: + pass + + # If inputs requiring gradient change from forward to the next, the module_gradient_graph_builder + # needs to be reinitialized so it can compute the backward output for the new inputs that require_grad + else: + # Reinitialize graph builder if the inputs or initializers requiring gradient have changed. + # This can happen when the user changes the model parameters after the onnx export. + # Model may have unused params dropped after export, so we only check those inputs existing in onnx graph. + + onnx_graph_input_requires_grads = [] + parameter_names = {k: v for k, v in self._flattened_module.named_parameters()} + for input_name in self._exported_graph_info.onnx_graph_input_names: + if input_name in parameter_names and parameter_names[input_name].requires_grad: + onnx_graph_input_requires_grads.append(input_name) + else: + # If not in parameter list, then it would come from user defined inputs. + if input_name in pre_export_graph_info.onnx_graph_input_names_require_grad: + onnx_graph_input_requires_grads.append(input_name) + + # print("onnx_graph_input_requires_grads: ", onnx_graph_input_requires_grads) + + if onnx_graph_input_requires_grads != self._exported_graph_info.onnx_graph_input_names_require_grad: + self._exported_graph_info.onnx_graph_input_names_require_grad = onnx_graph_input_requires_grads + else: + return False + + # print( + # "111111onnx_graph_input_names_require_grad: ", + # self._exported_graph_info.onnx_graph_input_names_require_grad, + # ) + self._finalized_model, self._finalized_graph_info = self._post_process( + self._exported_model, self._exported_graph_info + ) + else: + if not reexported_model: + return False + + self._finalized_model = self._exported_model + self._finalized_graph_info = _io._FinalizedGraphInfo() + self._finalized_graph_info.onnx_graph_input_names = self._exported_graph_info.onnx_graph_input_names + self._finalized_graph_info.onnx_graph_input_names_require_grad = ( + self._exported_graph_info.onnx_graph_input_names_require_grad + ) + + self._initializer_input_buffers_for_ort() + + print( + "_finalized_graph_info.onnx_graph_input_names_require_grad: ", + self._finalized_graph_info.onnx_graph_input_names_require_grad, + ) + print("_finalized_graph_info.onnx_graph_input_names: ", self._finalized_graph_info.onnx_graph_input_names) + print( + "o_finalized_graph_info.nnx_graph_input_names_user_defined: ", + self._finalized_graph_info._onnx_graph_input_names_user_defined, + ) + print( + "_finalized_graph_info.onnx_graph_input_names_require_grad_user_defined: ", + self._finalized_graph_info._onnx_graph_input_names_require_grad_user_defined, + ) + + self._initialize_graph_builder() + + return True + + def _initializer_input_buffers_for_ort(self): + parameter_names = {k: v for k, v in self._flattened_module.named_parameters()} + buffer_names = {k: v for k, v in self._flattened_module.named_buffers()} + for input_name in self._finalized_graph_info.onnx_graph_input_names: + if input_name in parameter_names: + self._finalized_graph_info._buffer_for_ort_runs[input_name] = parameter_names[input_name] + elif input_name in buffer_names: + self._finalized_graph_info._buffer_for_ort_runs[input_name] = buffer_names[input_name] + else: + self._finalized_graph_info._buffer_for_ort_runs[input_name] = None + # print(f"append new input_name into _onnx_graph_input_names_user_defined: {input_name}") + self._finalized_graph_info._onnx_graph_input_names_user_defined.append(input_name) + + if input_name in self._exported_graph_info.onnx_graph_input_names_require_grad: + self._finalized_graph_info._onnx_graph_input_names_require_grad_user_defined.append(input_name) + + # For user inputs, we will fill them dynamically during the forward run. + + # def flatten_inputs(self, args: Sequence[ORTModelInputOutputType], kwargs: Mapping[str, ORTModelInputOutputType], device): + # """Flattens the inputs and kwargs into a single tuple of inputs + + # The inputs are flattened in the order they appear in the model's forward function signature + # """ + # # Drop the schema directly, since we would assume self.args_input_schema and self.kwargs_input_schema are + # # always consistent with the model's forward function signature. + # flatten_args, args_schema = _io._extract_schema(args, device) + # flatten_kwargs, = _io._extract_schema(kwargs, device) + # self._num_positionals = len(flatten_args) + # return flatten_args + flatten_kwargs + + # def unflatten_inputs(self, inputs: Sequence[ORTModelInputOutputType]): + # """Unflattens the inputs into args and kwargs + + # The inputs are unflattened in the order they appear in the model's forward function arguments. + + # Mainly used for PyTorch run for ONNX export. + # """ + # restored_args = _io.unflatten_data_using_schema(inputs[: self._num_positionals], self._args_input_schema) + # restored_kwargs = _io.unflatten_data_using_schema(inputs[self._num_positionals :], self._kwargs_input_schema) + + # return restored_args, restored_kwargs + + def construct_inputs( + self, + inputs: Sequence[ORTModelInputOutputType], + kwargs: Mapping[str, ORTModelInputOutputType], + constant_as_tensor: bool, + device: torch.device, + ): + """Constructs the inputs for the forward method + + The inputs are constructed in the order they appear in the model's forward function signature + """ + # print("construct_inputs>>>>>", inputs, kwargs) + for name in self._finalized_graph_info._onnx_graph_input_names_user_defined: + if name in self._data_accessor: + assert name in self._finalized_graph_info._buffer_for_ort_runs, f"{name} is not in _buffer_for_ort_runs" + data = self._data_accessor[name](inputs, kwargs) + # print("data.requires_grad: ", data.requires_grad) + if PrimitiveType.is_primitive_type(data) and constant_as_tensor: + data = PrimitiveType.get_tensor(data, device) + self._finalized_graph_info._buffer_for_ort_runs[name] = data + else: + raise wrap_exception( + ORTModuleONNXModelException, + RuntimeError(f"Input is present in ONNX graph but not provided: {name}."), + ) + + # print("name of buffers: ", self._finalized_graph_info._buffer_for_ort_runs.keys()) + # print("name of onnx graph inputs: ", self._finalized_graph_info.onnx_graph_input_names) + + return self._finalized_graph_info._buffer_for_ort_runs + + +class GraphExecutionManager(GraphExecutionInterface, StaticGraphManager): def __init__( self, module: _FlattenedModule, @@ -62,6 +370,7 @@ def __init__( """Manages construction and execution of ONNX graphs""" super().__init__(module._original_module) + super(GraphExecutionInterface, self).__init__() # IMPORTANT: Debug and Fallback must the configured first self._debug_options = debug_options @@ -81,9 +390,9 @@ def __init__( # Model after inference optimization or gradient building. self._graph_builder = None self._graph_info = None - self._graph_initializer_names = set() - self._graph_initializer_names_to_train = set() - self._graph_initializers: List[torch.nn.parameter.Parameter] = [] + # self._graph_initializer_names = set() + # self._graph_initializer_names_to_train = set() + # self._graph_initializers: List[torch.nn.parameter.Parameter] = [] # TrainingAgent or InferenceAgent self._execution_agent = None @@ -97,16 +406,8 @@ def __init__( # Tracker for ORTModule model export, session creation overhead. self.time_tracker = _logger.TimeTracker() - # Value can be either torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL - # To be instantiated in the concrete implementation of GraphExecutionManager - self._export_mode = None - - # Exporter can take extra arguments for ORTModule extensions - # It cannot overlap with required/immutable arguments (validated in runtime) - self._export_extra_kwargs = {} - # Input and output infos (including schema) for exported model. - self._input_info: Optional[_InputInfo] = None + # self._input_info: Optional[_InputInfo] = None self._module_output_schema: Optional[ORTModelInputOutputSchemaType] = None # Device where the model is placed. @@ -187,7 +488,8 @@ def forward(self): def _build_graph(self, config): if self._runtime_options.use_static_shape: - self._graph_builder.build(config, self._input_info.shape) + # (TODO): add the shape for the onnx graph inputs. + self._graph_builder.build(config) # , self._input_info.shape) else: self._graph_builder.build(config) @@ -264,7 +566,13 @@ def _get_session_config(self): @_logger.TrackTime(_logger.ORTModuleInitPhase.EXPORT) @_logger.SuppressLogs(_logger.ORTModuleInitPhase.EXPORT, is_ort_filter=False) - def _export_model(self, *inputs, **kwargs) -> bool: + def _export_model( + self, + flattened_module: torch.nn.Module, + pre_export_graph_info: _io._PreExportGraphInfo, + args: Sequence[ORTModelInputOutputType], + kwargs: Mapping[str, ORTModelInputOutputType], + ) -> Tuple[onnx.ModelProto, _io._ExportedGraphInfo]: # 1. Set the self._device from the user module # 2. Verify input schema matches the schema used on the previous model export # 3. Export the user model under self._export_training_flag mode @@ -280,44 +588,60 @@ def _export_model(self, *inputs, **kwargs) -> bool: # e.g., some sympy functions in symbolic_shape_infer will change Python's random state. random_states = _utils.get_random_states() - schema = _io._extract_schema({"args": copy.copy(inputs), "kwargs": copy.copy(kwargs)}, self._device) - if ( - self._onnx_models.exported_model - and schema == self._input_info.schema - and not self._original_model_has_changed - ): - # All required models have already been exported previously - return False - self._set_device_from_module(inputs, kwargs) + # schema = _io._extract_schema({"args": copy.copy(inputs), "kwargs": copy.copy(kwargs)}, self._device) + # if ( + # self._onnx_models.exported_model + # and schema == self._input_info.schema + # and not self._original_model_has_changed + # ): + # # All required models have already been exported previously + # return False + self._set_device_from_module(args, kwargs) from onnxruntime.training.utils.hooks._subscriber_manager import no_increase_global_step with no_increase_global_step(): - self._onnx_models.exported_model = self._get_exported_model(schema, *inputs, **kwargs) - if self._debug_options.save_onnx_models.save: - self._onnx_models.save_exported_model( - self._debug_options.save_onnx_models.path, - self._debug_options.save_onnx_models.name_prefix, - self._export_mode, - ) + exported_model = self._get_exported_model(flattened_module, pre_export_graph_info, args, kwargs) + + exported_graph_info = _io._ExportedGraphInfo() + exported_graph_info.onnx_graph_input_names = [input.name for input in exported_model.graph.input] + parameter_names = [name for name, _ in flattened_module.named_parameters()] + exported_graph_info.onnx_graph_input_names_require_grad = [ + input.name + for input in exported_model.graph.input + if input.name in parameter_names or input.name in pre_export_graph_info.onnx_graph_input_names_require_grad + ] + # if self._debug_options.save_onnx_models.save: + # self._onnx_models.save_exported_model( + # self._debug_options.save_onnx_models.path, + # self._debug_options.save_onnx_models.name_prefix, + # self._export_mode, + # ) if self._runtime_options.run_symbolic_shape_infer: - self._onnx_models.exported_model = SymbolicShapeInference.infer_shapes( - self._onnx_models.exported_model, auto_merge=True, guess_output_rank=True + exported_model = SymbolicShapeInference.infer_shapes( + exported_model, auto_merge=True, guess_output_rank=True ) # Restore the recorded random states _utils.set_random_states(random_states) - return True + return exported_model, exported_graph_info - def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inputs, **kwargs) -> onnx.ModelProto: + def _get_exported_model( + self, + flattened_module: torch.nn.Module, + pre_export_graph_info: _io._PreExportGraphInfo, + args: Sequence[ORTModelInputOutputType], + kwargs: Mapping[str, ORTModelInputOutputType], + ) -> onnx.ModelProto: """Exports PyTorch `self._flattened_module` to ONNX for inferencing or training, using `*inputs` and `**kwargs` as input TODO: How to support dynamic axes? Dimensions are determined by samples """ - + # kwargs = {} + # inputs = flatten_args + flatten_kwargs # VERBOSE -> FULL export verbose log + FULL torch other logs from stdout and stderr (C++ backend) # DEVINFO -> FULL export verbose log + FULL torch other logs from stdout and stderr (C++ backend) # INFO -> [Rank 0] FULL export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend) @@ -326,9 +650,9 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu torch_exporter_verbose_log = self._debug_options.logging.log_level <= LogLevel.INFO # Setup dynamic axes for onnx model - self._input_info = _io.parse_inputs_for_onnx_export(self._module_parameters, None, input_schema, inputs, kwargs) + # self._input_info = _io.parse_inputs_for_onnx_export(self._module_parameters, None, input_schema, inputs, kwargs) need_deep_copy = self._runtime_options.deepcopy_before_model_export and _io.can_module_be_deep_cloned( - self._original_module, self._device + flattened_module, self._device ) if not need_deep_copy: if self._runtime_options.deepcopy_before_model_export: @@ -347,15 +671,17 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu ) ( output_names, - output_dynamic_axes, + dynamic_axes, self._module_output_schema, ) = _io.parse_outputs_for_onnx_export_and_extract_schema( - self._original_module, inputs, kwargs, self._logger, self._device, need_deep_copy + flattened_module, args, kwargs, self._logger, self._device, need_deep_copy ) - self._input_info.dynamic_axes.update(output_dynamic_axes) + # self._input_info.dynamic_axes.update(output_dynamic_axes) + # Combine the dymaic axes from inputs and outputs + dynamic_axes.update(pre_export_graph_info.onnx_graph_input_dynamic_axes_map) # FlattenedModule needs _InputInfo to expand user input from *args to *args + **kwargs - self._flattened_module._input_info = self._input_info + # self._flattened_module._input_info = self._input_info self._logger.info("Exporting the PyTorch model to ONNX...") @@ -363,7 +689,7 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu cache_dir = self._runtime_options.ortmodule_cache_dir if cache_dir: filename = os.path.join( - cache_dir, f"{hash_fn(str(self._flattened_module).encode()).hexdigest()}_{get_rank()}.onnx" + cache_dir, f"{hash_fn(str(flattened_module).encode()).hexdigest()}_{get_rank()}.onnx" ) if os.path.exists(cache_dir) and os.path.isfile(filename): self._logger.warning( @@ -375,27 +701,42 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu # Export torch.nn.Module to ONNX f = io.BytesIO() - + print("pre_export_graph_info.onnx_graph_input_names: ", pre_export_graph_info.onnx_graph_input_names) + print( + "pre_export_graph_info.onnx_graph_input_names_require_grad: ", + pre_export_graph_info.onnx_graph_input_names_require_grad, + ) # Deepcopy inputs, since input values may change after model run. # NOTE: Inputs may contain tensors that have attributes preventing their deepcopy (example grad_fn). # Therefore, deepcopy only the data component of the input tensors for export. - sample_inputs_copy, sample_kwargs_copy = _io.deepcopy_model_input(*inputs, **kwargs) + + sample_inputs_copy, sample_kwargs_copy = _io.deepcopy_model_input(*args, **kwargs) + assert len(sample_kwargs_copy) == 0, "Currently, kwargs are not supported for ONNX export." + sample_inputs_as_tuple = sample_inputs_copy # NOTE: Flattening the input will change the 'input schema', resulting in a re-export - sample_inputs_as_tuple = tuple(self._input_info.flatten(sample_inputs_copy, sample_kwargs_copy, self._device)) + # sample_inputs_as_tuple = tuple(self._input_info.flatten(sample_inputs_copy, sample_kwargs_copy, self._device)) # Ops behaving differently under train/eval mode need to be exported with the # correct training flag to reflect the expected behavior. # For example, the Dropout node in a model is dropped under eval mode. assert self._export_mode is not None, "Please use a concrete instance of ExecutionManager" + # print("sample_inputs_as_tuple: ", [v.shape for v in sample_inputs_as_tuple]) + # print("sample_inputs_as_tuple: ", [v.dtype for v in sample_inputs_as_tuple]) + # print("pre_export_graph_info.onnx_graph_input_names: ", pre_export_graph_info.onnx_graph_input_names) + # print( + # "pre_export_graph_info.onnx_graph_input_dynamic_axes_map: ", + # pre_export_graph_info.onnx_graph_input_dynamic_axes_map, + # ) + try: with torch.no_grad(), stage3_export_context(self._runtime_options.enable_zero_stage3_support, self): required_export_kwargs = { - "input_names": self._input_info.names, + "input_names": pre_export_graph_info.onnx_graph_input_names, # did not contains paramerter as its input yet "output_names": output_names, "opset_version": self._runtime_options.onnx_opset_version, "do_constant_folding": False, "training": self._export_mode, - "dynamic_axes": self._input_info.dynamic_axes, + "dynamic_axes": dynamic_axes, "verbose": torch_exporter_verbose_log, "export_params": False, "keep_initializers_as_inputs": True, @@ -416,7 +757,7 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu raise RuntimeError(error_msg) torch.onnx.export( - self._flattened_module, + flattened_module, sample_inputs_as_tuple, f, **required_export_kwargs, @@ -443,7 +784,7 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu exported_model = post_processing_enable_zero_stage3_compat( exported_model, self._zero_stage3_param_map, - [name for name, _ in self._flattened_module.named_parameters()], + [name for name, _ in flattened_module.named_parameters()], ) # Cannot append pull weight trigger name to input names as following, otherwise, the later check ( @@ -456,7 +797,7 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu if not os.path.exists(cache_dir): os.makedirs(cache_dir, exist_ok=True) filename = os.path.join( - cache_dir, f"{hash_fn(str(self._flattened_module).encode()).hexdigest()}_{get_rank()}.onnx" + cache_dir, f"{hash_fn(str(flattened_module).encode()).hexdigest()}_{get_rank()}.onnx" ) self._logger.info(f"Caching model for future runs to {filename}.") onnx.save(exported_model, filename) @@ -498,24 +839,27 @@ def _initialize_graph_builder(self): # All initializer names along with user inputs are a part of the onnx graph inputs # since the onnx model was exported with the flag keep_initializers_as_inputs=True - onnx_initializer_names = {p.name for p in self._onnx_models.exported_model.graph.input} + # onnx_initializer_names = {p.name for p in self._onnx_models.exported_model.graph.input} - # TODO: PyTorch exporter bug: changes the initializer order in ONNX model - initializer_names = [ - name for name, _ in self._flattened_module.named_parameters() if name in onnx_initializer_names - ] - initializer_names_to_train = [ - name - for name, param in self._flattened_module.named_parameters() - if param.requires_grad and name in onnx_initializer_names - ] + # onnx_initializer_names = self._finalized_graph_info.onnx_graph_input_names + + # # TODO: PyTorch exporter bug: changes the initializer order in ONNX model + # initializer_names = [ + # name for name, _ in self._flattened_module.named_parameters() if name in onnx_initializer_names + # ] + # initializer_names_to_train = [ + # name + # for name, param in self._flattened_module.named_parameters() + # if param.requires_grad and name in onnx_initializer_names + # ] # Build and optimize the full graph grad_builder_config = C.OrtModuleGraphBuilderConfiguration() - grad_builder_config.initializer_names = initializer_names - grad_builder_config.initializer_names_to_train = initializer_names_to_train + grad_builder_config.initializer_names = self._finalized_graph_info.onnx_graph_input_names + grad_builder_config.initializer_names_to_train = self._finalized_graph_info.onnx_graph_input_names_require_grad - input_names_require_grad = self._input_info.require_grad_names + # input_names_require_grad = self._input_info.require_grad_names + input_names_require_grad = self._finalized_graph_info._onnx_graph_input_names_require_grad_user_defined if self._runtime_options.enable_zero_stage3_support: from ._zero_stage3_compatibility import STAGE3_PULL_WEIGHT_TRIGGER_NAME @@ -532,18 +876,18 @@ def _initialize_graph_builder(self): # It is assumed here that the order and names of the inputs and outputs are not modified by the backend in any way # and are kept as they appear in the exported onnx model. - self._graph_builder.initialize(self._onnx_models.exported_model.SerializeToString(), grad_builder_config) + self._graph_builder.initialize(self._finalized_model.SerializeToString(), grad_builder_config) # TODO: Explore ways to make self._graph_info.initializer_names and self._graph_info.initializer_names_to_train # a set (unordered_set in the backend) that does not require a copy on each reference. - self._graph_initializer_names = set(initializer_names) - self._graph_initializer_names_to_train = set(initializer_names_to_train) + # self._graph_initializer_names = set(initializer_names) + # self._graph_initializer_names_to_train = set(initializer_names_to_train) # Initializers can be cached and used since they are expected not to be re-instantiated # between forward calls. - self._graph_initializers = [ - param for name, param in self._flattened_module.named_parameters() if name in self._graph_initializer_names - ] + # self._graph_initializers = [ + # param for name, param in self._flattened_module.named_parameters() if name in self._graph_initializer_names + # ] def signal_model_changed(self): """Signals the execution manager to re-export the model on the next forward call""" @@ -588,7 +932,7 @@ def _enable_conditional_optimizations( # Enable data sparsity inspection if sparse optimizer is ON or user wants to print input density. if self._runtime_options.enable_sparse_optimizer or self._runtime_options.print_input_density: self._runtime_inspector.enable_input_inspector( - self._onnx_models.exported_model, self._graph_builder.get_graph_info().user_input_names + self._exported_model, self._pre_export_graph_info.onnx_graph_input_names ) if self._runtime_options.enable_sparse_optimizer: @@ -599,17 +943,26 @@ def _enable_conditional_optimizations( if self._runtime_options.enable_zero_stage3_support: self._append_pull_weight_trigger_as_input(kwargs, detected_device) - _, embed_sparsity_results, label_sparsity_results = _io._combine_input_buffers_initializers( - self._graph_initializers, - self._graph_builder.get_graph_info().user_input_names, - self._input_info, - self._flattened_module.named_buffers(), - inputs, - kwargs, - detected_device, - self._runtime_inspector, - self._zero_stage3_param_map, - ) + prepared_input_map = self.construct_inputs(inputs, kwargs, True, self._device) + + embed_sparsity_results = OrderedDict() + label_sparsity_results = OrderedDict() + + for name, inp in prepared_input_map.items(): + found, embedding_density, label_density = self._runtime_inspector.inspect_input(name, inp) + if found: + if embedding_density < 100: + embed_sparsity_results[name] = embedding_density + if label_density < 100: + label_sparsity_results[name] = label_density + if ( + self._runtime_inspector.memory_ob.is_enabled() + and not self._runtime_inspector.memory_ob.symbolic_dim_collecting_completed + ): + self._runtime_inspector.memory_ob.collect_symbolic_dim_values( + self._finalized_graph_info.onnx_graph_input_dynamic_axes_map, prepared_input_map + ) + self._runtime_inspector.memory_ob.symbolic_dim_collecting_completed = True # Enable sparsity-based optimization when applicable. if len(label_sparsity_results) > 0: diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index 6690af9b71bf1..93481f5ecdb03 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -11,7 +11,7 @@ from onnxruntime.capi import _pybind_state as C -from . import _are_deterministic_algorithms_enabled, _io, _use_deterministic_algorithms, _utils +from . import _are_deterministic_algorithms_enabled, _use_deterministic_algorithms, _utils from ._execution_agent import InferenceAgent from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPolicy from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo @@ -110,15 +110,20 @@ def forward(self, *inputs, **kwargs): build_graph = False if ( self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False - or not self._onnx_models.exported_model + or not self._exported_model ): self.time_tracker.start(ORTModuleInitPhase.EndToEnd) # Exporting module to ONNX for the first time - build_graph = self._export_model(*inputs, **kwargs) - if build_graph: - # If model was exported, then initialize the graph builder. - self._initialize_graph_builder() + build_graph, pre_grad_graph_info, exported_graph_info = self.use_cached_exported_model_or_reexport( + inputs, kwargs, self._device + ) + + build_graph = self.use_cached_pre_grad_model_or_reinitialize(build_graph, pre_grad_graph_info) + # build_graph = self._export_model(*inputs, **kwargs) + # if build_graph: + # # If model was exported, then initialize the graph builder. + # self._initialize_graph_builder() # Build the inference graph if build_graph: @@ -162,23 +167,22 @@ def forward(self, *inputs, **kwargs): if self._runtime_options.enable_zero_stage3_support: self._append_pull_weight_trigger_as_input(kwargs, self._device) - prepared_input_list, _, _ = _io._combine_input_buffers_initializers( - self._graph_initializers, - self._graph_info.user_input_names, - self._input_info, - self._flattened_module.named_buffers(), - inputs, - kwargs, - self._device, - self._runtime_inspector, - self._zero_stage3_param_map, - ) + prepared_input_map = self.construct_inputs(inputs, kwargs, True, self._device) + + if ( + self._runtime_inspector.memory_ob.is_enabled() + and not self._runtime_inspector.memory_ob.symbolic_dim_collecting_completed + ): + self._runtime_inspector.memory_ob.collect_symbolic_dim_values( + self._finalized_graph_info.onnx_graph_input_dynamic_axes_map, prepared_input_map + ) + self._runtime_inspector.memory_ob.symbolic_dim_collecting_completed = True user_outputs, _ = InferenceManager.execution_session_run_forward( self._execution_agent, self._onnx_models.optimized_model, self._device, - *prepared_input_list, + *prepared_input_map.values(), ) if ( @@ -190,6 +194,9 @@ def forward(self, *inputs, **kwargs): self._execution_agent._inference_session, False, self._runtime_options.tuning_results_path ) + # print("user_outputs: ", user_outputs) + # print("self._module_output_schema: ", self._module_output_schema) + return unflatten_user_output(self._module_output_schema, user_outputs) except ORTModuleFallbackException as e: # Exceptions subject to fallback are handled here diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 7534cc46a21f1..68ae87ddf7526 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -6,11 +6,12 @@ import copy import gc import inspect +import warnings from collections import OrderedDict, abc +from functools import partial from logging import Logger -from typing import Dict, Iterator, List, Mapping, Optional, Sequence, Tuple +from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple -import onnx import torch from onnxruntime.training.utils import ( @@ -20,9 +21,9 @@ extract_data_and_schema, unflatten_data_using_schema, ) +from onnxruntime.training.utils.torch_io_helper import _TensorStub -from ._fallback import ORTModuleIOError, ORTModuleONNXModelException, wrap_exception -from ._runtime_inspector import RuntimeInspector +from ._fallback import ORTModuleIOError, wrap_exception class _OutputIdentityOp(torch.autograd.Function): @@ -76,6 +77,69 @@ def symbolic(g, self): return g.op("Identity", self) +class _PreExportGraphInfo: + def __init__(self): + # Input names parsed and then flatten from the model's forward function signature + self.onnx_graph_input_names: List[str] = [] + + # A subset of onnx_graph_input_names. + # Input names that require gradient parsed and then flatten from the model's forward function signature + # This should contains ONLY the user input names + self.onnx_graph_input_names_require_grad: List[str] = [] + + # Create symbolic names for each dimension of the graph input (e.g. onnx_graph_input_names). + # The key is the input name, the value is a dict of {dim_index: symbolic_dim_name} + # e.g. {"input1": {0: "input1_dim0", 1: "input1_dim1"}, "input2": {0: "input2_dim0"}} + self.onnx_graph_input_dynamic_axes_map: Dict[str, Dict[int, str]] = {} + + self.onnx_graph_input_shapes: List[List[int]] = [] + + +class _ExportedGraphInfo: + def __init__(self): + # Input names parsed and then flatten from the model's forward function signature + buffers + parameters (since we use + # keep_initializers_as_inputs=True for model export) + self.onnx_graph_input_names: List[str] = [] + + # A subset of onnx_graph_input_names. + # Input names that require gradient parsed and then flatten from the model's forward function signature + # This should contains both the user input names, the buffer names, and the parameter names (since we use + # keep_initializers_as_inputs=True for model export) + self.onnx_graph_input_names_require_grad: List[str] = [] + + def need_reexport(self, __value: object) -> bool: + assert isinstance( + __value, _ExportedGraphInfo + ), f"__value must be an instance of _ExportedGraphInfo, but got {type(__value)}" + + return self.onnx_graph_input_names != __value.onnx_graph_input_names + + +class _FinalizedGraphInfo: + def __init__(self): + # Input names for the pre-gradient-build graph. + # This may be different with the one in ExportedGraph since we may modify the graph inputs as needed + # for example when memory efficient gradient management is enabled. + self.onnx_graph_input_names: List[str] = [] + + # A subset of onnx_graph_input_names. + # Input names that require gradients for the pre-gradient-build graph. + self.onnx_graph_input_names_require_grad: List[str] = [] + + # Create symbolic names for each dimension of the graph input (e.g. onnx_graph_input_names). + # The key is the input name, the value is a dict of {dim_index: symbolic_dim_name} + # e.g. {"input1": {0: "input1_dim0", 1: "input1_dim1"}, "input2": {0: "input2_dim0"}} + self.onnx_graph_input_dynamic_axes_map: Dict[str, Dict[int, str]] = {} + + # self._graph_initializers: List[torch.nn.parameter.Parameter] = [] + + self._buffer_for_ort_runs: Dict[str, torch.Tensor] = OrderedDict() + self._onnx_graph_input_names_user_defined = [] # The ONNX graph input names excluding the parameters, buffers. + self._onnx_graph_input_names_require_grad_user_defined = ( + [] + ) # The ONNX graph input names excluding the parameters, buffers. + + def flatten_kwargs(kwargs, device): def _flatten_kwargs(value, name): if PrimitiveType.is_primitive_type(value): @@ -159,121 +223,6 @@ def unflatten( return args, kwargs -def _combine_input_buffers_initializers( - params: List[torch.nn.parameter.Parameter], - onnx_input_names: List[str], - input_info: Optional[_InputInfo], - named_buffer: Iterator[Tuple[str, torch.Tensor]], - inputs: Sequence[ORTModelInputOutputType], - kwargs: Mapping[str, ORTModelInputOutputType], - device: torch.device, - rt_inspector: RuntimeInspector, - zero_stage3_offload_param_map: Optional[Dict[str, torch.nn.parameter.Parameter]], -): - """Creates forward `*inputs` list from user input and PyTorch initializers - - ONNX Runtime forward requires an ordered list of: - * User input: computed from forward InferenceSession - * Initializers: computed from original PyTorch model parameters. - """ - - def _expand_inputs(current_input, non_none_inputs, name=""): - # The exporter handles input lists by expanding them so that each - # element of the list is its own input. - # ORTModule must match this behavior by also expanding the inputs. - if current_input is None or isinstance(current_input, str): - # Drop all None and string inputs - return - if isinstance(current_input, abc.Sequence): - # If the input is a sequence (like a list), expand the list so that - # each element of the list is an input by itself - for i, inp in enumerate(current_input): - _expand_inputs(inp, non_none_inputs, f"{name}_{i}" if name else str(i)) - elif isinstance(current_input, abc.Mapping): - # If the input is a mapping (like a dict), expand the dict so that - # each element of the dict is an input by itself - for key, val in current_input.items(): - _expand_inputs(val, non_none_inputs, f"{name}_{key}" if name else key) - else: - # else just collect all the non none inputs within non_none_inputs - if isinstance(non_none_inputs, abc.Sequence): - non_none_inputs.append(current_input) - elif isinstance(non_none_inputs, abc.Mapping): - non_none_inputs[name] = current_input - - # User inputs - non_none_inputs = [] - _expand_inputs(inputs, non_none_inputs) - flattened_kwargs_inputs = {} - _expand_inputs(kwargs, flattened_kwargs_inputs) - buffer_names_dict = None - result = [] - embed_sparsity_results = OrderedDict() - label_sparsity_results = OrderedDict() - onnx_input_to_value_map = OrderedDict() - - for input_idx, name in enumerate(onnx_input_names): - inp = None - if name in flattened_kwargs_inputs and flattened_kwargs_inputs[name] is not None: - # Only use keywords coming from user that are expected by ONNX model - inp = flattened_kwargs_inputs[name] - - if inp is None: - try: - # Only use positionals coming from user that are expected by ONNX model - # if input_idx >= len(input_info.names), IndexError will be thrown - if name != input_info.names[input_idx]: - # When ONNX drops unused inputs, get correct index from user input - # if name is not in input_info.names, ValueError will be thrown - input_idx = input_info.names.index(name) # noqa: PLW2901 - inp = non_none_inputs[input_idx] - except (IndexError, ValueError): - # ONNX input name is not present in input_info.names. - pass - - if inp is None: - # Registered buffers are translated to user_input+initializer in ONNX - if buffer_names_dict is None: - buffer_names_dict = {buffer_name: i for buffer_name, i in named_buffer} - try: # noqa: SIM105 - inp = buffer_names_dict[name] - except KeyError: - # ONNX input name is not present in the registered buffer dict. - pass - - if inp is not None: - if PrimitiveType.is_primitive_type(inp): - inp = PrimitiveType.get_tensor(inp, device) - - found, embedding_density, label_density = rt_inspector.inspect_input(name, inp) - if found: - if embedding_density < 100: - embed_sparsity_results[name] = embedding_density - if label_density < 100: - label_sparsity_results[name] = label_density - result.append(inp) - - onnx_input_to_value_map[name] = inp - else: - raise wrap_exception( - ORTModuleONNXModelException, RuntimeError(f"Input is present in ONNX graph but not provided: {name}.") - ) - - # params is a list of all initializers known to the onnx graph - if zero_stage3_offload_param_map: - for p in params: - if p not in zero_stage3_offload_param_map.values(): - result.append(p) - else: - result.extend(params) - - if rt_inspector.memory_ob.is_enabled() and not rt_inspector.memory_ob.symbolic_dim_collecting_completed: - rt_inspector.memory_ob.collect_symbolic_dim_values(input_info.dynamic_axes, onnx_input_to_value_map) - rt_inspector.memory_ob.symbolic_dim_collecting_completed = True - - return result, embed_sparsity_results, label_sparsity_results - - def deepcopy_model_input( *args, **kwargs ) -> Tuple[Sequence[ORTModelInputOutputType], Mapping[str, ORTModelInputOutputType]]: @@ -299,6 +248,9 @@ def extract_tensor(value): def unflatten_user_output(output_schema: Optional[ORTModelInputOutputSchemaType], outputs: List[torch.Tensor]): try: + # Need to distinguish between a single output and a tuple (having a single tensor) + if len(outputs) == 1 and output_schema is _TensorStub: + return outputs[0] return unflatten_data_using_schema(outputs, output_schema) except TypeError as e: raise wrap_exception( @@ -307,10 +259,12 @@ def unflatten_user_output(output_schema: Optional[ORTModelInputOutputSchemaType] ) from None -def _extract_schema(data: ORTModelInputOutputType, device) -> ORTModelInputOutputSchemaType: +def _extract_schema( + data: ORTModelInputOutputType, device +) -> Tuple[Sequence[ORTModelInputOutputType], ORTModelInputOutputSchemaType]: try: - _, schema = extract_data_and_schema(data, constant_as_tensor=True, device=device) - return schema + flatten_data, schema = extract_data_and_schema(data, constant_as_tensor=True, device=device) + return flatten_data, schema except TypeError as e: raise wrap_exception(ORTModuleIOError, TypeError(f"ORTModule fails to extract schema from data: {e}")) from None @@ -356,31 +310,31 @@ def _populate_output_names_and_dynamic_axes( return output_names, output_dynamic_axes -def _transform_output_to_flat_tuple(data): - """Converts the data to a flat tuple by iterating over the entire data structure""" +# def _transform_output_to_flat_tuple(data): +# """Converts the data to a flat tuple by iterating over the entire data structure""" - def _flatten_data(data, flat_data): - # Recursively traverse over the data and populate the flat_data with torch.Tensors +# def _flatten_data(data, flat_data): +# # Recursively traverse over the data and populate the flat_data with torch.Tensors - if data is None: - return - elif isinstance(data, torch.Tensor): - identity = _OutputIdentityOp.apply - flat_data.append(identity(data)) - elif isinstance(data, abc.Sequence): - for value in data: - _flatten_data(value, flat_data) - elif isinstance(data, abc.Mapping): - for _, value in sorted(data.items()): - _flatten_data(value, flat_data) - else: - raise wrap_exception( - ORTModuleIOError, TypeError(f"ORTModule does not support the following data type {type(data)}.") - ) +# if data is None: +# return +# elif isinstance(data, torch.Tensor): +# identity = _OutputIdentityOp.apply +# flat_data.append(identity(data)) +# elif isinstance(data, abc.Sequence): +# for value in data: +# _flatten_data(value, flat_data) +# elif isinstance(data, abc.Mapping): +# for _, value in sorted(data.items()): +# _flatten_data(value, flat_data) +# else: +# raise wrap_exception( +# ORTModuleIOError, TypeError(f"ORTModule does not support the following data type {type(data)}.") +# ) - flat_data = [] - _flatten_data(data, flat_data) - return tuple(flat_data) +# flat_data = [] +# _flatten_data(data, flat_data) +# return tuple(flat_data) class _FlattenedModule(torch.nn.Module): @@ -390,20 +344,41 @@ def __init__(self, original_module: torch.nn.Module): # Before `forward` is called, _ort_module must be assigned # Updated input info is needed to expand args into *args, **kwargs - self._input_info: Optional[_InputInfo] = None + # self._input_info: Optional[_InputInfo] = None + self._unflatten_functor: Optional[Callable] = None + + self.device: Optional[torch.device] = None + + self._output_schema: Optional[ORTModelInputOutputSchemaType] = None def forward(self, *args): - new_args, new_kwargs = self._input_info.unflatten(args) - return _transform_output_to_flat_tuple(self._original_module(*new_args, **new_kwargs)) + new_args, new_kwargs = self._unflatten_functor(args) + + # print("unflatten args: ", [v.shape for v in new_args]) + # print("unflatten kwargs: ", {k: v.shape for k, v in new_kwargs.items()}) + + original_outputs = self._original_module(*new_args, **new_kwargs) + + # Flatten the outputs + flatten_outputs, self._output_schema = _extract_schema(original_outputs, self.device) + + # Append _OutputIdentityOp to the outputs to support passthrough outputs + final_flatten_outputs = [] + for output in flatten_outputs: + final_flatten_outputs.append(_OutputIdentityOp.apply(output)) + + return final_flatten_outputs def parse_inputs_for_onnx_export( all_input_parameters: List[inspect.Parameter], - onnx_graph: Optional[onnx.ModelProto], - schema: ORTModelInputOutputSchemaType, + # onnx_graph: Optional[onnx.ModelProto], + # schema: ORTModelInputOutputSchemaType, args: Sequence[ORTModelInputOutputType], kwargs: Mapping[str, ORTModelInputOutputType], -) -> _InputInfo: + constant_as_tensor: bool, + device: torch.device, +) -> Tuple[_PreExportGraphInfo, Dict[str, Callable]]: """Parses through the model inputs and returns _InputInfo. Loop through all input parameters, try to flatten them into a 1-D list of inputs. For nested data in the inputs, @@ -430,55 +405,84 @@ def parse_inputs_for_onnx_export( """ + data_accessors: Dict[str, Callable] = OrderedDict() + def _add_dynamic_shape(name, input) -> Dict[str, Dict[int, str]]: dynamic_axes[name] = {} for dim_idx in range(len(input.shape)): dynamic_axes[name].update({dim_idx: f"{name}_dim{dim_idx}"}) return dynamic_axes - def _add_input(name, input_value, onnx_graph, onnx_graph_input_names): + def _warn_of_constant_inputs(data): + warnings.warn(f"Received input of type {type(data)} is treated as a constant by ORT by default.") + + def _add_input(name, input_value, onnx_graph_input_names, cur_func): """Returns number of expanded non none inputs that _add_input processed""" - if name in input_names or input_value is None or isinstance(input_value, str): + # in case the input is already handled. + if name in input_names: # or input_value is None or isinstance(input_value, str): # Drop all None and string inputs and return 0. return - if isinstance(input_value, abc.Sequence): + # InputInfo should contain all the names irrespective of whether they are + # a part of the onnx graph or not. + input_names.append(name) + + value = input_value + if value is None: + _warn_of_constant_inputs(value) + elif isinstance(value, str): + _warn_of_constant_inputs(value) + elif PrimitiveType.is_primitive_type(value): + if constant_as_tensor: + value = PrimitiveType.get_tensor(value, device) + else: + _warn_of_constant_inputs(value) + elif isinstance(value, abc.Sequence): # If the input is a sequence (like a list), expand the list so that # each element of the list is an input by itself. - for i, val in enumerate(input_value): + for i, val in enumerate(value): # Name each input with the index appended to the original name of the # argument. - _add_input(f"{name}_{i}", val, onnx_graph, onnx_graph_input_names) + _add_input( + f"{name}_{i}", + val, + onnx_graph_input_names, + partial(lambda i, args, kwargs: cur_func(args, kwargs)[i], i), + ) # Return here since the list by itself is not a valid input. # All the elements of the list have already been added as inputs individually. return - elif isinstance(input_value, abc.Mapping): + elif isinstance(value, abc.Mapping): # If the input is a mapping (like a dict), expand the dict so that # each element of the dict is an input by itself. - for key, val in input_value.items(): - _add_input(f"{name}_{key}", val, onnx_graph, onnx_graph_input_names) + for key, val in value.items(): + _add_input( + f"{name}_{key}", + val, + onnx_graph_input_names, + partial(lambda key, args, kwargs: cur_func(args, kwargs)[key], key), + ) # Return here since the dict by itself is not a valid input. # All the elements of the dict have already been added as inputs individually. return - # InputInfo should contain all the names irrespective of whether they are - # a part of the onnx graph or not. - input_names.append(name) - - if (onnx_graph is None or name in onnx_graph_input_names) and isinstance(input_value, torch.Tensor): - if input_value.requires_grad: + if isinstance(value, torch.Tensor): + onnx_graph_input_names.append(name) + data_accessors[name] = cur_func + if value.requires_grad: input_names_require_grad.append(name) - dynamic_axes.update(_add_dynamic_shape(name, input_value)) - input_shape.append(list(input_value.size())) + dynamic_axes.update(_add_dynamic_shape(name, value)) + input_shape.append(list(value.size())) # Ignore optional inputs explicitly specified as None # ONNX exporter may remove unused inputs onnx_graph_input_names: List[str] = [] - if onnx_graph is not None: - onnx_graph_input_names = {inp.name for inp in onnx_graph.graph.input} + # onnx_graph = None + # if onnx_graph is not None: + # onnx_graph_input_names = {inp.name for inp in onnx_graph.graph.input} input_names: List[str] = [] dynamic_axes: Dict[str, Dict[int, str]] = {} @@ -513,7 +517,13 @@ def _add_input(name, input_value, onnx_graph, onnx_graph_input_names): name = f"{input_parameter.name}_{var_positional_idx}" var_positional_idx += 1 inp = args[args_i] - _add_input(name, inp, onnx_graph, onnx_graph_input_names) + _add_input( + name, + inp, + # onnx_graph, + onnx_graph_input_names, + partial(lambda args_i, args, kwargs: args[args_i], args_i), + ) elif ( input_parameter.kind == inspect.Parameter.POSITIONAL_ONLY or input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD @@ -523,24 +533,35 @@ def _add_input(name, input_value, onnx_graph, onnx_graph_input_names): name = input_parameter.name inp = None input_idx += var_positional_idx # noqa: PLW2901 + access_func = None if input_idx < len(args) and args[input_idx] is not None: inp = args[input_idx] + + access_func = partial(lambda input_idx, args, kwargs: args[input_idx], input_idx) + elif name in kwargs and kwargs[name] is not None: inp = kwargs[name] - _add_input(name, inp, onnx_graph, onnx_graph_input_names) + + access_func = partial(lambda name, args, kwargs: kwargs[name], name) + + _add_input(name, inp, onnx_graph_input_names, access_func) elif input_parameter.kind == inspect.Parameter.VAR_KEYWORD: # **kwargs is always the last argument of forward() for name, inp in kwargs.items(): - _add_input(name, inp, onnx_graph, onnx_graph_input_names) + _add_input( + name, + inp, + onnx_graph_input_names, + partial(lambda name, args, kwargs: kwargs[name], name), + ) - return _InputInfo( - names=input_names, - shape=input_shape, - require_grad_names=input_names_require_grad, - dynamic_axes=dynamic_axes, - schema=schema, - num_positionals=len(args), - ) + exported_graph = _PreExportGraphInfo() + exported_graph.onnx_graph_input_names = onnx_graph_input_names + exported_graph.onnx_graph_input_names_require_grad = input_names_require_grad + exported_graph.onnx_graph_input_dynamic_axes_map = dynamic_axes + exported_graph.onnx_graph_input_shapes = input_shape + + return exported_graph, data_accessors def calculate_total_parameter_size_in_bytes(module: torch.nn.Module) -> int: @@ -608,10 +629,36 @@ def parse_outputs_for_onnx_export_and_extract_schema( sample_outputs = model_copy(*sample_args_copy, **sample_kwargs_copy) + # print("sample_outputs: ", sample_outputs) + # Parse the output and extract the output_names and output_dynamic_axes to be used for onnx export - output_names, output_dynamic_axes = _parse_outputs_and_extract_names_and_dynamic_axes(sample_outputs) + # output_names, output_dynamic_axes = _parse_outputs_and_extract_names_and_dynamic_axes(sample_outputs) + + output_names: List[str] = [] + output_dynamic_axes: Dict[str, Dict[int, str]] = {} + + # # Naming the outputs with a hyphen ensures that there can be no input with the same + # # name, preventing collisions with other NodeArgs (for example an input to forward called output0) + # output_name = f"output-{output_idx[0]}" + # output_idx[0] += 1 + # output_names.append(output_name) + # output_dynamic_axes[output_name] = {} + # for dim_idx in range(len(output.shape)): + # output_dynamic_axes[output_name].update({dim_idx: f"{output_name}_dim{dim_idx}"}) + + for output_idx, output in enumerate(sample_outputs): + output_name = f"output-{output_idx}" + output_names.append(output_name) + output_dynamic_axes[output_name] = {} + for dim_idx in range(len(output.shape)): + output_dynamic_axes[output_name].update({dim_idx: f"{output_name}_dim{dim_idx}"}) + + _, flattend_module_output_schema = _extract_schema(sample_outputs, device) + + original_module_output_schema = model_copy._output_schema + + # print("output_schema: ", flattend_module_output_schema) - output_schema = _extract_schema(sample_outputs, device) if deep_copied: del model_copy gc.collect() @@ -620,4 +667,4 @@ def parse_outputs_for_onnx_export_and_extract_schema( # Release the memory cached by torch. torch.cuda.empty_cache() # Return output names, output dynamic axes and output schema - return output_names, output_dynamic_axes, output_schema + return output_names, output_dynamic_axes, original_module_output_schema diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py index 078ce4d27cd6f..45cb9f0671d95 100644 --- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py +++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py @@ -65,6 +65,9 @@ def enable_input_inspector(self, model: ModelProto, user_input_names: List[str]) else: raise RuntimeError("Input density observer is already enabled.") + if model is None: + raise RuntimeError("ONNX model is not available when enabling input density inspection.") + return self.input_density_ob.initialize(model, user_input_names) def inspect_input(self, input_name, input_data) -> Tuple[bool, float, float]: diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 5b2c673ce94cb..fe55b16537c5c 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -12,12 +12,12 @@ from onnxruntime.capi import _pybind_state as C from onnxruntime.capi.onnxruntime_inference_collection import get_ort_device_type -from . import _are_deterministic_algorithms_enabled, _io, _use_deterministic_algorithms, _utils +from . import _are_deterministic_algorithms_enabled, _use_deterministic_algorithms, _utils from ._execution_agent import TrainingAgent from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPolicy from ._gradient_accumulation_manager import GradientAccumulationManager from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo -from ._io import _FlattenedModule, _InputInfo, unflatten_user_output +from ._io import _FlattenedModule, unflatten_user_output from ._logger import ORTModuleInitPhase, TrackTime from ._runtime_inspector import Phase from ._utils import save_tuning_results, set_tuning_results @@ -207,6 +207,9 @@ def backward(ctx, *grad_outputs): self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD) + # print("transferred_backward_outputs: ", transferred_backward_outputs) + # print("self._gradient_map: ", self._gradient_map) + return tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map) return _ORTModuleFunction @@ -247,27 +250,37 @@ def forward(self, *inputs, **kwargs): build_gradient_graph = False if ( self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False - or not self._onnx_models.exported_model + or not self._exported_model ): self.time_tracker.start(ORTModuleInitPhase.EndToEnd) - build_gradient_graph = self._export_model(*inputs, **kwargs) + ( + build_gradient_graph, + pre_grad_graph_info, + exported_graph_info, + ) = self.use_cached_exported_model_or_reexport(inputs, kwargs, self._device) - if build_gradient_graph: - # If model was exported, then initialize the graph builder - self._initialize_graph_builder() - - # Since the schema was just extracted while trying to export the model and it was either - # saved to self._input_info.schema or checked for equality with the self._input_info.schema - # it should not need to be updated again. Pass it inside parse_inputs_for_onnx_export. - input_info = _io.parse_inputs_for_onnx_export( - self._module_parameters, self._onnx_models.exported_model, self._input_info.schema, inputs, kwargs + build_gradient_graph = self.use_cached_pre_grad_model_or_reinitialize( + build_gradient_graph, pre_grad_graph_info ) - # Reinitialize graph builder if the inputs or initializers requiring gradient have changed. - # Order of or operation is important here because we always need to call - # _reinitialize_graph_builder irrespective of the value of build_gradient_graph. - build_gradient_graph = self._reinitialize_graph_builder(input_info) or build_gradient_graph + # build_gradient_graph = self._export_model(*inputs, **kwargs) + + # if build_gradient_graph: + # # If model was exported, then initialize the graph builder + # self._initialize_graph_builder() + + # # Since the schema was just extracted while trying to export the model and it was either + # # saved to self._input_info.schema or checked for equality with the self._input_info.schema + # # it should not need to be updated again. Pass it inside parse_inputs_for_onnx_export. + # input_info = _io.parse_inputs_for_onnx_export( + # self._module_parameters, self._onnx_models.exported_model, self._input_info.schema, inputs, kwargs + # ) + + # # Reinitialize graph builder if the inputs or initializers requiring gradient have changed. + # # Order of or operation is important here because we always need to call + # # _reinitialize_graph_builder irrespective of the value of build_gradient_graph. + # build_gradient_graph = self._reinitialize_graph_builder(input_info) or build_gradient_graph # Build the gradient graph if build_gradient_graph: @@ -313,23 +326,47 @@ def forward(self, *inputs, **kwargs): if self._runtime_options.enable_zero_stage3_support: self._append_pull_weight_trigger_as_input(kwargs, self._device) - prepared_input_list, _, _ = _io._combine_input_buffers_initializers( - self._graph_initializers, - self._graph_info.user_input_names, - self._input_info, - self._flattened_module.named_buffers(), - inputs, - kwargs, - self._device, - self._runtime_inspector, - self._zero_stage3_param_map, - ) + # prepared_input_list = self.construct_inputs(inputs, kwargs) + + prepared_input_map = self.construct_inputs(inputs, kwargs, True, self._device) + + if ( + self._runtime_inspector.memory_ob.is_enabled() + and not self._runtime_inspector.memory_ob.symbolic_dim_collecting_completed + ): + self._runtime_inspector.memory_ob.collect_symbolic_dim_values( + self._finalized_graph_info.onnx_graph_input_dynamic_axes_map, prepared_input_map + ) + self._runtime_inspector.memory_ob.symbolic_dim_collecting_completed = True + + # prepared_input_list, _, _ = _io._combine_input_buffers_initializers( + # self._graph_initializers, + # self._graph_info.user_input_names, + # self._input_info, + # self._flattened_module.named_buffers(), + # inputs, + # kwargs, + # self._device, + # self._runtime_inspector, + # self._zero_stage3_param_map, + # ) + + for input_name, input_value in prepared_input_map.items(): + print( + f"input_name: {input_name}, shape: {input_value.shape}, dtype: {input_value.dtype}, requires_grad: {input_value.requires_grad}" + ) + + user_outputs = self._forward_class.apply(*prepared_input_map.values()) + # print("user_outputs: ", user_outputs) + # print("self._module_output_schema: ", self._module_output_schema) outputs = unflatten_user_output( self._module_output_schema, - self._forward_class.apply(*prepared_input_list), + user_outputs, ) + # print("outputs: ", outputs) + if ( create_execution_session and self._runtime_options.enable_tuning @@ -380,24 +417,33 @@ def _build_graph(self, graph_transformer_config): # Map each input/initializer to its gradient index in the graph output, or -1 is gradient is not required. self._gradient_map = [] - num_user_input_grads = len(self._input_info.require_grad_names) - require_grad_names_set = set(self._input_info.require_grad_names) - require_grad_names_index = 0 - for input_name in self._graph_info.user_input_names: - if input_name in require_grad_names_set: - self._gradient_map.append(require_grad_names_index) - require_grad_names_index += 1 - else: - self._gradient_map.append(-1) - initializer_index = num_user_input_grads - for initializer_name in self._graph_info.initializer_names: - if initializer_name in self._graph_initializer_names_to_train: - self._gradient_map.append(initializer_index) - initializer_index += 1 + index_for_input_requires_grad = 0 + for input_name in self._finalized_graph_info.onnx_graph_input_names: + if input_name in self._finalized_graph_info.onnx_graph_input_names_require_grad: + self._gradient_map.append(index_for_input_requires_grad) + index_for_input_requires_grad += 1 else: self._gradient_map.append(-1) + # num_user_input_grads = len(self._finalized_graph_info.onnx_graph_input_names_require_grad) + # require_grad_names_set = set(self._finalized_graph_info.onnx_graph_input_names_require_grad) + # require_grad_names_index = 0 + # for input_name in self._graph_info.user_input_names: + # if input_name in require_grad_names_set: + # self._gradient_map.append(require_grad_names_index) + # require_grad_names_index += 1 + # else: + # self._gradient_map.append(-1) + + # initializer_index = num_user_input_grads + # for initializer_name in self._graph_info.initializer_names: + # if initializer_name in self._finalized_graph_info.onnx_graph_input_names_require_grad: + # self._gradient_map.append(initializer_index) + # initializer_index += 1 + # else: + # self._gradient_map.append(-1) + @TrackTime(ORTModuleInitPhase.CREATE_SESSION) def _create_execution_agent(self): """Creates a TrainingAgent that can run the forward and backward graph on the training model""" @@ -480,28 +526,28 @@ def _create_execution_agent(self): self._execution_agent._inference_session, True, self._runtime_options.tuning_results_path ) - def _reinitialize_graph_builder(self, input_info: _InputInfo): - """Return true if the module graph builder was reinitialized""" - - # Model may have unused params dropped after export and not part of self._graph_initializer_names_to_train - # To see if any trainable initializers changed, compare self._graph_initializer_names_to_train - # with initializers in module named_parameters that are known to the onnx graph. - initializer_names_to_train_set_user_model = { - name - for name, param in self._flattened_module.named_parameters() - if param.requires_grad and name in self._graph_initializer_names - } - - # If inputs requiring gradient change from forward to the next, the module_gradient_graph_builder - # needs to be reinitialized so it can compute the backward output for the new inputs that require_grad - if ( - input_info.require_grad_names != self._input_info.require_grad_names - or initializer_names_to_train_set_user_model != self._graph_initializer_names_to_train - ): - self._input_info = input_info - self._initialize_graph_builder() - return True - return False + # def _reinitialize_graph_builder(self, input_info: _InputInfo): + # """Return true if the module graph builder was reinitialized""" + + # # Model may have unused params dropped after export and not part of self._graph_initializer_names_to_train + # # To see if any trainable initializers changed, compare self._graph_initializer_names_to_train + # # with initializers in module named_parameters that are known to the onnx graph. + # initializer_names_to_train_set_user_model = { + # name + # for name, param in self._flattened_module.named_parameters() + # if param.requires_grad and name in self._graph_initializer_names + # } + + # # If inputs requiring gradient change from forward to the next, the module_gradient_graph_builder + # # needs to be reinitialized so it can compute the backward output for the new inputs that require_grad + # if ( + # input_info.require_grad_names != self._input_info.require_grad_names + # or initializer_names_to_train_set_user_model != self._graph_initializer_names_to_train + # ): + # self._input_info = input_info + # self._initialize_graph_builder() + # return True + # return False def __getstate__(self): state = super().__getstate__() diff --git a/orttraining/orttraining/python/training/utils/torch_io_helper.py b/orttraining/orttraining/python/training/utils/torch_io_helper.py index 34cc1ca942a8c..4824ed7137021 100644 --- a/orttraining/orttraining/python/training/utils/torch_io_helper.py +++ b/orttraining/orttraining/python/training/utils/torch_io_helper.py @@ -5,7 +5,7 @@ import copy import warnings -from collections import abc +from collections import OrderedDict, abc from typing import List, Mapping, Optional, Sequence, Tuple, Union import torch @@ -221,8 +221,8 @@ def _flatten_from_data(data: ORTModelInputOutputType, prefix_name: str = ""): return stubbed_schema elif isinstance(data, abc.Mapping): dict_type = type(data) - stubbed_schema = {} - for key, val in sorted(data.items()): + stubbed_schema = OrderedDict() + for key, val in data.items(): stubbed_schema[key] = _flatten_from_data(val, f"{prefix_name}_{key}" if prefix_name else f"{key}") stubbed_schema = dict_type(**stubbed_schema) return stubbed_schema @@ -305,7 +305,7 @@ def _replace_stub_with_tensor_value(data_schema: ORTModelInputOutputSchemaType, return data_schema elif isinstance(data_schema, abc.Mapping): new_user_output = copy.copy(data_schema) - for key, schema_val in sorted(data_schema.items()): + for key, schema_val in data_schema.items(): new_user_output[key] = _replace_stub_with_tensor_value(schema_val, data) data_schema = new_user_output diff --git a/orttraining/orttraining/python/training/utils/torch_to_onnx.py b/orttraining/orttraining/python/training/utils/torch_to_onnx.py new file mode 100644 index 0000000000000..f5c1fb13cbbc2 --- /dev/null +++ b/orttraining/orttraining/python/training/utils/torch_to_onnx.py @@ -0,0 +1,6 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from __future__ import annotations diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index f944d8bc5ef42..cb79cb712627c 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -447,7 +447,11 @@ def test_forward_call_single_positional_argument(): N, D_in, H, D_out = 64, 784, 500, 10 # noqa: N806 model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) - ort_model = ORTModule(model) + # ort_model = ORTModule(model) + from onnxruntime.training.ortmodule import DebugOptions, LogLevel, ORTModule + + ort_model = ORTModule(model, DebugOptions(save_onnx=True, log_level=LogLevel.INFO, onnx_prefix="2024_0102_01")) + # Check that the original forward signature is preserved. assert inspect.signature(model.forward) == inspect.signature(ort_model.forward) x = torch.randn(N, D_in, device=device) @@ -684,7 +688,15 @@ def test_input_requires_grad_saved(device): model = ORTModule(model) x = torch.randn(N, D_in, device=device, requires_grad=True) + 1 model(x) - assert model._torch_module._execution_manager(model._is_training())._input_info.require_grad_names == ["input1"] + assert model._torch_module._execution_manager( + model._is_training() + )._pre_export_graph_info.onnx_graph_input_names_require_grad == ["input1"] + assert ( + "input1" + in model._torch_module._execution_manager( + model._is_training() + )._finalized_graph_info.onnx_graph_input_names_require_grad + ) @pytest.mark.parametrize("device", ["cuda", "cpu"]) @@ -826,7 +838,14 @@ def forward(self, input): device = "cuda" pt_model = NeuralNetTranspose(perm).to(device) - ort_model = ORTModule(copy.deepcopy(pt_model)) + + from onnxruntime.training.ortmodule import DebugOptions, LogLevel, ORTModule + + ort_model = ORTModule( + copy.deepcopy(pt_model), DebugOptions(save_onnx=True, log_level=LogLevel.INFO, onnx_prefix="2024_0103_01") + ) + + # ort_model = ORTModule(copy.deepcopy(pt_model)) def run_step(model, x): prediction = model(x) @@ -835,11 +854,11 @@ def run_step(model, x): return prediction x = torch.randn(*shape, device=device, requires_grad=True) - pt_prediction = run_step(pt_model, x) + # pt_prediction = run_step(pt_model, x) ort_prediction = run_step(ort_model, x) - _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) - _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) + # _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + # _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) @pytest.mark.parametrize( @@ -2617,7 +2636,10 @@ def test_exception_raised_for_custom_class_return_value_module(device): # ORT backend with pytest.raises(_fallback.ORTModuleIOError) as runtime_error: ort_model(x, y, z) - assert "ORTModule does not support the following model output type" in str(runtime_error.value) + assert ( + "ORTModule fails to extract schema from data: Unsupported flatten data type: " + in str(runtime_error.value) + ) del os.environ["ORTMODULE_SKIPCHECK_POLICY"] @@ -3468,21 +3490,30 @@ def test_forward_dynamic_args(): for _ in range(10): output = model(*args_size1) assert output is not None - hash_args_size1 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) + hash_args_size1 = hash( + repr(model._torch_module._execution_manager(model._is_training())._args_input_schema) + + repr(model._torch_module._execution_manager(model._is_training())._kwargs_input_schema) + ) assert hash_args_size1 is not None # Decrease number of inputs and train some more for _ in range(10): output = model(*args_size2) assert output is not None - hash_args_size2 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) + hash_args_size2 = hash( + repr(model._torch_module._execution_manager(model._is_training())._args_input_schema) + + repr(model._torch_module._execution_manager(model._is_training())._kwargs_input_schema) + ) assert hash_args_size2 != hash_args_size1 # Increase number of inputs and train some more for _ in range(10): output = model(*args_size3) assert output is not None - hash_args_size3 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) + hash_args_size3 = hash( + repr(model._torch_module._execution_manager(model._is_training())._args_input_schema) + + repr(model._torch_module._execution_manager(model._is_training())._kwargs_input_schema) + ) assert hash_args_size3 != hash_args_size2 del os.environ["ORTMODULE_SKIPCHECK_POLICY"] @@ -3507,35 +3538,50 @@ def test_forward_dynamic_kwargs(): for _ in range(10): output = model(one) assert output is not None - hash_x = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) + hash_x = hash( + repr(model._torch_module._execution_manager(model._is_training())._args_input_schema) + + repr(model._torch_module._execution_manager(model._is_training())._kwargs_input_schema) + ) assert hash_x is not None # Train with x and y as inputs for _ in range(10): output = model(one, y=one) assert output is not None - hash_x_y = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) + hash_x_y = hash( + repr(model._torch_module._execution_manager(model._is_training())._args_input_schema) + + repr(model._torch_module._execution_manager(model._is_training())._kwargs_input_schema) + ) assert hash_x_y != hash_x # Train with x and z as inputs for _ in range(10): output = model(one, z=one) assert output is not None - hash_x_z = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) + hash_x_z = hash( + repr(model._torch_module._execution_manager(model._is_training())._args_input_schema) + + repr(model._torch_module._execution_manager(model._is_training())._kwargs_input_schema) + ) assert hash_x_z != hash_x_y # Train with x, y and z as inputs for _ in range(10): output = model(one, y=one, z=one) assert output is not None - hash_x_y_z = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) + hash_x_y_z = hash( + repr(model._torch_module._execution_manager(model._is_training())._args_input_schema) + + repr(model._torch_module._execution_manager(model._is_training())._kwargs_input_schema) + ) assert hash_x_y_z != hash_x_z # Return to original input with x as input for _ in range(10): output = model(one) assert output is not None - hash_x2 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) + hash_x2 = hash( + repr(model._torch_module._execution_manager(model._is_training())._args_input_schema) + + repr(model._torch_module._execution_manager(model._is_training())._kwargs_input_schema) + ) assert hash_x2 != hash_x_y_z assert hash_x2 == hash_x @@ -3545,7 +3591,7 @@ def test_forward_dynamic_kwargs(): @pytest.mark.parametrize( "forward_function", [ # Only pos_X, pos_X as positionals - lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1), + # lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1), # Only pos_X, pos_X as keywords lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0=pos_0, pos_1=pos_1), # pos_X + *args, pos_X as positionals @@ -3616,7 +3662,11 @@ def forward(self, pos_0, pos_1, *args, kw_0=None, kw_1=None, **kwargs): device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 # noqa: N806 model = KwargsNet(input_size=D_in, hidden_size=H, num_classes=D_out).to(device) - model = ORTModule(model) + # model = ORTModule(model) + + from onnxruntime.training.ortmodule import DebugOptions, LogLevel, ORTModule + + model = ORTModule(model, DebugOptions(save_onnx=True, log_level=LogLevel.INFO, onnx_prefix="2024_0102_02")) # Dummy inputs used pos_0 = torch.randn(N, D_in, device=device) @@ -3955,10 +4005,10 @@ def forward(self, input1, bool_argument): input1 = torch.randn(N, D_in, device=device) ort_model(input1, bool_arguments[0]) - exported_model1 = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_models.exported_model + exported_model1 = ort_model._torch_module._execution_manager(ort_model._is_training())._exported_model ort_model(input1, bool_arguments[1]) - exported_model2 = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_models.exported_model + exported_model2 = ort_model._torch_module._execution_manager(ort_model._is_training())._exported_model assert exported_model1 != exported_model2 @@ -4124,34 +4174,34 @@ def test_stateless_model_unspecified_device(): _test_helpers.assert_values_are_close(pt_y, ort_y) -@pytest.mark.parametrize( - "model", - [ - (UnusedBeginParameterNet(784, 500, 400, 10)), - (UnusedMiddleParameterNet(784, 500, 400, 10)), - (UnusedEndParameterNet(784, 500, 400, 10)), - ], -) -def test_unused_parameters_does_not_unnecessarily_reinitialize(model): - device = "cuda" - - N, D_in, H1, H2, D_out = 64, 784, 500, 400, 10 # noqa: F841, N806 - model = model.to(device) - ort_model = ORTModule(copy.deepcopy(model)) - training_manager = ort_model._torch_module._execution_manager(ort_model._is_training()) - - x = torch.randn(N, D_in, device=device) - _ = ort_model(x) - - input_info = _io.parse_inputs_for_onnx_export( - training_manager._module_parameters, - training_manager._onnx_models.exported_model, - training_manager._input_info.schema, - x, - {}, - ) +# @pytest.mark.parametrize( +# "model", +# [ +# (UnusedBeginParameterNet(784, 500, 400, 10)), +# (UnusedMiddleParameterNet(784, 500, 400, 10)), +# (UnusedEndParameterNet(784, 500, 400, 10)), +# ], +# ) +# def test_unused_parameters_does_not_unnecessarily_reinitialize(model): +# device = "cuda" + +# N, D_in, H1, H2, D_out = 64, 784, 500, 400, 10 +# model = model.to(device) +# ort_model = ORTModule(copy.deepcopy(model)) +# training_manager = ort_model._torch_module._execution_manager(ort_model._is_training()) + +# x = torch.randn(N, D_in, device=device) +# _ = ort_model(x) + +# input_info = _io.parse_inputs_for_onnx_export( +# training_manager._module_parameters, +# training_manager._exported_model, +# training_manager._input_info.schema, +# x, +# {}, +# ) - assert not training_manager._reinitialize_graph_builder(input_info) +# assert not training_manager._reinitialize_graph_builder(input_info) def test_load_state_dict_for_wrapped_ortmodule(): @@ -4783,7 +4833,7 @@ def forward(self, a): ort_model = ORTModule(pt_model) _ = ort_model(torch.randn(N, D_in, device=device)) - exported_model1 = ort_model._torch_module._execution_manager(True)._onnx_models.exported_model + exported_model1 = ort_model._torch_module._execution_manager(True)._exported_model for training_mode in [False, True]: assert ort_model._torch_module._execution_manager(training_mode)._original_model_has_changed is False @@ -4793,7 +4843,7 @@ def forward(self, a): assert ort_model._torch_module._execution_manager(training_mode)._original_model_has_changed is True _ = ort_model(torch.randn(N, D_in, device=device)) - exported_model2 = ort_model._torch_module._execution_manager(True)._onnx_models.exported_model + exported_model2 = ort_model._torch_module._execution_manager(True)._exported_model assert exported_model1 != exported_model2 @@ -5232,9 +5282,12 @@ def run_step(model, x): ort_prediction, ort_loss = run_step(ort_model, ort_x) pt_prediction, pt_loss = run_step(pt_model, pt_x) if step == 0: - model_onx = ort_model._torch_module._execution_manager._training_manager._onnx_models - for name in ["exported_model", "optimized_model"]: - onx = getattr(model_onx, name) + exported_model = ort_model._torch_module._execution_manager._training_manager._exported_model + # optimized_model = ort_model._torch_module._execution_manager._training_manager._optimized_model + for onx in [ + exported_model, + ]: + # "optimized_model"]: opv = None for op in onx.opset_import: if not op.domain: @@ -5275,7 +5328,7 @@ def test_opset_version_change(opset_version): prediction.backward() # Check opset version on ONNX model - exported_model = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_models.exported_model + exported_model = ort_model._torch_module._execution_manager(ort_model._is_training())._exported_model assert exported_model.opset_import[0].version == opset_version if original_env is not None: @@ -5710,9 +5763,6 @@ def run_step(model, input, positions): loss.backward() return loss - # batch_size = 3 - # sequence = 4 - if embed_is_sparse: input = torch.tensor([[0, 2, 3, 4], [2, 3, 1, 1], [1, 1, 1, 1]], device=device) else: From 527ccacb0af456332a498bf117fa7ad2e225705e Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Wed, 3 Jan 2024 17:18:59 +0000 Subject: [PATCH 02/32] save --- .../ortmodule/_graph_execution_manager.py | 698 ++-------------- .../ortmodule/_graph_transition_manager.py | 756 ++++++++++++++++++ .../training/ortmodule/_inference_manager.py | 43 +- .../python/training/ortmodule/_io.py | 180 ++--- .../training/ortmodule/_training_manager.py | 119 +-- .../ortmodule/_zero_stage3_compatibility.py | 10 +- .../python/training/ortmodule/ortmodule.py | 4 +- .../python/orttraining_test_ortmodule_api.py | 45 +- 8 files changed, 981 insertions(+), 874 deletions(-) create mode 100755 orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index ae6ccb1acdfff..3577ee5a04a33 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -4,14 +4,10 @@ # -------------------------------------------------------------------------- import copy -import inspect -import io import logging import os from abc import ABC, abstractmethod # noqa: F401 -from functools import partial -from hashlib import md5 as hash_fn -from typing import Dict, List, Mapping, Optional, OrderedDict, Sequence, Tuple +from typing import Dict, List, OrderedDict, Tuple import onnx import torch @@ -19,31 +15,17 @@ import onnxruntime from onnxruntime.capi import _pybind_state as C -from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference -from onnxruntime.training.utils import ( - ORTModelInputOutputSchemaType, - ORTModelInputOutputType, - PrimitiveType, - PTable, - onnx_dtype_to_pytorch_dtype, -) +from onnxruntime.training.utils import PTable, onnx_dtype_to_pytorch_dtype from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 -from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils -from ._fallback import ( - ORTModuleDeviceException, - ORTModuleONNXModelException, - ORTModuleTorchModelException, - _FallbackManager, - _FallbackPolicy, - wrap_exception, -) +from . import _are_deterministic_algorithms_enabled, _logger, _onnx_models, _utils +from ._fallback import ORTModuleTorchModelException, _FallbackManager, _FallbackPolicy, wrap_exception from ._gradient_accumulation_manager import GradientAccumulationManager from ._graph_execution_interface import GraphExecutionInterface +from ._graph_transition_manager import GraphTransitionManager, PostExportProcessedModelInfo from ._io import _FlattenedModule from ._runtime_inspector import RuntimeInspector -from ._utils import check_function_has_param, get_rank -from ._zero_stage3_compatibility import stage3_export_context +from ._utils import get_rank from .options import DebugOptions, LogLevel, _MemoryOptimizationLevel, _RuntimeOptions from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension @@ -58,313 +40,13 @@ def __init__(self, state, output_info: List[Tuple[torch.Size, torch.device, torc self.output_info = output_info -def _get_onnx_file_name(name_prefix, name, export_mode): - suffix = "training" if export_mode == torch.onnx.TrainingMode.TRAINING else "inference" - return f"{name_prefix}_{name}_{suffix}.onnx" - - -def _save_model(model: onnx.ModelProto, file_path: str): - onnx.save(model, file_path) - - -class StaticGraphManager: - def __init__(self): - # Export graph infos - - self._pre_export_graph_info = _io._PreExportGraphInfo() - self._data_accessor = None - - # Value can be either torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL - # To be instantiated in the concrete implementation of GraphExecutionManager - self._export_mode = None - # Exporter can take extra arguments for ORTModule extensions - # It cannot overlap with required/immutable arguments (validated in runtime) - self._export_extra_kwargs = {} - self._exported_graph_info = _io._ExportedGraphInfo() - self._module_parameters: List[inspect.Parameter] = [] - self._exported_model: Optional[onnx.ModelProto] = None - self._args_input_schema: Optional[ORTModelInputOutputSchemaType] = None - self._kwargs_input_schema: Optional[ORTModelInputOutputSchemaType] = None - - # Pre-grad graph infos - self._finalized_graph_info = _io._FinalizedGraphInfo() - self._finalized_model: Optional[onnx.ModelProto] = None - - # self._buffers_as_onnx_graph_input: List[torch.nn.parameter.Parameter] = [] # Cache the list of free buffers, which will be used as onnx graph inputs. - # self._params_as_onnx_graph_input: List[torch.nn.parameter.Parameter] = [] # Cache the list of parameters, which will be used as onnx graph inputs. - - def use_cached_exported_model_or_reexport( - self, - args: Sequence[ORTModelInputOutputType], - kwargs: Mapping[str, ORTModelInputOutputType], - device: Optional[torch.device], - ) -> Tuple[bool, _io._PreExportGraphInfo, _io._ExportedGraphInfo]: - """Create the exported graph if it does not exist, otherwise use the cached one""" - - need_export_model = not self._exported_model - need_export_model = need_export_model or self._original_model_has_changed - - # print("args: ", args, ", kwargs: ", kwargs) - - # Check graph inputs parsed from the model's forward function signature and current inputs, - # if they are different, we need to re-export the model. - pre_export_graph_info, data_accessor = _io.parse_inputs_for_onnx_export( - self._module_parameters, args, kwargs, True, device - ) - - # print(">>>>pre_export_graph_info.onnx_graph_input_names: ", pre_export_graph_info.onnx_graph_input_names) - # print( - # ">>>>pre_export_graph_info.onnx_graph_input_names_require_grad: ", - # pre_export_graph_info.onnx_graph_input_names_require_grad, - # ) - need_export_model = ( - need_export_model - or self._pre_export_graph_info.onnx_graph_input_names != pre_export_graph_info.onnx_graph_input_names - ) - - # Maybe we should also check parameters count or size, because user could modify the parameters after the export. - # But pre_export_graph_info did not contains any parameters as its inputs. - - # Extract the schema from the args and kwargs, and compare with the cached one. - # This check ideally is not needed as we already have the above check, but it is added as a safeguard. - flatten_args, args_schema = _io._extract_schema(copy.copy(args), device) - # print("!!!!!!!!!!!!!!!!!!kwargs", kwargs) - flatten_kwargs, kwargs_schema = _io._extract_schema(copy.copy(kwargs), device) - # print("!!!!!!!!!!!!!!!!!!flatten_kwargs", flatten_kwargs) - # schema = _io._extract_schema({"args": copy.copy(args), "kwargs": copy.copy(kwargs)}, device) - need_export_model = ( - need_export_model or args_schema != self._args_input_schema or kwargs_schema != self._kwargs_input_schema - ) - - if need_export_model: - # Set the schema before exporting the model, so that we can use the schema to unflatten the inputs - # during the flatten module forward run. - self._args_input_schema = args_schema - self._kwargs_input_schema = kwargs_schema - - def _unflatten_inputs( - num_positionals, - args_schema: Optional[ORTModelInputOutputSchemaType], - kwargs_schema: Optional[ORTModelInputOutputSchemaType], - inputs: Sequence[ORTModelInputOutputType], - ): - """Unflattens the inputs into args and kwargs - - The inputs are unflattened in the order they appear in the model's forward function arguments. - - Mainly used for PyTorch run for ONNX export. - """ - restored_args = _io.unflatten_data_using_schema(inputs[:num_positionals], args_schema) - restored_kwargs = _io.unflatten_data_using_schema(inputs[num_positionals:], kwargs_schema) - - return restored_args, restored_kwargs - - self._flattened_module._unflatten_functor = partial( - _unflatten_inputs, len(flatten_args), self._args_input_schema, self._kwargs_input_schema - ) - self._flattened_module.device = device - self._exported_model, exported_graph_info = self._export_model( - self._flattened_module, pre_export_graph_info, flatten_args + flatten_kwargs, {} - ) - self._pre_export_graph_info = pre_export_graph_info - self._data_accessor = data_accessor - - self._original_model_has_changed = False - - self._exported_graph_info = exported_graph_info - - # save the ortmodule exported model - if self._debug_options.save_onnx_models.save: - _save_model( - self._exported_model, - os.path.join( - self._debug_options.save_onnx_models.path, - _get_onnx_file_name( - self._debug_options.save_onnx_models.name_prefix, "torch_exported", self._export_mode - ), - ), - ) - - return need_export_model, pre_export_graph_info, self._exported_graph_info - - def _post_process( - self, exported_model: onnx.ModelProto, exported_graph_info: _io._ExportedGraphInfo - ) -> Tuple[onnx.ModelProto, _io._FinalizedGraphInfo]: - """Post process the exported model, for example, add extra information to the model proto""" - - # Deepcopy the exported model as pre-grad model, in case modification affects the exported model. - - # TODO(): Do pre-grad graph modification as needed, for memory efficient gradient management, etc. - # Currently, we don't do any modification, so just use the exported graph as pre-grad graph. - - finalized_model = copy.deepcopy(exported_model) - - finalized_graph_info = _io._FinalizedGraphInfo() - finalized_graph_info.onnx_graph_input_names = exported_graph_info.onnx_graph_input_names - finalized_graph_info.onnx_graph_input_names_require_grad = ( - exported_graph_info.onnx_graph_input_names_require_grad - ) - - self._finalized_model = finalized_model - self._finalized_graph_info = finalized_graph_info - - return finalized_model, finalized_graph_info - - def use_cached_pre_grad_model_or_reinitialize( - self, reexported_model: bool, pre_export_graph_info: _io._PreExportGraphInfo - ) -> bool: - if self._export_mode == torch.onnx.TrainingMode.TRAINING: - # initializer_names_to_train_set_user_model = [ - # name - # for name, param in self._flattened_module.named_parameters() - # if param.requires_grad and name in self._finalized_graph_info.onnx_graph_input_names - # ] - - if reexported_model: - pass - - # If inputs requiring gradient change from forward to the next, the module_gradient_graph_builder - # needs to be reinitialized so it can compute the backward output for the new inputs that require_grad - else: - # Reinitialize graph builder if the inputs or initializers requiring gradient have changed. - # This can happen when the user changes the model parameters after the onnx export. - # Model may have unused params dropped after export, so we only check those inputs existing in onnx graph. - - onnx_graph_input_requires_grads = [] - parameter_names = {k: v for k, v in self._flattened_module.named_parameters()} - for input_name in self._exported_graph_info.onnx_graph_input_names: - if input_name in parameter_names and parameter_names[input_name].requires_grad: - onnx_graph_input_requires_grads.append(input_name) - else: - # If not in parameter list, then it would come from user defined inputs. - if input_name in pre_export_graph_info.onnx_graph_input_names_require_grad: - onnx_graph_input_requires_grads.append(input_name) - - # print("onnx_graph_input_requires_grads: ", onnx_graph_input_requires_grads) - - if onnx_graph_input_requires_grads != self._exported_graph_info.onnx_graph_input_names_require_grad: - self._exported_graph_info.onnx_graph_input_names_require_grad = onnx_graph_input_requires_grads - else: - return False - - # print( - # "111111onnx_graph_input_names_require_grad: ", - # self._exported_graph_info.onnx_graph_input_names_require_grad, - # ) - self._finalized_model, self._finalized_graph_info = self._post_process( - self._exported_model, self._exported_graph_info - ) - else: - if not reexported_model: - return False - - self._finalized_model = self._exported_model - self._finalized_graph_info = _io._FinalizedGraphInfo() - self._finalized_graph_info.onnx_graph_input_names = self._exported_graph_info.onnx_graph_input_names - self._finalized_graph_info.onnx_graph_input_names_require_grad = ( - self._exported_graph_info.onnx_graph_input_names_require_grad - ) - - self._initializer_input_buffers_for_ort() - - print( - "_finalized_graph_info.onnx_graph_input_names_require_grad: ", - self._finalized_graph_info.onnx_graph_input_names_require_grad, - ) - print("_finalized_graph_info.onnx_graph_input_names: ", self._finalized_graph_info.onnx_graph_input_names) - print( - "o_finalized_graph_info.nnx_graph_input_names_user_defined: ", - self._finalized_graph_info._onnx_graph_input_names_user_defined, - ) - print( - "_finalized_graph_info.onnx_graph_input_names_require_grad_user_defined: ", - self._finalized_graph_info._onnx_graph_input_names_require_grad_user_defined, - ) - - self._initialize_graph_builder() - - return True - - def _initializer_input_buffers_for_ort(self): - parameter_names = {k: v for k, v in self._flattened_module.named_parameters()} - buffer_names = {k: v for k, v in self._flattened_module.named_buffers()} - for input_name in self._finalized_graph_info.onnx_graph_input_names: - if input_name in parameter_names: - self._finalized_graph_info._buffer_for_ort_runs[input_name] = parameter_names[input_name] - elif input_name in buffer_names: - self._finalized_graph_info._buffer_for_ort_runs[input_name] = buffer_names[input_name] - else: - self._finalized_graph_info._buffer_for_ort_runs[input_name] = None - # print(f"append new input_name into _onnx_graph_input_names_user_defined: {input_name}") - self._finalized_graph_info._onnx_graph_input_names_user_defined.append(input_name) - - if input_name in self._exported_graph_info.onnx_graph_input_names_require_grad: - self._finalized_graph_info._onnx_graph_input_names_require_grad_user_defined.append(input_name) - - # For user inputs, we will fill them dynamically during the forward run. - - # def flatten_inputs(self, args: Sequence[ORTModelInputOutputType], kwargs: Mapping[str, ORTModelInputOutputType], device): - # """Flattens the inputs and kwargs into a single tuple of inputs - - # The inputs are flattened in the order they appear in the model's forward function signature - # """ - # # Drop the schema directly, since we would assume self.args_input_schema and self.kwargs_input_schema are - # # always consistent with the model's forward function signature. - # flatten_args, args_schema = _io._extract_schema(args, device) - # flatten_kwargs, = _io._extract_schema(kwargs, device) - # self._num_positionals = len(flatten_args) - # return flatten_args + flatten_kwargs - - # def unflatten_inputs(self, inputs: Sequence[ORTModelInputOutputType]): - # """Unflattens the inputs into args and kwargs - - # The inputs are unflattened in the order they appear in the model's forward function arguments. - - # Mainly used for PyTorch run for ONNX export. - # """ - # restored_args = _io.unflatten_data_using_schema(inputs[: self._num_positionals], self._args_input_schema) - # restored_kwargs = _io.unflatten_data_using_schema(inputs[self._num_positionals :], self._kwargs_input_schema) - - # return restored_args, restored_kwargs - - def construct_inputs( - self, - inputs: Sequence[ORTModelInputOutputType], - kwargs: Mapping[str, ORTModelInputOutputType], - constant_as_tensor: bool, - device: torch.device, - ): - """Constructs the inputs for the forward method - - The inputs are constructed in the order they appear in the model's forward function signature - """ - # print("construct_inputs>>>>>", inputs, kwargs) - for name in self._finalized_graph_info._onnx_graph_input_names_user_defined: - if name in self._data_accessor: - assert name in self._finalized_graph_info._buffer_for_ort_runs, f"{name} is not in _buffer_for_ort_runs" - data = self._data_accessor[name](inputs, kwargs) - # print("data.requires_grad: ", data.requires_grad) - if PrimitiveType.is_primitive_type(data) and constant_as_tensor: - data = PrimitiveType.get_tensor(data, device) - self._finalized_graph_info._buffer_for_ort_runs[name] = data - else: - raise wrap_exception( - ORTModuleONNXModelException, - RuntimeError(f"Input is present in ONNX graph but not provided: {name}."), - ) - - # print("name of buffers: ", self._finalized_graph_info._buffer_for_ort_runs.keys()) - # print("name of onnx graph inputs: ", self._finalized_graph_info.onnx_graph_input_names) - - return self._finalized_graph_info._buffer_for_ort_runs - - -class GraphExecutionManager(GraphExecutionInterface, StaticGraphManager): +class GraphExecutionManager(GraphExecutionInterface): def __init__( self, module: _FlattenedModule, debug_options: DebugOptions, fallback_manager: _FallbackManager, + export_mode: int, logger: logging.Logger, ): """Manages construction and execution of ONNX graphs""" @@ -384,15 +66,32 @@ def __init__( # Original and flattened (transformed) output module self._flattened_module = module - # onnx models + # Device where the model is placed. + # self._device: Optional[torch.device] = _utils.get_device_from_module(module) + + # Model export and post export processing before inference optimization && building gradient. self._onnx_models = _onnx_models.ONNXModels() + self._export_mode = export_mode + self._graph_transition_manager = GraphTransitionManager( + flatten_module=module, + export_mode=export_mode, + save=debug_options.save_onnx_models.save, + save_path=debug_options.save_onnx_models.path, + save_name_prefix=debug_options.save_onnx_models.name_prefix, + deepcopy_before_model_export=self._runtime_options.deepcopy_before_model_export, + torch_exporter_verbose_log=self._debug_options.logging.log_level <= LogLevel.INFO, + enable_zero_stage3_support=self._runtime_options.enable_zero_stage3_support, + onnx_opset_version=self._runtime_options.onnx_opset_version, + enable_custom_autograd_function=self._runtime_options.enable_custom_autograd_function, + enable_symbolic_shape_infer=self._runtime_options.run_symbolic_shape_infer, + exported_model_cache_dir=self._runtime_options.ortmodule_cache_dir, + logger=logger, + ) + self._post_export_processed_model_info = None - # Model after inference optimization or gradient building. + # Model after inference optimization && gradient building. self._graph_builder = None self._graph_info = None - # self._graph_initializer_names = set() - # self._graph_initializer_names_to_train = set() - # self._graph_initializers: List[torch.nn.parameter.Parameter] = [] # TrainingAgent or InferenceAgent self._execution_agent = None @@ -406,32 +105,11 @@ def __init__( # Tracker for ORTModule model export, session creation overhead. self.time_tracker = _logger.TimeTracker() - # Input and output infos (including schema) for exported model. - # self._input_info: Optional[_InputInfo] = None - self._module_output_schema: Optional[ORTModelInputOutputSchemaType] = None - - # Device where the model is placed. - self._device: Optional[torch.device] = _utils.get_device_from_module(module) - - # Forward function input parameters of the original module. - self._module_parameters: List[inspect.Parameter] = list( - inspect.signature(self._original_module.forward).parameters.values() - ) - - # TODO: remove after PyTorch ONNX exporter supports VAR_KEYWORD parameters. - for input_parameter in self._module_parameters: - if input_parameter.kind == inspect.Parameter.VAR_KEYWORD: - self._logger.info("The model's forward method has **kwargs parameter which has EXPERIMENTAL support!") - self.is_rocm_pytorch = bool(torch.version.hip is not None and ROCM_HOME is not None) # WIP feature to enable caching in Gradient accumulation scenario. self._gradient_accumulation_manager = GradientAccumulationManager() - # Flag to re-export the model due to attribute change on the original module. - # Re-export will be avoided if _skip_check is enabled. - self._original_model_has_changed = False - # Load ATen operator executor extension. load_aten_op_executor_cpp_extension() @@ -564,257 +242,6 @@ def _get_session_config(self): return session_options, providers, provider_options - @_logger.TrackTime(_logger.ORTModuleInitPhase.EXPORT) - @_logger.SuppressLogs(_logger.ORTModuleInitPhase.EXPORT, is_ort_filter=False) - def _export_model( - self, - flattened_module: torch.nn.Module, - pre_export_graph_info: _io._PreExportGraphInfo, - args: Sequence[ORTModelInputOutputType], - kwargs: Mapping[str, ORTModelInputOutputType], - ) -> Tuple[onnx.ModelProto, _io._ExportedGraphInfo]: - # 1. Set the self._device from the user module - # 2. Verify input schema matches the schema used on the previous model export - # 3. Export the user model under self._export_training_flag mode - # Return True if the model needs to be exported, False if no export is required. - - # Note: Model is only exported when: - # 1. Model has never been exported before. - # 2. Model input schema has changed (changes in inputs requiring gradient, shape, boolean inputs values change, etc) - # Model is not re-exported when the model parameters change. This can happen when the model is stateful, - # or the user explicitly changed model parameters after the onnx export. - - # Record random states here and restore later in case any of them gets changed during the export, - # e.g., some sympy functions in symbolic_shape_infer will change Python's random state. - random_states = _utils.get_random_states() - - # schema = _io._extract_schema({"args": copy.copy(inputs), "kwargs": copy.copy(kwargs)}, self._device) - # if ( - # self._onnx_models.exported_model - # and schema == self._input_info.schema - # and not self._original_model_has_changed - # ): - # # All required models have already been exported previously - # return False - self._set_device_from_module(args, kwargs) - - from onnxruntime.training.utils.hooks._subscriber_manager import no_increase_global_step - - with no_increase_global_step(): - exported_model = self._get_exported_model(flattened_module, pre_export_graph_info, args, kwargs) - - exported_graph_info = _io._ExportedGraphInfo() - exported_graph_info.onnx_graph_input_names = [input.name for input in exported_model.graph.input] - parameter_names = [name for name, _ in flattened_module.named_parameters()] - exported_graph_info.onnx_graph_input_names_require_grad = [ - input.name - for input in exported_model.graph.input - if input.name in parameter_names or input.name in pre_export_graph_info.onnx_graph_input_names_require_grad - ] - # if self._debug_options.save_onnx_models.save: - # self._onnx_models.save_exported_model( - # self._debug_options.save_onnx_models.path, - # self._debug_options.save_onnx_models.name_prefix, - # self._export_mode, - # ) - - if self._runtime_options.run_symbolic_shape_infer: - exported_model = SymbolicShapeInference.infer_shapes( - exported_model, auto_merge=True, guess_output_rank=True - ) - - # Restore the recorded random states - _utils.set_random_states(random_states) - - return exported_model, exported_graph_info - - def _get_exported_model( - self, - flattened_module: torch.nn.Module, - pre_export_graph_info: _io._PreExportGraphInfo, - args: Sequence[ORTModelInputOutputType], - kwargs: Mapping[str, ORTModelInputOutputType], - ) -> onnx.ModelProto: - """Exports PyTorch `self._flattened_module` to ONNX for inferencing or training, - using `*inputs` and `**kwargs` as input - - TODO: How to support dynamic axes? Dimensions are determined by samples - """ - # kwargs = {} - # inputs = flatten_args + flatten_kwargs - # VERBOSE -> FULL export verbose log + FULL torch other logs from stdout and stderr (C++ backend) - # DEVINFO -> FULL export verbose log + FULL torch other logs from stdout and stderr (C++ backend) - # INFO -> [Rank 0] FULL export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend) - # WARNING/ERROR -> [Rank 0] NO export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend) - # Be noted: rank 0 log only is controlled by logger configured in _logger.py - torch_exporter_verbose_log = self._debug_options.logging.log_level <= LogLevel.INFO - - # Setup dynamic axes for onnx model - # self._input_info = _io.parse_inputs_for_onnx_export(self._module_parameters, None, input_schema, inputs, kwargs) - need_deep_copy = self._runtime_options.deepcopy_before_model_export and _io.can_module_be_deep_cloned( - flattened_module, self._device - ) - if not need_deep_copy: - if self._runtime_options.deepcopy_before_model_export: - self._logger.warning( - "Since the user requested not to deep copy this model, " - "the initial weights may not be preserved and could change slightly during the forward run. " - "This could cause a minor difference between the ORTModule and the PyTorch run for the " - "first iteration. The computation will proceed as normal, but this should be noted." - ) - else: - self._logger.warning( - "Due to the limited GPU memory execution manager does not create a deep copy of this model. " - "Therefore, the initial weights might be slightly altered during the forward run. " - "This could result in a minor discrepancy between the ORTModule and the PyTorch run for the " - "first iteration. The computation will continue as usual, but this should be noted." - ) - ( - output_names, - dynamic_axes, - self._module_output_schema, - ) = _io.parse_outputs_for_onnx_export_and_extract_schema( - flattened_module, args, kwargs, self._logger, self._device, need_deep_copy - ) - # self._input_info.dynamic_axes.update(output_dynamic_axes) - # Combine the dymaic axes from inputs and outputs - dynamic_axes.update(pre_export_graph_info.onnx_graph_input_dynamic_axes_map) - - # FlattenedModule needs _InputInfo to expand user input from *args to *args + **kwargs - # self._flattened_module._input_info = self._input_info - - self._logger.info("Exporting the PyTorch model to ONNX...") - - # Leverage cached model if available - cache_dir = self._runtime_options.ortmodule_cache_dir - if cache_dir: - filename = os.path.join( - cache_dir, f"{hash_fn(str(flattened_module).encode()).hexdigest()}_{get_rank()}.onnx" - ) - if os.path.exists(cache_dir) and os.path.isfile(filename): - self._logger.warning( - f"Cached model detected! Cached model will be used to save export and initialization time." - f"If you want the model to be re-exported then DELETE {filename}." - ) - exported_model = onnx.load(filename) - return exported_model - - # Export torch.nn.Module to ONNX - f = io.BytesIO() - print("pre_export_graph_info.onnx_graph_input_names: ", pre_export_graph_info.onnx_graph_input_names) - print( - "pre_export_graph_info.onnx_graph_input_names_require_grad: ", - pre_export_graph_info.onnx_graph_input_names_require_grad, - ) - # Deepcopy inputs, since input values may change after model run. - # NOTE: Inputs may contain tensors that have attributes preventing their deepcopy (example grad_fn). - # Therefore, deepcopy only the data component of the input tensors for export. - - sample_inputs_copy, sample_kwargs_copy = _io.deepcopy_model_input(*args, **kwargs) - assert len(sample_kwargs_copy) == 0, "Currently, kwargs are not supported for ONNX export." - sample_inputs_as_tuple = sample_inputs_copy - # NOTE: Flattening the input will change the 'input schema', resulting in a re-export - # sample_inputs_as_tuple = tuple(self._input_info.flatten(sample_inputs_copy, sample_kwargs_copy, self._device)) - # Ops behaving differently under train/eval mode need to be exported with the - # correct training flag to reflect the expected behavior. - # For example, the Dropout node in a model is dropped under eval mode. - assert self._export_mode is not None, "Please use a concrete instance of ExecutionManager" - - # print("sample_inputs_as_tuple: ", [v.shape for v in sample_inputs_as_tuple]) - # print("sample_inputs_as_tuple: ", [v.dtype for v in sample_inputs_as_tuple]) - # print("pre_export_graph_info.onnx_graph_input_names: ", pre_export_graph_info.onnx_graph_input_names) - # print( - # "pre_export_graph_info.onnx_graph_input_dynamic_axes_map: ", - # pre_export_graph_info.onnx_graph_input_dynamic_axes_map, - # ) - - try: - with torch.no_grad(), stage3_export_context(self._runtime_options.enable_zero_stage3_support, self): - required_export_kwargs = { - "input_names": pre_export_graph_info.onnx_graph_input_names, # did not contains paramerter as its input yet - "output_names": output_names, - "opset_version": self._runtime_options.onnx_opset_version, - "do_constant_folding": False, - "training": self._export_mode, - "dynamic_axes": dynamic_axes, - "verbose": torch_exporter_verbose_log, - "export_params": False, - "keep_initializers_as_inputs": True, - } - - if check_function_has_param(torch.onnx.export, "autograd_inlining"): - # From some PyTorch version, autograd_inlining is a valid argument. - # We allow it to be True if custom autograd function is disabled (where autograd.Function - # anyway is not supported in ONNX until it can be inlined). - required_export_kwargs[ - "autograd_inlining" - ] = not self._runtime_options.enable_custom_autograd_function - - invalid_args = self._export_extra_kwargs.keys() & required_export_kwargs.keys() - - if len(invalid_args) != 0: - error_msg = f"The following PyTorch exporter arguments cannot be specified: '{invalid_args}'." - raise RuntimeError(error_msg) - - torch.onnx.export( - flattened_module, - sample_inputs_as_tuple, - f, - **required_export_kwargs, - **self._export_extra_kwargs, - ) - except Exception as e: - raise wrap_exception( # noqa: B904 - ORTModuleONNXModelException, - RuntimeError( - f"There was an error while exporting the PyTorch model to ONNX: " - f"\n\n{_utils.get_exception_as_string(e)}" - ), - ) - exported_model = onnx.load_model_from_string(f.getvalue()) - - if self._runtime_options.enable_custom_autograd_function: - from ._custom_autograd_function_exporter import post_process_enabling_autograd_function - - exported_model = post_process_enabling_autograd_function(exported_model) - - if self._runtime_options.enable_zero_stage3_support: - from ._zero_stage3_compatibility import post_processing_enable_zero_stage3_compat - - exported_model = post_processing_enable_zero_stage3_compat( - exported_model, - self._zero_stage3_param_map, - [name for name, _ in flattened_module.named_parameters()], - ) - - # Cannot append pull weight trigger name to input names as following, otherwise, the later check ( - # https://github.com/microsoft/onnxruntime/blob/068300d97eb25e5b52324e7af54a45ed1fa6a4c3/orttraining/orttraining/python/training/ortmodule/_training_manager.py#L466C18-L466C18) - # find input info mismatch, will re-initialize the graph builder. - # self._input_info.require_grad_names.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME) - - # Cache model for future runs - if cache_dir: - if not os.path.exists(cache_dir): - os.makedirs(cache_dir, exist_ok=True) - filename = os.path.join( - cache_dir, f"{hash_fn(str(flattened_module).encode()).hexdigest()}_{get_rank()}.onnx" - ) - self._logger.info(f"Caching model for future runs to {filename}.") - onnx.save(exported_model, filename) - - return exported_model - - def _set_device_from_module(self, inputs, kwargs): - """Get the device from the module and save it to self._device""" - - device = _utils.get_device_from_module(self._original_module) or _utils.get_device_from_inputs(inputs, kwargs) - if not self._device or self._device != device: - self._device = device - if not self._device: - raise wrap_exception( - ORTModuleDeviceException, RuntimeError("A device must be specified in the model or inputs!") - ) - def _get_graph_transformer_config(self) -> C.TrainingGraphTransformerConfiguration: graph_transformer_config = C.TrainingGraphTransformerConfiguration() graph_transformer_config.propagate_cast_ops_config = C.PropagateCastOpsConfiguration() @@ -834,37 +261,25 @@ def _get_graph_transformer_config(self) -> C.TrainingGraphTransformerConfigurati return graph_transformer_config @_logger.TrackTime(_logger.ORTModuleInitPhase.GRAPH_BUILDER_INIT) - def _initialize_graph_builder(self): + def _initialize_graph_builder(self, post_export_processed_model_info: PostExportProcessedModelInfo): """Creates a new OrtModuleGraphBuilder, initializes it and saves it to self._graph_builder""" - # All initializer names along with user inputs are a part of the onnx graph inputs - # since the onnx model was exported with the flag keep_initializers_as_inputs=True - # onnx_initializer_names = {p.name for p in self._onnx_models.exported_model.graph.input} - - # onnx_initializer_names = self._finalized_graph_info.onnx_graph_input_names - - # # TODO: PyTorch exporter bug: changes the initializer order in ONNX model - # initializer_names = [ - # name for name, _ in self._flattened_module.named_parameters() if name in onnx_initializer_names - # ] - # initializer_names_to_train = [ - # name - # for name, param in self._flattened_module.named_parameters() - # if param.requires_grad and name in onnx_initializer_names - # ] - # Build and optimize the full graph grad_builder_config = C.OrtModuleGraphBuilderConfiguration() - grad_builder_config.initializer_names = self._finalized_graph_info.onnx_graph_input_names - grad_builder_config.initializer_names_to_train = self._finalized_graph_info.onnx_graph_input_names_require_grad - - # input_names_require_grad = self._input_info.require_grad_names - input_names_require_grad = self._finalized_graph_info._onnx_graph_input_names_require_grad_user_defined + grad_builder_config.initializer_names = ( + post_export_processed_model_info.onnx_graph_input_names + ) # containing both user defined and buffers/parameters. + grad_builder_config.initializer_names_to_train = ( + post_export_processed_model_info.onnx_graph_input_names_require_grad + ) # containing both user defined and parameters requiring gradients. + + input_names_require_grad = post_export_processed_model_info.onnx_graph_input_names_require_grad_user_defined if self._runtime_options.enable_zero_stage3_support: from ._zero_stage3_compatibility import STAGE3_PULL_WEIGHT_TRIGGER_NAME # Add stage3 pull weight trigger name to require_grad_names, so that it will be included in the gradient graph. input_names_require_grad.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME) + grad_builder_config.input_names_require_grad = input_names_require_grad grad_builder_config.build_gradient_graph = self._export_mode == torch.onnx.TrainingMode.TRAINING grad_builder_config.enable_caching = self._runtime_options.enable_grad_acc_optimization @@ -876,22 +291,9 @@ def _initialize_graph_builder(self): # It is assumed here that the order and names of the inputs and outputs are not modified by the backend in any way # and are kept as they appear in the exported onnx model. - self._graph_builder.initialize(self._finalized_model.SerializeToString(), grad_builder_config) - - # TODO: Explore ways to make self._graph_info.initializer_names and self._graph_info.initializer_names_to_train - # a set (unordered_set in the backend) that does not require a copy on each reference. - # self._graph_initializer_names = set(initializer_names) - # self._graph_initializer_names_to_train = set(initializer_names_to_train) - - # Initializers can be cached and used since they are expected not to be re-instantiated - # between forward calls. - # self._graph_initializers = [ - # param for name, param in self._flattened_module.named_parameters() if name in self._graph_initializer_names - # ] - - def signal_model_changed(self): - """Signals the execution manager to re-export the model on the next forward call""" - self._original_model_has_changed = True + self._graph_builder.initialize( + post_export_processed_model_info._post_export_processed_model.SerializeToString(), grad_builder_config + ) def __getstate__(self): state = copy.copy(self.__dict__) @@ -915,6 +317,10 @@ def __setstate__(self, state): _utils.reinitialize_graph_execution_manager(self) + @property + def _device(self): + return self._graph_transition_manager._device + @_logger.TrackTime(_logger.ORTModuleInitPhase.DETECTION) def _enable_conditional_optimizations( self, graph_transformer_config: C.TrainingGraphTransformerConfiguration, inputs: Tuple, kwargs: Dict @@ -932,7 +338,8 @@ def _enable_conditional_optimizations( # Enable data sparsity inspection if sparse optimizer is ON or user wants to print input density. if self._runtime_options.enable_sparse_optimizer or self._runtime_options.print_input_density: self._runtime_inspector.enable_input_inspector( - self._exported_model, self._pre_export_graph_info.onnx_graph_input_names + self._graph_transition_manager._exported_model_info.exported_model, + self._graph_transition_manager._model_info_for_export.onnx_graph_input_names, ) if self._runtime_options.enable_sparse_optimizer: @@ -943,7 +350,9 @@ def _enable_conditional_optimizations( if self._runtime_options.enable_zero_stage3_support: self._append_pull_weight_trigger_as_input(kwargs, detected_device) - prepared_input_map = self.construct_inputs(inputs, kwargs, True, self._device) + prepared_input_map = self._graph_transition_manager._post_export_processed_model_info.construct_inputs( + inputs, kwargs, True + ) embed_sparsity_results = OrderedDict() label_sparsity_results = OrderedDict() @@ -960,7 +369,8 @@ def _enable_conditional_optimizations( and not self._runtime_inspector.memory_ob.symbolic_dim_collecting_completed ): self._runtime_inspector.memory_ob.collect_symbolic_dim_values( - self._finalized_graph_info.onnx_graph_input_dynamic_axes_map, prepared_input_map + self._graph_transition_manager._post_export_processed_model_info.onnx_graph_input_dynamic_axes_map, + prepared_input_map, ) self._runtime_inspector.memory_ob.symbolic_dim_collecting_completed = True diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py new file mode 100755 index 0000000000000..eeb63ac4c20e0 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -0,0 +1,756 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from __future__ import annotations + +import copy +import inspect +import io +import logging +import os +from collections import OrderedDict +from hashlib import md5 as hash_fn +from typing import Mapping, Sequence + +import onnx +import torch + +from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference +from onnxruntime.training.utils import ( + ORTModelInputOutputSchemaType, + ORTModelInputOutputType, + PrimitiveType, + unflatten_data_using_schema, +) +from onnxruntime.training.utils.torch_io_helper import _TensorStub + +from . import _io, _utils +from ._fallback import ORTModuleDeviceException, ORTModuleIOError, ORTModuleONNXModelException, wrap_exception +from ._utils import check_function_has_param, get_rank +from ._zero_stage3_compatibility import stage3_export_context + + +def _get_onnx_file_name(name_prefix, name, export_mode): + suffix = "training" if export_mode == torch.onnx.TrainingMode.TRAINING else "inference" + return f"{name_prefix}_{name}_{suffix}.onnx" + + +def _save_model(model: onnx.ModelProto, file_path: str): + onnx.save(model, file_path) + + +class ExportedModelInfo: + def __init__( + self, + module_forward_args_schema: ORTModelInputOutputSchemaType, + module_forward_kwargs_schema: ORTModelInputOutputSchemaType, + onnx_graph_input_names: list[str], + onnx_graph_input_names_require_grad: list[str], + exported_model: onnx.ModelProto, + module_forward_output_schema: ORTModelInputOutputSchemaType, + ): + # Input names parsed and then flatten from the model's forward function signature + buffers + parameters (since we use + # keep_initializers_as_inputs=True for model export) + # Be noted: all inputs are used by the model for its compute. + self.onnx_graph_input_names: list[str] = onnx_graph_input_names + + # A subset of onnx_graph_input_names. + # Input names that require gradient parsed and then flatten from the model's forward function signature + # This should contains both the user input names, the buffer names, and the parameter names (since we use + # keep_initializers_as_inputs=True for model export) + # Be noted: all inputs are used by the model for its compute. + self.onnx_graph_input_names_require_grad: list[str] = onnx_graph_input_names_require_grad + + # Exported model proto. + self.exported_model: onnx.ModelProto | None = exported_model + + self.module_forward_args_schema: ORTModelInputOutputSchemaType | None = module_forward_args_schema + self.module_forward_kwargs_schema: ORTModelInputOutputSchemaType | None = module_forward_kwargs_schema + + self.module_forward_output_schema: ORTModelInputOutputSchemaType | None = module_forward_output_schema + + def __str__(self): + return f"""ExportedModelInfo class: + \tonnx_graph_input_names: {self.onnx_graph_input_names} + \tonnx_graph_input_names_require_grad: {self.onnx_graph_input_names_require_grad} + \tmodule_forward_args_schema: {self.module_forward_args_schema} + \tmodule_forward_kwargs_schema: {self.module_forward_kwargs_schema} + \tmodule_forward_output_schema: {self.module_forward_output_schema} + """ + + def __repro__(self): + return self.__str__() + + +class PostExportProcessedModelInfo: + def __init__( + self, + flatten_module: torch.nn.Module, + device: torch.device | None, + onnx_graph_input_names_user_defined: list[str], + onnx_graph_input_names_require_grad_user_defined: list[str], + onnx_graph_input_names: list[str], + onnx_graph_input_names_require_grad: list[str], + onnx_graph_input_dynamic_axes_map: dict[str, dict[int, str]], + module_forward_output_schema: ORTModelInputOutputSchemaType, + _post_export_processed_model: onnx.ModelProto, + data_accessor: list[callable], + ): + self.device = device + + self._flattened_module = flatten_module + + # Input names for the pre-gradient-build graph. + # This may be different with the one in ExportedGraph since we may modify the graph inputs as needed + # for example when memory efficient gradient management is enabled. + self.onnx_graph_input_names: list[str] = onnx_graph_input_names + + # A subset of onnx_graph_input_names. + # Input names that require gradients for the pre-gradient-build graph. + self.onnx_graph_input_names_require_grad: list[str] = onnx_graph_input_names_require_grad + + # Create symbolic names for each dimension of the graph input (e.g. onnx_graph_input_names). + # The key is the input name, the value is a dict of {dim_index: symbolic_dim_name} + # e.g. {"input1": {0: "input1_dim0", 1: "input1_dim1"}, "input2": {0: "input2_dim0"}} + self.onnx_graph_input_dynamic_axes_map: dict[str, dict[int, str]] = onnx_graph_input_dynamic_axes_map + + self.buffer_for_ort_runs: dict[str, torch.Tensor] = OrderedDict() + self.onnx_graph_input_names_user_defined = ( + onnx_graph_input_names_user_defined # The ONNX graph input names excluding the parameters, buffers. + ) + + # The ONNX graph input names excluding the parameters, buffers. + self.onnx_graph_input_names_require_grad_user_defined = onnx_graph_input_names_require_grad_user_defined + + self._post_export_processed_model: onnx.ModelProto | None = _post_export_processed_model + + # A function to access the input data from the args and kwargs. + # If it is not None, the length is same as onnx_graph_input_names. + # For i-th input name, we can use the i-th function to get the input data from args and kwargs. + self.data_accessor: list[callable] | None = data_accessor + + self.module_forward_output_schema: ORTModelInputOutputSchemaType | None = module_forward_output_schema + + # Create the buffers for the inputs that are either parameters or buffers in the original module. + # For user inputs, fill with None for now, and will be filled dynamically during the forward run. + parameter_names = {k: v for k, v in self._flattened_module.named_parameters()} + buffer_names = {k: v for k, v in self._flattened_module.named_buffers()} + for input_name in self.onnx_graph_input_names: + if input_name in parameter_names: + self.buffer_for_ort_runs[input_name] = parameter_names[input_name] + elif input_name in buffer_names: + self.buffer_for_ort_runs[input_name] = buffer_names[input_name] + else: + self.buffer_for_ort_runs[input_name] = None + + def __str__(self): + return f"""PostExportProcessedModelInfo class: + \tdevice: {self.device} + \tonnx_graph_input_names: {self.onnx_graph_input_names} + \tonnx_graph_input_names_require_grad: {self.onnx_graph_input_names_require_grad} + \tonnx_graph_input_dynamic_axes_map: {self.onnx_graph_input_dynamic_axes_map} + \tonnx_graph_input_names_user_defined: {self.onnx_graph_input_names_user_defined} + \tonnx_graph_input_names_require_grad_user_defined: {self.onnx_graph_input_names_require_grad_user_defined} + \tbuffer_for_ort_runs.keys(): {self.buffer_for_ort_runs.keys()} + """ + + def __repro__(self): + return self.__str__() + + def construct_inputs( + self, + args: Sequence[ORTModelInputOutputType], + kwargs: Mapping[str, ORTModelInputOutputType], + constant_as_tensor: bool, + ): + """Constructs the inputs for the forward method + + The inputs are constructed in the order they appear in the model's forward function signature + """ + # print("construct_inputs>>>>>", inputs, kwargs) + for name in self.onnx_graph_input_names_user_defined: + if name in self.data_accessor: + assert name in self.buffer_for_ort_runs, f"{name} is not in buffer_for_ort_runs" + data = self.data_accessor[name](args, kwargs) + # print("data.requires_grad: ", data.requires_grad) + if PrimitiveType.is_primitive_type(data) and constant_as_tensor: + data = PrimitiveType.get_tensor(data, self.device) + self.buffer_for_ort_runs[name] = data + else: + raise wrap_exception( + ORTModuleONNXModelException, + RuntimeError(f"Input is present in ONNX graph but not provided: {name}."), + ) + + return self.buffer_for_ort_runs + + def restore_outputs(self, ort_flatten_outputs: list[torch.Tensor]): + """Restores the outputs from the ORT forward run, back to the original data structure""" + try: + # Need to distinguish between a single output and a tuple (having a single tensor) + if len(ort_flatten_outputs) == 1 and self.module_forward_output_schema is _TensorStub: + return ort_flatten_outputs[0] + return unflatten_data_using_schema(ort_flatten_outputs, self.module_forward_output_schema) + except TypeError as e: + raise wrap_exception( + ORTModuleIOError, + TypeError(f"ORTModule fails to unflatten user output: {e}"), + ) from None + + +class GraphTransitionManager: + """Manage the graph transition from 1). PyTorch to ONNX export and 2). ONNX to ONNX post-export processing.""" + + def __init__( + self, + flatten_module: torch.nn.Module, + export_mode: int, + save: bool, + save_path: str, + save_name_prefix: str, + deepcopy_before_model_export: bool, + torch_exporter_verbose_log: bool, + enable_zero_stage3_support: bool, + onnx_opset_version: int, + enable_custom_autograd_function: bool, + enable_symbolic_shape_infer: bool, + exported_model_cache_dir: str, + logger: logging.Logger, + ): + self._device = _utils.get_device_from_module(flatten_module) + self._export_mode = export_mode + + # Debug options + self._save = save + self._save_path = save_path + self._save_name_prefix = save_name_prefix + + # Runtime options + self._deepcopy_before_model_export = deepcopy_before_model_export + self._torch_exporter_verbose_log = torch_exporter_verbose_log + self._enable_zero_stage3_support = enable_zero_stage3_support + self._onnx_opset_version = onnx_opset_version + self._enable_custom_autograd_function = enable_custom_autograd_function + self._run_symbolic_shape_infer = enable_symbolic_shape_infer + self._ortmodule_cache_dir = exported_model_cache_dir + + self._logger = logger + + # A signal to indicate if the original model has changed and need a re-export. + self._original_model_has_changed = False + + self._flatten_module = flatten_module + + # Forward function input parameters of the original module. + self._module_forward_func_parameters: list[inspect.Parameter] = list( + inspect.signature(self._flatten_module._original_module.forward).parameters.values() + ) + # TODO: remove after PyTorch ONNX exporter supports VAR_KEYWORD parameters. + for input_parameter in self._module_forward_func_parameters: + if input_parameter.kind == inspect.Parameter.VAR_KEYWORD: + logger.info("The model's forward method has **kwargs parameter which has EXPERIMENTAL support!") + + # Model info collected from the original module's forward function signature and args/kwargs, used for ONNX export. + self._model_info_for_export: _io.ModelInfoForExport | None = None + self._exported_model_info: ExportedModelInfo | None = None + + # Model info after export and post export processing. + self._post_export_processed_model_info = None + + @staticmethod + def _export_check( + prev_model_info_for_export: _io.ModelInfoForExport, + prev_exported_model_info: ExportedModelInfo, + args: Sequence[ORTModelInputOutputType], + kwargs: Mapping[str, ORTModelInputOutputType], + device: torch.device | None, + original_model_has_changed: bool, + module_forward_func_parameters: list[inspect.Parameter], + export_mode: int, + logger: logging.Logger, + ): + """Check if the model needs to be exported, if yes, return True. + + For the following cases, return True: + 1. The model has never been exported before. + 2. The model's input names parsed from args and kwargs changed. + 3. The model input schema parsed from args and kwargs has changed. + """ + + need_export_model = prev_exported_model_info is None # never exported before + + need_export_model = need_export_model or original_model_has_changed + + # Check graph inputs parsed from the model's forward function signature and current inputs, + # if they are different, we need to re-export the model. + model_info_for_export = _io.parse_inputs_for_onnx_export( + module_forward_func_parameters, args, kwargs, True, device, export_mode + ) + + need_export_model = ( + need_export_model + or prev_model_info_for_export.onnx_graph_input_names != model_info_for_export.onnx_graph_input_names + ) + + # Extract the schema from the args and kwargs, and compare with the cached one. + # This check ideally is not needed as we already have the above check, but it is added as a safeguard. + flatten_args, args_schema = _io._extract_schema(copy.copy(args), device) + flatten_kwargs, kwargs_schema = _io._extract_schema(copy.copy(kwargs), device) + + need_export_model = ( + need_export_model + or args_schema != prev_exported_model_info.module_forward_args_schema + or kwargs_schema != prev_exported_model_info.module_forward_kwargs_schema + ) + + logger.info( + f"_export_check completed - need_export_model: {need_export_model}, model_info_for_export: {model_info_for_export}" + ) + + return need_export_model, model_info_for_export, flatten_args, flatten_kwargs, args_schema, kwargs_schema + + @staticmethod + def _reprocess_check( + flatten_module, prev_exported_model_info, export_mode, cur_model_info_for_export: _io.ModelInfoForExport + ) -> bool: + """Check if the exported model needs to be re-processed, if yes, return True and the updated onnx_graph_input_requires_grads. + + For the following cases, return True: + 1. The export mode is TRAINING and the model's input names (including both user input and module parameters) + requiring gradient change. + """ + if export_mode == torch.onnx.TrainingMode.TRAINING: + # If inputs requiring gradient change from forward to the next, the module_gradient_graph_builder + # needs to be reinitialized so it can compute the backward output for the new inputs that require_grad + + # Reinitialize graph builder if the inputs or initializers requiring gradient have changed. + # This can happen when the user changes the model parameters after the onnx export. + # Model may have unused params dropped after export, so we only check those inputs existing in onnx graph. + + onnx_graph_input_requires_grads = [] + parameter_names = {k: v for k, v in flatten_module.named_parameters()} + for input_name in prev_exported_model_info.onnx_graph_input_names: + if input_name in parameter_names and parameter_names[input_name].requires_grad: + onnx_graph_input_requires_grads.append(input_name) + else: + # If not in parameter list, then it would come from user defined inputs. + if input_name in cur_model_info_for_export.onnx_graph_input_names_require_grad: + onnx_graph_input_requires_grads.append(input_name) + + if onnx_graph_input_requires_grads == prev_exported_model_info.onnx_graph_input_names_require_grad: + return False, [] + return True, onnx_graph_input_requires_grads + + return False, [] + + @staticmethod + def _post_export_process( + flatten_module, + device, + exported_model_info: ExportedModelInfo, + model_info_for_export: _io.ModelInfoForExport, + logger: logging.Logger, + ): + """Post process the exported model, generate the processed model which will be used for initializing graph builder.""" + + # Deepcopy the exported model, in case modification affects the exported model. + + # TODO(): Do pre-grad graph modification as needed, for memory efficient gradient management, etc. + # Currently, we don't do any modifications. + + post_processed_model = copy.deepcopy(exported_model_info.exported_model) + + # Get the intersection of all user defined input names (parsed from forward function signature) and + # the exported model input names including both user defined names and parameter/buffer names. + # It is possible some user defined input names are not in the exported model input names, if it is not used + # by the model for its compute. + onnx_graph_input_names_user_defined = [ + input_name + for input_name in model_info_for_export.onnx_graph_input_names + if input_name in exported_model_info.onnx_graph_input_names + ] + onnx_graph_input_names_require_grad = [ + input_name + for input_name in model_info_for_export.onnx_graph_input_names_require_grad + if input_name in exported_model_info.onnx_graph_input_names_require_grad + ] + + post_export_processed_model_info = PostExportProcessedModelInfo( + flatten_module, + device, + onnx_graph_input_names_user_defined, + onnx_graph_input_names_require_grad, + exported_model_info.onnx_graph_input_names, + exported_model_info.onnx_graph_input_names_require_grad, + model_info_for_export.onnx_graph_input_dynamic_axes_map, + exported_model_info.module_forward_output_schema, + post_processed_model, + model_info_for_export.data_accessor, + ) + + logger.info( + f"_post_export_process completed, post-export processed graph infos: {post_export_processed_model_info}" + ) + + return post_export_processed_model_info + + def use_cache_or_reconstruct( + self, args: Sequence[ORTModelInputOutputType], kwargs: Mapping[str, ORTModelInputOutputType] + ) -> tuple[bool, PostExportProcessedModelInfo]: + """Check if the model can be reused, otherwise, reconstruct the model. + + Return True if the model can be reused, otherwise, return False. + The model can be reused when the following conditions are met: + a. The model has been exported before, and the inputs (args/outputs) schemas are the same as the previous ones. + b. The graph inputs requiring gradient are the same as the previous ones. + + """ + ( + need_export_model, + cur_model_info_for_export, + flatten_args, + flatten_kwargs, + cur_args_schema, + cur_kwargs_schema, + ) = GraphTransitionManager._export_check( + prev_exported_model_info=self._model_info_for_export, + prev_model_info_for_export=self._exported_model_info, + args=args, + kwargs=kwargs, + device=self._device, + original_model_has_changed=self._original_model_has_changed, + module_forward_func_parameters=self._module_forward_func_parameters, + export_mode=self._export_mode, + logger=self._logger, + ) + + if need_export_model: + # Set the information used to unflatten the inputs during the flatten module forward run. + # Must be set before calling exporting the model. + self._flatten_module._device = self._device + self._flatten_module._args_schema = cur_args_schema + self._flatten_module._kwargs_schema = cur_kwargs_schema + self._flatten_module._num_positionals = len(flatten_args) + + flatten_inputs = flatten_args + flatten_kwargs + self._set_device_from_module(flatten_inputs, {}) + + # Start exporting the model by passing the 1-D flatten tensor list containing all args plus kwargs. + ( + exported_model, + module_output_schema, + onnx_graph_input_names, + onnx_graph_input_names_require_grad, + ) = GraphTransitionManager._export_model( + flattened_module=self._flatten_module, + model_info_for_export=cur_model_info_for_export, + flatten_module_inputs=flatten_inputs, + run_symbolic_shape_infer=self._run_symbolic_shape_infer, + deepcopy_before_model_export=self._deepcopy_before_model_export, + device=self._device, + ortmodule_cache_dir=self._ortmodule_cache_dir, + enable_custom_autograd_function=self._enable_custom_autograd_function, + enable_zero_stage3_support=self._enable_zero_stage3_support, + onnx_opset_version=self._onnx_opset_version, + torch_exporter_verbose_log=self._torch_exporter_verbose_log, + stage3_param_handle=self, + logger=self._logger, + ) + + self._exported_model_info = ExportedModelInfo( + module_forward_args_schema=cur_args_schema, + module_forward_kwargs_schema=cur_kwargs_schema, + onnx_graph_input_names=onnx_graph_input_names, + onnx_graph_input_names_require_grad=onnx_graph_input_names_require_grad, + exported_model=exported_model, + module_forward_output_schema=module_output_schema, + ) + + self._model_info_for_export = cur_model_info_for_export + + # Reset the signal to indicate the original model has changed. + self._original_model_has_changed = False + + # Save the exported model + if self._save: + _save_model( + self._exported_model_info.exported_model, + os.path.join( + self._save_path, + _get_onnx_file_name(self._save_name_prefix, "torch_exported", self._export_mode), + ), + ) + + self._logger.info(f"do_export completed, exported graph infos: {self._exported_model_info}") + + need_re_processed = False + if need_export_model: + need_re_processed = True + else: + need_re_processed, updated_onnx_graph_input_requires_grads = GraphTransitionManager._reprocess_check( + flatten_module=self._flatten_module, + prev_exported_model_info=self._exported_model_info, + export_mode=self._export_mode, + cur_model_info_for_export=cur_model_info_for_export, + ) + if need_re_processed: + # Update the onnx_graph_input_names_require_grads to make it a new default baseline to compare using new iteration data. + self._exported_model_info.onnx_graph_input_names_require_grad = updated_onnx_graph_input_requires_grads + + if need_re_processed: + # At this point, the exported model is ready, and we can start post export processing. + self._post_export_processed_model_info = GraphTransitionManager._post_export_process( + flatten_module=self._flatten_module, + device=self._device, + exported_model_info=self._exported_model_info, + model_info_for_export=self._model_info_for_export, + logger=self._logger, + ) + + # Save the post_processed model + if self._save: + _save_model( + self._post_export_processed_model_info._post_export_processed_model, + os.path.join( + self._save_path, + _get_onnx_file_name(self._save_name_prefix, "post_processed", self._export_mode), + ), + ) + + return need_re_processed, self._post_export_processed_model_info + + # @_logger.TrackTime(_logger.ORTModuleInitPhase.EXPORT) + # @_logger.SuppressLogs(_logger.ORTModuleInitPhase.EXPORT, is_ort_filter=False) + @staticmethod + def _export_model( + flattened_module: torch.nn.Module, + model_info_for_export: _io.ModelInfoForExport, + flatten_module_inputs: Sequence[ORTModelInputOutputType], + run_symbolic_shape_infer: bool, + deepcopy_before_model_export: bool, + device: torch.device, + ortmodule_cache_dir: str, + enable_custom_autograd_function: bool, + enable_zero_stage3_support: bool, + onnx_opset_version: int, + torch_exporter_verbose_log: bool, + stage3_param_handle: type, + logger: logging.Logger, + ) -> tuple[onnx.ModelProto, ORTModelInputOutputSchemaType, list[str], list[str]]: + # Record random states here and restore later in case any of them gets changed during the export, + # e.g., some sympy functions in symbolic_shape_infer will change Python's random state. + random_states = _utils.get_random_states() + + from onnxruntime.training.utils.hooks._subscriber_manager import no_increase_global_step + + with no_increase_global_step(): + exported_model, module_output_schema = GraphTransitionManager._get_exported_model( + flattened_module=flattened_module, + model_info_for_export=model_info_for_export, + flatten_module_inputs=flatten_module_inputs, + deepcopy_before_model_export=deepcopy_before_model_export, + device=device, + ortmodule_cache_dir=ortmodule_cache_dir, + enable_custom_autograd_function=enable_custom_autograd_function, + enable_zero_stage3_support=enable_zero_stage3_support, + onnx_opset_version=onnx_opset_version, + torch_exporter_verbose_log=torch_exporter_verbose_log, + stage3_param_handle=stage3_param_handle, + logger=logger, + ) + + onnx_graph_input_names = [input.name for input in exported_model.graph.input] + parameter_names = [name for name, _ in flattened_module.named_parameters()] + onnx_graph_input_names_require_grad = [ + input.name + for input in exported_model.graph.input + if input.name in parameter_names or input.name in model_info_for_export.onnx_graph_input_names_require_grad + ] + + if run_symbolic_shape_infer: + exported_model = SymbolicShapeInference.infer_shapes( + exported_model, auto_merge=True, guess_output_rank=True + ) + + # Restore the recorded random states + _utils.set_random_states(random_states) + + return exported_model, module_output_schema, onnx_graph_input_names, onnx_graph_input_names_require_grad + + @staticmethod + def _get_exported_model( + flattened_module: torch.nn.Module, + model_info_for_export: _io.ModelInfoForExport, + flatten_module_inputs: Sequence[ORTModelInputOutputType], + deepcopy_before_model_export: bool, + device: torch.device, + ortmodule_cache_dir: str, + enable_custom_autograd_function: bool, + enable_zero_stage3_support: bool, + onnx_opset_version: int, + torch_exporter_verbose_log: bool, + stage3_param_handle: type, + logger: logging.Logger, + ) -> tuple[onnx.ModelProto, ORTModelInputOutputSchemaType]: + """Exports PyTorch `flattened_module` to ONNX for inferencing or training.""" + + need_deep_copy = deepcopy_before_model_export and _io.can_module_be_deep_cloned(flattened_module, device) + if not need_deep_copy: + if deepcopy_before_model_export: + logger.warning( + "Since the user requested not to deep copy this model, " + "the initial weights may not be preserved and could change slightly during the forward run. " + "This could cause a minor difference between the ORTModule and the PyTorch run for the " + "first iteration. The computation will proceed as normal, but this should be noted." + ) + else: + logger.warning( + "Due to the limited GPU memory execution manager does not create a deep copy of this model. " + "Therefore, the initial weights might be slightly altered during the forward run. " + "This could result in a minor discrepancy between the ORTModule and the PyTorch run for the " + "first iteration. The computation will continue as usual, but this should be noted." + ) + ( + output_names, + dynamic_axes, + module_output_schema, + ) = _io.parse_outputs_for_onnx_export_and_extract_schema( + flattened_module, flatten_module_inputs, logger, need_deep_copy + ) + + # Combine the dynamic axes from inputs and outputs + dynamic_axes.update(model_info_for_export.onnx_graph_input_dynamic_axes_map) + + logger.info("Exporting the PyTorch model to ONNX...") + + # Leverage cached model if available + cache_dir = ortmodule_cache_dir + if cache_dir: + filename = os.path.join( + cache_dir, f"{hash_fn(str(flattened_module).encode()).hexdigest()}_{get_rank()}.onnx" + ) + if os.path.exists(cache_dir) and os.path.isfile(filename): + logger.warning( + f"Cached model detected! Cached model will be used to save export and initialization time." + f"If you want the model to be re-exported then DELETE {filename}." + ) + exported_model = onnx.load(filename) + return exported_model, module_output_schema + + # Export torch.nn.Module to ONNX + f = io.BytesIO() + print("pre_export_graph_info.onnx_graph_input_names: ", model_info_for_export.onnx_graph_input_names) + print( + "pre_export_graph_info.onnx_graph_input_names_require_grad: ", + model_info_for_export.onnx_graph_input_names_require_grad, + ) + + # Deepcopy inputs, since input values may change after model run. + # NOTE: Inputs may contain tensors that have attributes preventing their deepcopy (example grad_fn). + # Therefore, deepcopy only the data component of the input tensors for export. + kwargs = {} + sample_inputs_copy, sample_kwargs_copy = _io.deepcopy_model_input(*flatten_module_inputs, **kwargs) + assert len(sample_kwargs_copy) == 0, "Currently, kwargs are not supported for ONNX export." + sample_inputs_as_tuple = sample_inputs_copy + + print(">>>shapes for the flatten args and kwargs", [v.shape for v in sample_inputs_as_tuple]) + + # Ops behaving differently under train/eval mode need to be exported with the + # correct training flag to reflect the expected behavior. + # For example, the Dropout node in a model is dropped under eval mode. + assert model_info_for_export.export_mode is not None, "Please use a concrete instance of ExecutionManager" + + try: + with torch.no_grad(), stage3_export_context( + enable_zero_stage3_support, stage3_param_handle, flattened_module + ): + required_export_kwargs = { + "input_names": model_info_for_export.onnx_graph_input_names, # did not contains paramerter as its input yet + "output_names": output_names, + "opset_version": onnx_opset_version, + "do_constant_folding": False, + "training": model_info_for_export.export_mode, + "dynamic_axes": dynamic_axes, + "verbose": torch_exporter_verbose_log, + "export_params": False, + "keep_initializers_as_inputs": True, + } + + if check_function_has_param(torch.onnx.export, "autograd_inlining"): + # From some PyTorch version, autograd_inlining is a valid argument. + # We allow it to be True if custom autograd function is disabled (where autograd.Function + # anyway is not supported in ONNX until it can be inlined). + required_export_kwargs["autograd_inlining"] = not enable_custom_autograd_function + + invalid_args = model_info_for_export.export_extra_kwargs.keys() & required_export_kwargs.keys() + + if len(invalid_args) != 0: + error_msg = f"The following PyTorch exporter arguments cannot be specified: '{invalid_args}'." + raise RuntimeError(error_msg) + + torch.onnx.export( + flattened_module, + sample_inputs_as_tuple, + f, + **required_export_kwargs, + **model_info_for_export.export_extra_kwargs, + ) + except Exception as e: + raise wrap_exception( # noqa: B904 + ORTModuleONNXModelException, + RuntimeError( + f"There was an error while exporting the PyTorch model to ONNX: " + f"\n\n{_utils.get_exception_as_string(e)}" + ), + ) + exported_model = onnx.load_model_from_string(f.getvalue()) + + if enable_custom_autograd_function: + from ._custom_autograd_function_exporter import post_process_enabling_autograd_function + + exported_model = post_process_enabling_autograd_function(exported_model) + + if enable_zero_stage3_support: + from ._zero_stage3_compatibility import post_processing_enable_zero_stage3_compat + + exported_model = post_processing_enable_zero_stage3_compat( + exported_model, + stage3_param_handle._zero_stage3_param_map, + [name for name, _ in flattened_module.named_parameters()], + ) + + # Cannot append pull weight trigger name to input names as following, otherwise, the later check ( + # https://github.com/microsoft/onnxruntime/blob/068300d97eb25e5b52324e7af54a45ed1fa6a4c3/orttraining/orttraining/python/training/ortmodule/_training_manager.py#L466C18-L466C18) + # find input info mismatch, will re-initialize the graph builder. + # self._input_info.require_grad_names.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME) + + # Cache model for future runs + if cache_dir: + if not os.path.exists(cache_dir): + os.makedirs(cache_dir, exist_ok=True) + filename = os.path.join( + cache_dir, f"{hash_fn(str(flattened_module).encode()).hexdigest()}_{get_rank()}.onnx" + ) + logger.info(f"Caching model for future runs to {filename}.") + onnx.save(exported_model, filename) + + return exported_model, module_output_schema + + def _set_device_from_module(self, inputs, kwargs): + """Get the device from the module and save it to self._device""" + + device = _utils.get_device_from_module(self._flatten_module._original_module) or _utils.get_device_from_inputs( + inputs, kwargs + ) + if not self._device or self._device != device: + self._device = device + if not self._device: + raise wrap_exception( + ORTModuleDeviceException, RuntimeError("A device must be specified in the model or inputs!") + ) + + def signal_model_changed(self): + """Signals the execution manager to re-export the model on the next forward call""" + self._original_model_has_changed = True diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index 93481f5ecdb03..bd6a4410323b7 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -15,7 +15,6 @@ from ._execution_agent import InferenceAgent from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPolicy from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo -from ._io import unflatten_user_output from ._logger import ORTModuleInitPhase, TrackTime from ._utils import save_tuning_results, set_tuning_results from .options import DebugOptions, _SkipCheck @@ -28,8 +27,7 @@ class InferenceManager(GraphExecutionManager): """ def __init__(self, model, debug_options: DebugOptions, fallback_manager: _FallbackManager, logger: Logger): - super().__init__(model, debug_options, fallback_manager, logger) - self._export_mode = torch.onnx.TrainingMode.EVAL + super().__init__(model, debug_options, fallback_manager, torch.onnx.TrainingMode.EVAL, logger) @staticmethod def execution_session_run_forward( @@ -110,20 +108,18 @@ def forward(self, *inputs, **kwargs): build_graph = False if ( self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False - or not self._exported_model + or not self._graph_transition_manager._exported_model_info ): self.time_tracker.start(ORTModuleInitPhase.EndToEnd) # Exporting module to ONNX for the first time - build_graph, pre_grad_graph_info, exported_graph_info = self.use_cached_exported_model_or_reexport( - inputs, kwargs, self._device - ) - build_graph = self.use_cached_pre_grad_model_or_reinitialize(build_graph, pre_grad_graph_info) - # build_graph = self._export_model(*inputs, **kwargs) - # if build_graph: - # # If model was exported, then initialize the graph builder. - # self._initialize_graph_builder() + build_graph, post_export_processed_model_info = self._graph_transition_manager.use_cache_or_reconstruct( + inputs, kwargs + ) + if build_graph: + # TODO(): do we need call it for inferencing mode??? + self._initialize_graph_builder(post_export_processed_model_info) # Build the inference graph if build_graph: @@ -145,13 +141,13 @@ def forward(self, *inputs, **kwargs): create_execution_session = ( build_graph - or self._device != module_device + or self._graph_transition_manager._device != module_device or torch.are_deterministic_algorithms_enabled() is not _are_deterministic_algorithms_enabled() ) _use_deterministic_algorithms(torch.are_deterministic_algorithms_enabled()) - if self._device != module_device: - self._device = module_device + if self._graph_transition_manager._device != module_device: + self._graph_transition_manager._device = module_device if create_execution_session: # Create execution session creates the inference_session @@ -162,26 +158,29 @@ def forward(self, *inputs, **kwargs): if self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False: # Assert that the input and model device match - _utils._check_same_device(self._device, "Input argument to forward", *inputs) + _utils._check_same_device(self._graph_transition_manager._device, "Input argument to forward", *inputs) if self._runtime_options.enable_zero_stage3_support: - self._append_pull_weight_trigger_as_input(kwargs, self._device) + self._append_pull_weight_trigger_as_input(kwargs, self._graph_transition_manager._device) - prepared_input_map = self.construct_inputs(inputs, kwargs, True, self._device) + prepared_input_map = self._graph_transition_manager._post_export_processed_model_info.construct_inputs( + inputs, kwargs, True + ) if ( self._runtime_inspector.memory_ob.is_enabled() and not self._runtime_inspector.memory_ob.symbolic_dim_collecting_completed ): self._runtime_inspector.memory_ob.collect_symbolic_dim_values( - self._finalized_graph_info.onnx_graph_input_dynamic_axes_map, prepared_input_map + self._graph_transition_manager._post_export_processed_model_info.onnx_graph_input_dynamic_axes_map, + prepared_input_map, ) self._runtime_inspector.memory_ob.symbolic_dim_collecting_completed = True user_outputs, _ = InferenceManager.execution_session_run_forward( self._execution_agent, self._onnx_models.optimized_model, - self._device, + self._graph_transition_manager._device, *prepared_input_map.values(), ) @@ -196,8 +195,8 @@ def forward(self, *inputs, **kwargs): # print("user_outputs: ", user_outputs) # print("self._module_output_schema: ", self._module_output_schema) - - return unflatten_user_output(self._module_output_schema, user_outputs) + return self._graph_transition_manager._post_export_processed_model_info.restore_outputs(user_outputs) + # return unflatten_user_output(self._module_output_schema, user_outputs) except ORTModuleFallbackException as e: # Exceptions subject to fallback are handled here self._fallback_manager.handle_exception(exception=e, log_level=self._debug_options.logging.log_level) diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 68ae87ddf7526..735d92a723e57 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -77,69 +77,6 @@ def symbolic(g, self): return g.op("Identity", self) -class _PreExportGraphInfo: - def __init__(self): - # Input names parsed and then flatten from the model's forward function signature - self.onnx_graph_input_names: List[str] = [] - - # A subset of onnx_graph_input_names. - # Input names that require gradient parsed and then flatten from the model's forward function signature - # This should contains ONLY the user input names - self.onnx_graph_input_names_require_grad: List[str] = [] - - # Create symbolic names for each dimension of the graph input (e.g. onnx_graph_input_names). - # The key is the input name, the value is a dict of {dim_index: symbolic_dim_name} - # e.g. {"input1": {0: "input1_dim0", 1: "input1_dim1"}, "input2": {0: "input2_dim0"}} - self.onnx_graph_input_dynamic_axes_map: Dict[str, Dict[int, str]] = {} - - self.onnx_graph_input_shapes: List[List[int]] = [] - - -class _ExportedGraphInfo: - def __init__(self): - # Input names parsed and then flatten from the model's forward function signature + buffers + parameters (since we use - # keep_initializers_as_inputs=True for model export) - self.onnx_graph_input_names: List[str] = [] - - # A subset of onnx_graph_input_names. - # Input names that require gradient parsed and then flatten from the model's forward function signature - # This should contains both the user input names, the buffer names, and the parameter names (since we use - # keep_initializers_as_inputs=True for model export) - self.onnx_graph_input_names_require_grad: List[str] = [] - - def need_reexport(self, __value: object) -> bool: - assert isinstance( - __value, _ExportedGraphInfo - ), f"__value must be an instance of _ExportedGraphInfo, but got {type(__value)}" - - return self.onnx_graph_input_names != __value.onnx_graph_input_names - - -class _FinalizedGraphInfo: - def __init__(self): - # Input names for the pre-gradient-build graph. - # This may be different with the one in ExportedGraph since we may modify the graph inputs as needed - # for example when memory efficient gradient management is enabled. - self.onnx_graph_input_names: List[str] = [] - - # A subset of onnx_graph_input_names. - # Input names that require gradients for the pre-gradient-build graph. - self.onnx_graph_input_names_require_grad: List[str] = [] - - # Create symbolic names for each dimension of the graph input (e.g. onnx_graph_input_names). - # The key is the input name, the value is a dict of {dim_index: symbolic_dim_name} - # e.g. {"input1": {0: "input1_dim0", 1: "input1_dim1"}, "input2": {0: "input2_dim0"}} - self.onnx_graph_input_dynamic_axes_map: Dict[str, Dict[int, str]] = {} - - # self._graph_initializers: List[torch.nn.parameter.Parameter] = [] - - self._buffer_for_ort_runs: Dict[str, torch.Tensor] = OrderedDict() - self._onnx_graph_input_names_user_defined = [] # The ONNX graph input names excluding the parameters, buffers. - self._onnx_graph_input_names_require_grad_user_defined = ( - [] - ) # The ONNX graph input names excluding the parameters, buffers. - - def flatten_kwargs(kwargs, device): def _flatten_kwargs(value, name): if PrimitiveType.is_primitive_type(value): @@ -342,17 +279,24 @@ def __init__(self, original_module: torch.nn.Module): super().__init__() self._original_module: torch.nn.Module = original_module - # Before `forward` is called, _ort_module must be assigned - # Updated input info is needed to expand args into *args, **kwargs - # self._input_info: Optional[_InputInfo] = None - self._unflatten_functor: Optional[Callable] = None - - self.device: Optional[torch.device] = None - + # Before ONNX export, we flatten the args and kwargs into a 1-D list of tensors to make torch.export happy. + # As a result, we need to unflatten the args and kwargs back to the original structure before calling the + # original module's forward function. + # So we need set those information that are needed to unflatten the args and kwargs, before calling the + # torch.export. + self._device: Optional[torch.device] = None + self._args_schema: Optional[ORTModelInputOutputSchemaType] = None + self._kwargs_schema: Optional[ORTModelInputOutputSchemaType] = None + self._num_positionals: Optional[int] = None + + # Similarly, to make torch.export happy, we need to flatten the original module's outputs into a 1-D list of tensors. + # Need to keep the output schema to unflatten the outputs back to the original structure. + # Then those code depends on the original structure of the outputs can work properly. self._output_schema: Optional[ORTModelInputOutputSchemaType] = None def forward(self, *args): - new_args, new_kwargs = self._unflatten_functor(args) + new_args = unflatten_data_using_schema(args[: self._num_positionals], self._args_schema) + new_kwargs = unflatten_data_using_schema(args[self._num_positionals :], self._kwargs_schema) # print("unflatten args: ", [v.shape for v in new_args]) # print("unflatten kwargs: ", {k: v.shape for k, v in new_kwargs.items()}) @@ -360,7 +304,7 @@ def forward(self, *args): original_outputs = self._original_module(*new_args, **new_kwargs) # Flatten the outputs - flatten_outputs, self._output_schema = _extract_schema(original_outputs, self.device) + flatten_outputs, self._output_schema = _extract_schema(original_outputs, self._device) # Append _OutputIdentityOp to the outputs to support passthrough outputs final_flatten_outputs = [] @@ -370,15 +314,67 @@ def forward(self, *args): return final_flatten_outputs +class ModelInfoForExport: + def __init__( + self, + onnx_graph_input_names: List[str], + onnx_graph_input_names_require_grad: List[str], + onnx_graph_input_dynamic_axes_map: Dict[str, Dict[int, str]], + onnx_graph_input_shapes: List[List[int]], + data_accessor: Optional[List[callable]] = None, + export_mode: Optional[int] = None, + ): + # Value can be either torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL + self.export_mode = export_mode + + # Exporter can take extra arguments for ORTModule extensions + # It cannot overlap with required/immutable arguments (validated in runtime) + self.export_extra_kwargs = {} + + # Input names parsed and then flatten from the model's forward function signature. + # This should contains ONLY the user defined input names + # Be noted: some of the input might not be used by the model for its compute. + self.onnx_graph_input_names: List[str] = onnx_graph_input_names + + # A subset of onnx_graph_input_names. + # Input names that require gradient parsed and then flatten from the model's forward function signature + # This should contains ONLY the user defined input names + # Be noted: some of the input might not be used by the model for its compute. + self.onnx_graph_input_names_require_grad: List[str] = onnx_graph_input_names_require_grad + + # Create symbolic names for each dimension of the graph input (e.g. onnx_graph_input_names). + # The key is the input name, the value is a dict of {dim_index: symbolic_dim_name} + # e.g. {"input1": {0: "input1_dim0", 1: "input1_dim1"}, "input2": {0: "input2_dim0"}} + self.onnx_graph_input_dynamic_axes_map: Dict[str, Dict[int, str]] = onnx_graph_input_dynamic_axes_map + + self.onnx_graph_input_shapes: List[List[int]] = onnx_graph_input_shapes + + # A function to access the input data from the args and kwargs. + # If it is not None, the length is same as onnx_graph_input_names. + # For i-th input name, we can use the i-th function to get the input data from args and kwargs. + self.data_accessor: Optional[List[callable]] = data_accessor + + def __str__(self) -> str: + return f"""ModelInfoForExport class: + \tExport mode: {self.export_mode} + \tExport extra kwargs: {self.export_extra_kwargs} + \tInput names: {self.onnx_graph_input_names} + \tInput names require grad: {self.onnx_graph_input_names_require_grad} + \tInput dynamic axes: {self.onnx_graph_input_dynamic_axes_map} + \tInput shapes: {self.onnx_graph_input_shapes}""" + + def __repr__(self) -> str: + return self.__str__() + + def parse_inputs_for_onnx_export( all_input_parameters: List[inspect.Parameter], - # onnx_graph: Optional[onnx.ModelProto], - # schema: ORTModelInputOutputSchemaType, args: Sequence[ORTModelInputOutputType], kwargs: Mapping[str, ORTModelInputOutputType], constant_as_tensor: bool, device: torch.device, -) -> Tuple[_PreExportGraphInfo, Dict[str, Callable]]: + export_mode: int, +) -> ModelInfoForExport: """Parses through the model inputs and returns _InputInfo. Loop through all input parameters, try to flatten them into a 1-D list of inputs. For nested data in the inputs, @@ -398,10 +394,10 @@ def parse_inputs_for_onnx_export( Args: all_input_parameters: All inspected input parameters from the original model forward function signature. - onnx_graph: (optional) The onnx graph. - schema: The schema extracted from the positional arguments and keyword arguments of the model. args: The positional arguments of the model. kwargs: The keyword arguments of the model. + constant_as_tensor: Whether to treat constant inputs as tensors. + device: The device to be used for constant inputs. """ @@ -555,13 +551,22 @@ def _add_input(name, input_value, onnx_graph_input_names, cur_func): partial(lambda name, args, kwargs: kwargs[name], name), ) - exported_graph = _PreExportGraphInfo() - exported_graph.onnx_graph_input_names = onnx_graph_input_names - exported_graph.onnx_graph_input_names_require_grad = input_names_require_grad - exported_graph.onnx_graph_input_dynamic_axes_map = dynamic_axes - exported_graph.onnx_graph_input_shapes = input_shape + exported_graph = ModelInfoForExport( + onnx_graph_input_names=onnx_graph_input_names, + onnx_graph_input_names_require_grad=input_names_require_grad, + onnx_graph_input_dynamic_axes_map=dynamic_axes, + onnx_graph_input_shapes=input_shape, + data_accessor=data_accessors, + export_mode=export_mode, + ) + # exported_graph.onnx_graph_input_names = onnx_graph_input_names + # exported_graph.onnx_graph_input_names_require_grad = input_names_require_grad + # exported_graph.onnx_graph_input_dynamic_axes_map = dynamic_axes + # exported_graph.onnx_graph_input_shapes = input_shape + # exported_graph.data_accessor = data_accessors + # exported_graph._export_mode = export_mode - return exported_graph, data_accessors + return exported_graph def calculate_total_parameter_size_in_bytes(module: torch.nn.Module) -> int: @@ -598,20 +603,19 @@ def can_module_be_deep_cloned(module: torch.nn.Module, device: Optional[torch.de def parse_outputs_for_onnx_export_and_extract_schema( module, - args: Sequence[ORTModelInputOutputType], - kwargs: Mapping[str, ORTModelInputOutputType], + flatten_args: Sequence[ORTModelInputOutputType], logger: Logger, - device: Optional[torch.device], clone_module: bool, ): # Perform a forward call to grab outputs output_names = None output_dynamic_axes = None deep_copied = False + kwargs = {} logger.info("Running model forward to infer output schema and dynamic axes...") with torch.no_grad(): # Deepcopy inputs, since input values may change after model run. - sample_args_copy, sample_kwargs_copy = deepcopy_model_input(*args, **kwargs) + sample_args_copy, sample_kwargs_copy = deepcopy_model_input(*flatten_args, **kwargs) try: if clone_module: # Deepcopy model, in case model is stateful and changes after model run. @@ -653,8 +657,6 @@ def parse_outputs_for_onnx_export_and_extract_schema( for dim_idx in range(len(output.shape)): output_dynamic_axes[output_name].update({dim_idx: f"{output_name}_dim{dim_idx}"}) - _, flattend_module_output_schema = _extract_schema(sample_outputs, device) - original_module_output_schema = model_copy._output_schema # print("output_schema: ", flattend_module_output_schema) diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index fe55b16537c5c..8fba3c0dece6f 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -17,7 +17,7 @@ from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPolicy from ._gradient_accumulation_manager import GradientAccumulationManager from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo -from ._io import _FlattenedModule, unflatten_user_output +from ._io import _FlattenedModule from ._logger import ORTModuleInitPhase, TrackTime from ._runtime_inspector import Phase from ._utils import save_tuning_results, set_tuning_results @@ -38,9 +38,8 @@ def __init__( fallback_manager: _FallbackManager, logger: Logger, ): - super().__init__(model, debug_options, fallback_manager, logger) + super().__init__(model, debug_options, fallback_manager, torch.onnx.TrainingMode.TRAINING, logger) - self._export_mode = torch.onnx.TrainingMode.TRAINING self._forward_class = self._create_autofunction_class() @staticmethod @@ -250,37 +249,17 @@ def forward(self, *inputs, **kwargs): build_gradient_graph = False if ( self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False - or not self._exported_model + or not self._graph_transition_manager._exported_model_info ): self.time_tracker.start(ORTModuleInitPhase.EndToEnd) ( build_gradient_graph, - pre_grad_graph_info, - exported_graph_info, - ) = self.use_cached_exported_model_or_reexport(inputs, kwargs, self._device) + post_export_processed_model_info, + ) = self._graph_transition_manager.use_cache_or_reconstruct(inputs, kwargs) - build_gradient_graph = self.use_cached_pre_grad_model_or_reinitialize( - build_gradient_graph, pre_grad_graph_info - ) - - # build_gradient_graph = self._export_model(*inputs, **kwargs) - - # if build_gradient_graph: - # # If model was exported, then initialize the graph builder - # self._initialize_graph_builder() - - # # Since the schema was just extracted while trying to export the model and it was either - # # saved to self._input_info.schema or checked for equality with the self._input_info.schema - # # it should not need to be updated again. Pass it inside parse_inputs_for_onnx_export. - # input_info = _io.parse_inputs_for_onnx_export( - # self._module_parameters, self._onnx_models.exported_model, self._input_info.schema, inputs, kwargs - # ) - - # # Reinitialize graph builder if the inputs or initializers requiring gradient have changed. - # # Order of or operation is important here because we always need to call - # # _reinitialize_graph_builder irrespective of the value of build_gradient_graph. - # build_gradient_graph = self._reinitialize_graph_builder(input_info) or build_gradient_graph + if build_gradient_graph: + self._initialize_graph_builder(post_export_processed_model_info) # Build the gradient graph if build_gradient_graph: @@ -308,7 +287,7 @@ def forward(self, *inputs, **kwargs): ) _use_deterministic_algorithms(torch.are_deterministic_algorithms_enabled()) if self._device != device: - self._device = device + self._graph_transition_manager._device = device if create_execution_session: # Create execution session creates the training_session @@ -328,42 +307,23 @@ def forward(self, *inputs, **kwargs): # prepared_input_list = self.construct_inputs(inputs, kwargs) - prepared_input_map = self.construct_inputs(inputs, kwargs, True, self._device) + prepared_input_map = self._graph_transition_manager._post_export_processed_model_info.construct_inputs( + inputs, kwargs, True + ) if ( self._runtime_inspector.memory_ob.is_enabled() and not self._runtime_inspector.memory_ob.symbolic_dim_collecting_completed ): self._runtime_inspector.memory_ob.collect_symbolic_dim_values( - self._finalized_graph_info.onnx_graph_input_dynamic_axes_map, prepared_input_map + self._graph_transition_manager._post_export_processed_model_info.onnx_graph_input_dynamic_axes_map, + prepared_input_map, ) self._runtime_inspector.memory_ob.symbolic_dim_collecting_completed = True - # prepared_input_list, _, _ = _io._combine_input_buffers_initializers( - # self._graph_initializers, - # self._graph_info.user_input_names, - # self._input_info, - # self._flattened_module.named_buffers(), - # inputs, - # kwargs, - # self._device, - # self._runtime_inspector, - # self._zero_stage3_param_map, - # ) - - for input_name, input_value in prepared_input_map.items(): - print( - f"input_name: {input_name}, shape: {input_value.shape}, dtype: {input_value.dtype}, requires_grad: {input_value.requires_grad}" - ) - user_outputs = self._forward_class.apply(*prepared_input_map.values()) - # print("user_outputs: ", user_outputs) - # print("self._module_output_schema: ", self._module_output_schema) - outputs = unflatten_user_output( - self._module_output_schema, - user_outputs, - ) + outputs = self._graph_transition_manager._post_export_processed_model_info.restore_outputs(user_outputs) # print("outputs: ", outputs) @@ -419,38 +379,24 @@ def _build_graph(self, graph_transformer_config): self._gradient_map = [] index_for_input_requires_grad = 0 - for input_name in self._finalized_graph_info.onnx_graph_input_names: - if input_name in self._finalized_graph_info.onnx_graph_input_names_require_grad: + for input_name in self._graph_transition_manager._post_export_processed_model_info.onnx_graph_input_names: + if ( + input_name + in self._graph_transition_manager._post_export_processed_model_info.onnx_graph_input_names_require_grad + ): self._gradient_map.append(index_for_input_requires_grad) index_for_input_requires_grad += 1 else: self._gradient_map.append(-1) - # num_user_input_grads = len(self._finalized_graph_info.onnx_graph_input_names_require_grad) - # require_grad_names_set = set(self._finalized_graph_info.onnx_graph_input_names_require_grad) - # require_grad_names_index = 0 - # for input_name in self._graph_info.user_input_names: - # if input_name in require_grad_names_set: - # self._gradient_map.append(require_grad_names_index) - # require_grad_names_index += 1 - # else: - # self._gradient_map.append(-1) - - # initializer_index = num_user_input_grads - # for initializer_name in self._graph_info.initializer_names: - # if initializer_name in self._finalized_graph_info.onnx_graph_input_names_require_grad: - # self._gradient_map.append(initializer_index) - # initializer_index += 1 - # else: - # self._gradient_map.append(-1) - @TrackTime(ORTModuleInitPhase.CREATE_SESSION) def _create_execution_agent(self): """Creates a TrainingAgent that can run the forward and backward graph on the training model""" session_options, providers, provider_options = self._get_session_config() fw_feed_names = [input.name for input in self._onnx_models.optimized_model.graph.input] - device_type = self._device if type(self._device) is str else self._device.type.lower() # noqa: E721 + device_type = self._device if isinstance(self._device, str) else self._device.type.lower() + if device_type == "ort": fw_outputs_device_info = [C.get_ort_device(self._device.index)] * ( len(self._graph_info.user_output_names) + len(self._graph_info.frontier_node_arg_map) @@ -526,29 +472,6 @@ def _create_execution_agent(self): self._execution_agent._inference_session, True, self._runtime_options.tuning_results_path ) - # def _reinitialize_graph_builder(self, input_info: _InputInfo): - # """Return true if the module graph builder was reinitialized""" - - # # Model may have unused params dropped after export and not part of self._graph_initializer_names_to_train - # # To see if any trainable initializers changed, compare self._graph_initializer_names_to_train - # # with initializers in module named_parameters that are known to the onnx graph. - # initializer_names_to_train_set_user_model = { - # name - # for name, param in self._flattened_module.named_parameters() - # if param.requires_grad and name in self._graph_initializer_names - # } - - # # If inputs requiring gradient change from forward to the next, the module_gradient_graph_builder - # # needs to be reinitialized so it can compute the backward output for the new inputs that require_grad - # if ( - # input_info.require_grad_names != self._input_info.require_grad_names - # or initializer_names_to_train_set_user_model != self._graph_initializer_names_to_train - # ): - # self._input_info = input_info - # self._initialize_graph_builder() - # return True - # return False - def __getstate__(self): state = super().__getstate__() diff --git a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py index ff110c431d300..11d978e71d8a8 100644 --- a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py +++ b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py @@ -395,7 +395,7 @@ def _update_python_op_input_related_attributes( @contextmanager -def stage3_export_context(enable: bool, graph_execution_manager): +def stage3_export_context(enable: bool, stage3_param_handle, flattened_module): """Context manager for stage3 specific model export. Some export functions are overridden when entering the context; the original functions are restored when exiting the context. @@ -411,9 +411,7 @@ def stage3_export_context(enable: bool, graph_execution_manager): # Delay collecting stage3 parameters here instead of in the graph execution manager, # to make sure DeepSpeed initialization is done, so that the parameters ds_status are correct. - graph_execution_manager._zero_stage3_param_map = _get_all_zero_stage3_params( - graph_execution_manager._flattened_module - ) + stage3_param_handle._zero_stage3_param_map = _get_all_zero_stage3_params(flattened_module) try: from torch.onnx._internal import _beartype @@ -428,8 +426,8 @@ def _get_tensor_rank(x) -> Optional[int]: from torch.onnx.symbolic_helper import _is_tensor input_name = x.debugName() - if input_name in graph_execution_manager._zero_stage3_param_map: - rank = len(graph_execution_manager._zero_stage3_param_map[input_name].ds_shape) + if input_name in stage3_param_handle._zero_stage3_param_map: + rank = len(stage3_param_handle._zero_stage3_param_map[input_name].ds_shape) return rank if not _is_tensor(x) or x.type() is None: diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index b5c52bdaef3c6..ba6f7c2d0c03a 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -306,7 +306,9 @@ def __setattr__(self, name: str, value) -> None: # Re-export will be avoided if _skip_check is enabled. if isinstance(self._torch_module, TorchModuleORT): for training_mode in [False, True]: - self._torch_module._execution_manager(training_mode).signal_model_changed() + self._torch_module._execution_manager( + training_mode + )._graph_transition_manager.signal_model_changed() else: # Setting any new attributes should be done on ORTModule only when 'torch_module' is not defined diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index cb79cb712627c..7cececb1bad7a 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -624,7 +624,7 @@ def test_torch_nn_module_to_api(original_device, to_argument): x = x.to(to_argument) model(x) assert _utils.get_device_str( - model._torch_module._execution_manager(model._is_training())._device + model._torch_module._execution_manager(model._is_training())._graph_transition_manager._device ) == _utils.get_device_str(torch.device(to_argument)) @@ -690,12 +690,12 @@ def test_input_requires_grad_saved(device): model(x) assert model._torch_module._execution_manager( model._is_training() - )._pre_export_graph_info.onnx_graph_input_names_require_grad == ["input1"] + )._graph_transition_manager._model_info_for_export.onnx_graph_input_names_require_grad == ["input1"] assert ( "input1" in model._torch_module._execution_manager( model._is_training() - )._finalized_graph_info.onnx_graph_input_names_require_grad + )._graph_transition_manager._post_export_processed_model_info.onnx_graph_input_names_require_grad ) @@ -855,7 +855,7 @@ def run_step(model, x): x = torch.randn(*shape, device=device, requires_grad=True) # pt_prediction = run_step(pt_model, x) - ort_prediction = run_step(ort_model, x) + run_step(ort_model, x) # _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) # _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) @@ -2967,7 +2967,7 @@ def test_forward_data_and_model_on_different_devices(data_device, model_device): with pytest.raises(_fallback.ORTModuleDeviceException) as runtime_error: ort_model(x) assert ( - f"Input argument to forward found on device {torch.device(x.device)}, but expected it to be on module device {ort_model._torch_module._execution_manager(ort_model._is_training())._device}." + f"Input argument to forward found on device {torch.device(x.device)}, but expected it to be on module device {ort_model._torch_module._execution_manager(ort_model._is_training())._graph_transition_manager._device}." in str(runtime_error.value) ) @@ -3723,7 +3723,10 @@ def forward(self, a, b, c, d, *args, kw_0=None, **kwargs): # Modeling device = "cuda" model = UnusedNet().to(device) - model = ORTModule(model) + from onnxruntime.training.ortmodule import DebugOptions, LogLevel, ORTModule + + model = ORTModule(model, DebugOptions(log_level=LogLevel.INFO)) + # model = ORTModule(model) # Dummy data one = torch.FloatTensor([1]).to(device) @@ -4005,10 +4008,14 @@ def forward(self, input1, bool_argument): input1 = torch.randn(N, D_in, device=device) ort_model(input1, bool_arguments[0]) - exported_model1 = ort_model._torch_module._execution_manager(ort_model._is_training())._exported_model + exported_model1 = ort_model._torch_module._execution_manager( + ort_model._is_training() + )._graph_transition_manager._exported_model_info.exported_model ort_model(input1, bool_arguments[1]) - exported_model2 = ort_model._torch_module._execution_manager(ort_model._is_training())._exported_model + exported_model2 = ort_model._torch_module._execution_manager( + ort_model._is_training() + )._graph_transition_manager._exported_model_info.exported_model assert exported_model1 != exported_model2 @@ -4195,7 +4202,7 @@ def test_stateless_model_unspecified_device(): # input_info = _io.parse_inputs_for_onnx_export( # training_manager._module_parameters, -# training_manager._exported_model, +# training_manager._graph_transition_manager._exported_model_info.exported_model, # training_manager._input_info.schema, # x, # {}, @@ -4678,7 +4685,9 @@ def forward(self, batch, **kwargs): device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 # noqa: F841, N806 pt_model = ListDictKwargsNet(N, D_in).to(device) - ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="kwargsanddict")) + ort_model = ORTModule( + copy.deepcopy(pt_model), DebugOptions(save_onnx=True, log_level=LogLevel.INFO, onnx_prefix="kwargsanddict") + ) x = { "one_value": [torch.randn(N, D_in, device=device)], "two_value": [torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device)], @@ -4833,7 +4842,9 @@ def forward(self, a): ort_model = ORTModule(pt_model) _ = ort_model(torch.randn(N, D_in, device=device)) - exported_model1 = ort_model._torch_module._execution_manager(True)._exported_model + exported_model1 = ort_model._torch_module._execution_manager( + True + )._graph_transition_manager._exported_model_info.exported_model for training_mode in [False, True]: assert ort_model._torch_module._execution_manager(training_mode)._original_model_has_changed is False @@ -4843,7 +4854,9 @@ def forward(self, a): assert ort_model._torch_module._execution_manager(training_mode)._original_model_has_changed is True _ = ort_model(torch.randn(N, D_in, device=device)) - exported_model2 = ort_model._torch_module._execution_manager(True)._exported_model + exported_model2 = ort_model._torch_module._execution_manager( + True + )._graph_transition_manager._exported_model_info.exported_model assert exported_model1 != exported_model2 @@ -5282,7 +5295,9 @@ def run_step(model, x): ort_prediction, ort_loss = run_step(ort_model, ort_x) pt_prediction, pt_loss = run_step(pt_model, pt_x) if step == 0: - exported_model = ort_model._torch_module._execution_manager._training_manager._exported_model + exported_model = ( + ort_model._torch_module._execution_manager._training_manager._graph_transition_manager._exported_model_info.exported_model + ) # optimized_model = ort_model._torch_module._execution_manager._training_manager._optimized_model for onx in [ exported_model, @@ -5328,7 +5343,9 @@ def test_opset_version_change(opset_version): prediction.backward() # Check opset version on ONNX model - exported_model = ort_model._torch_module._execution_manager(ort_model._is_training())._exported_model + exported_model = ort_model._torch_module._execution_manager( + ort_model._is_training() + )._graph_transition_manager._exported_model_info.exported_model assert exported_model.opset_import[0].version == opset_version if original_env is not None: From 613136a871c27231658af9bca76bf5be1dba5d8f Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Wed, 3 Jan 2024 17:49:54 +0000 Subject: [PATCH 03/32] save --- .../ortmodule/_graph_transition_manager.py | 259 +++++++++--------- .../training/ortmodule/_inference_manager.py | 24 +- .../python/training/ortmodule/_io.py | 196 ------------- .../training/ortmodule/_training_manager.py | 10 - .../python/orttraining_test_ortmodule_api.py | 61 ++--- 5 files changed, 163 insertions(+), 387 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index eeb63ac4c20e0..5da9d00b1abd5 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -95,7 +95,7 @@ def __init__( onnx_graph_input_names_require_grad: list[str], onnx_graph_input_dynamic_axes_map: dict[str, dict[int, str]], module_forward_output_schema: ORTModelInputOutputSchemaType, - _post_export_processed_model: onnx.ModelProto, + post_export_processed_model: onnx.ModelProto, data_accessor: list[callable], ): self.device = device @@ -124,13 +124,14 @@ def __init__( # The ONNX graph input names excluding the parameters, buffers. self.onnx_graph_input_names_require_grad_user_defined = onnx_graph_input_names_require_grad_user_defined - self._post_export_processed_model: onnx.ModelProto | None = _post_export_processed_model + self._post_export_processed_model: onnx.ModelProto | None = post_export_processed_model # A function to access the input data from the args and kwargs. # If it is not None, the length is same as onnx_graph_input_names. # For i-th input name, we can use the i-th function to get the input data from args and kwargs. self.data_accessor: list[callable] | None = data_accessor + # Used for unflattening the outputs from the ORT forward run. self.module_forward_output_schema: ORTModelInputOutputSchemaType | None = module_forward_output_schema # Create the buffers for the inputs that are either parameters or buffers in the original module. @@ -169,12 +170,11 @@ def construct_inputs( The inputs are constructed in the order they appear in the model's forward function signature """ - # print("construct_inputs>>>>>", inputs, kwargs) + for name in self.onnx_graph_input_names_user_defined: if name in self.data_accessor: assert name in self.buffer_for_ort_runs, f"{name} is not in buffer_for_ort_runs" data = self.data_accessor[name](args, kwargs) - # print("data.requires_grad: ", data.requires_grad) if PrimitiveType.is_primitive_type(data) and constant_as_tensor: data = PrimitiveType.get_tensor(data, self.device) self.buffer_for_ort_runs[name] = data @@ -188,6 +188,7 @@ def construct_inputs( def restore_outputs(self, ort_flatten_outputs: list[torch.Tensor]): """Restores the outputs from the ORT forward run, back to the original data structure""" + try: # Need to distinguish between a single output and a tuple (having a single tensor) if len(ort_flatten_outputs) == 1 and self.module_forward_output_schema is _TensorStub: @@ -259,6 +260,131 @@ def __init__( # Model info after export and post export processing. self._post_export_processed_model_info = None + def use_cache_or_reconstruct( + self, args: Sequence[ORTModelInputOutputType], kwargs: Mapping[str, ORTModelInputOutputType] + ) -> tuple[bool, PostExportProcessedModelInfo]: + """Check if the model can be reused, otherwise, reconstruct the model. + + Return True if the model can be reused, otherwise, return False. + The model can be reused when the following conditions are met: + a. The model has been exported before, and the inputs (args/outputs) schemas are the same as the previous ones. + b. In training mode, the graph inputs requiring gradient are the same as the previous ones. + + """ + ( + need_export_model, + cur_model_info_for_export, + flatten_args, + flatten_kwargs, + cur_args_schema, + cur_kwargs_schema, + ) = GraphTransitionManager._export_check( + prev_exported_model_info=self._model_info_for_export, + prev_model_info_for_export=self._exported_model_info, + args=args, + kwargs=kwargs, + device=self._device, + original_model_has_changed=self._original_model_has_changed, + module_forward_func_parameters=self._module_forward_func_parameters, + export_mode=self._export_mode, + logger=self._logger, + ) + + if need_export_model: + # Set the information used to unflatten the inputs during the flatten module forward run. + # Must be set before calling exporting the model. + self._flatten_module._device = self._device + self._flatten_module._args_schema = cur_args_schema + self._flatten_module._kwargs_schema = cur_kwargs_schema + self._flatten_module._num_positionals = len(flatten_args) + + flatten_inputs = flatten_args + flatten_kwargs + self._set_device_from_module(flatten_inputs, {}) + + # Start exporting the model by passing the 1-D flatten tensor list containing all args plus kwargs. + ( + exported_model, + module_output_schema, + onnx_graph_input_names, + onnx_graph_input_names_require_grad, + ) = GraphTransitionManager._export_model( + flattened_module=self._flatten_module, + model_info_for_export=cur_model_info_for_export, + flatten_module_inputs=flatten_inputs, + run_symbolic_shape_infer=self._run_symbolic_shape_infer, + deepcopy_before_model_export=self._deepcopy_before_model_export, + device=self._device, + ortmodule_cache_dir=self._ortmodule_cache_dir, + enable_custom_autograd_function=self._enable_custom_autograd_function, + enable_zero_stage3_support=self._enable_zero_stage3_support, + onnx_opset_version=self._onnx_opset_version, + torch_exporter_verbose_log=self._torch_exporter_verbose_log, + stage3_param_handle=self, + logger=self._logger, + ) + + self._exported_model_info = ExportedModelInfo( + module_forward_args_schema=cur_args_schema, + module_forward_kwargs_schema=cur_kwargs_schema, + onnx_graph_input_names=onnx_graph_input_names, + onnx_graph_input_names_require_grad=onnx_graph_input_names_require_grad, + exported_model=exported_model, + module_forward_output_schema=module_output_schema, + ) + + self._model_info_for_export = cur_model_info_for_export + + # Reset the signal to indicate the original model has changed. + self._original_model_has_changed = False + + # Save the exported model + if self._save: + _save_model( + self._exported_model_info.exported_model, + os.path.join( + self._save_path, + _get_onnx_file_name(self._save_name_prefix, "torch_exported", self._export_mode), + ), + ) + + self._logger.info(f"do_export completed, exported graph infos: {self._exported_model_info}") + + need_re_processed = False + if need_export_model: + need_re_processed = True + else: + need_re_processed, updated_onnx_graph_input_requires_grads = GraphTransitionManager._reprocess_check( + flatten_module=self._flatten_module, + prev_exported_model_info=self._exported_model_info, + export_mode=self._export_mode, + cur_model_info_for_export=cur_model_info_for_export, + ) + if need_re_processed: + # Update the onnx_graph_input_names_require_grads to make it a new default baseline to compare using new iteration data. + self._exported_model_info.onnx_graph_input_names_require_grad = updated_onnx_graph_input_requires_grads + + if need_re_processed: + # At this point, the exported model is ready, and we can start post export processing. + self._post_export_processed_model_info = GraphTransitionManager._post_export_process( + flatten_module=self._flatten_module, + device=self._device, + exported_model_info=self._exported_model_info, + model_info_for_export=self._model_info_for_export, + logger=self._logger, + ) + + # Save the post_processed model + if self._save: + _save_model( + self._post_export_processed_model_info._post_export_processed_model, + os.path.join( + self._save_path, + _get_onnx_file_name(self._save_name_prefix, "post_processed", self._export_mode), + ), + ) + + return need_re_processed, self._post_export_processed_model_info + @staticmethod def _export_check( prev_model_info_for_export: _io.ModelInfoForExport, @@ -396,131 +522,6 @@ def _post_export_process( return post_export_processed_model_info - def use_cache_or_reconstruct( - self, args: Sequence[ORTModelInputOutputType], kwargs: Mapping[str, ORTModelInputOutputType] - ) -> tuple[bool, PostExportProcessedModelInfo]: - """Check if the model can be reused, otherwise, reconstruct the model. - - Return True if the model can be reused, otherwise, return False. - The model can be reused when the following conditions are met: - a. The model has been exported before, and the inputs (args/outputs) schemas are the same as the previous ones. - b. The graph inputs requiring gradient are the same as the previous ones. - - """ - ( - need_export_model, - cur_model_info_for_export, - flatten_args, - flatten_kwargs, - cur_args_schema, - cur_kwargs_schema, - ) = GraphTransitionManager._export_check( - prev_exported_model_info=self._model_info_for_export, - prev_model_info_for_export=self._exported_model_info, - args=args, - kwargs=kwargs, - device=self._device, - original_model_has_changed=self._original_model_has_changed, - module_forward_func_parameters=self._module_forward_func_parameters, - export_mode=self._export_mode, - logger=self._logger, - ) - - if need_export_model: - # Set the information used to unflatten the inputs during the flatten module forward run. - # Must be set before calling exporting the model. - self._flatten_module._device = self._device - self._flatten_module._args_schema = cur_args_schema - self._flatten_module._kwargs_schema = cur_kwargs_schema - self._flatten_module._num_positionals = len(flatten_args) - - flatten_inputs = flatten_args + flatten_kwargs - self._set_device_from_module(flatten_inputs, {}) - - # Start exporting the model by passing the 1-D flatten tensor list containing all args plus kwargs. - ( - exported_model, - module_output_schema, - onnx_graph_input_names, - onnx_graph_input_names_require_grad, - ) = GraphTransitionManager._export_model( - flattened_module=self._flatten_module, - model_info_for_export=cur_model_info_for_export, - flatten_module_inputs=flatten_inputs, - run_symbolic_shape_infer=self._run_symbolic_shape_infer, - deepcopy_before_model_export=self._deepcopy_before_model_export, - device=self._device, - ortmodule_cache_dir=self._ortmodule_cache_dir, - enable_custom_autograd_function=self._enable_custom_autograd_function, - enable_zero_stage3_support=self._enable_zero_stage3_support, - onnx_opset_version=self._onnx_opset_version, - torch_exporter_verbose_log=self._torch_exporter_verbose_log, - stage3_param_handle=self, - logger=self._logger, - ) - - self._exported_model_info = ExportedModelInfo( - module_forward_args_schema=cur_args_schema, - module_forward_kwargs_schema=cur_kwargs_schema, - onnx_graph_input_names=onnx_graph_input_names, - onnx_graph_input_names_require_grad=onnx_graph_input_names_require_grad, - exported_model=exported_model, - module_forward_output_schema=module_output_schema, - ) - - self._model_info_for_export = cur_model_info_for_export - - # Reset the signal to indicate the original model has changed. - self._original_model_has_changed = False - - # Save the exported model - if self._save: - _save_model( - self._exported_model_info.exported_model, - os.path.join( - self._save_path, - _get_onnx_file_name(self._save_name_prefix, "torch_exported", self._export_mode), - ), - ) - - self._logger.info(f"do_export completed, exported graph infos: {self._exported_model_info}") - - need_re_processed = False - if need_export_model: - need_re_processed = True - else: - need_re_processed, updated_onnx_graph_input_requires_grads = GraphTransitionManager._reprocess_check( - flatten_module=self._flatten_module, - prev_exported_model_info=self._exported_model_info, - export_mode=self._export_mode, - cur_model_info_for_export=cur_model_info_for_export, - ) - if need_re_processed: - # Update the onnx_graph_input_names_require_grads to make it a new default baseline to compare using new iteration data. - self._exported_model_info.onnx_graph_input_names_require_grad = updated_onnx_graph_input_requires_grads - - if need_re_processed: - # At this point, the exported model is ready, and we can start post export processing. - self._post_export_processed_model_info = GraphTransitionManager._post_export_process( - flatten_module=self._flatten_module, - device=self._device, - exported_model_info=self._exported_model_info, - model_info_for_export=self._model_info_for_export, - logger=self._logger, - ) - - # Save the post_processed model - if self._save: - _save_model( - self._post_export_processed_model_info._post_export_processed_model, - os.path.join( - self._save_path, - _get_onnx_file_name(self._save_name_prefix, "post_processed", self._export_mode), - ), - ) - - return need_re_processed, self._post_export_processed_model_info - # @_logger.TrackTime(_logger.ORTModuleInitPhase.EXPORT) # @_logger.SuppressLogs(_logger.ORTModuleInitPhase.EXPORT, is_ort_filter=False) @staticmethod diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index bd6a4410323b7..e41bb9fe947b8 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -141,12 +141,12 @@ def forward(self, *inputs, **kwargs): create_execution_session = ( build_graph - or self._graph_transition_manager._device != module_device + or self._device != module_device or torch.are_deterministic_algorithms_enabled() is not _are_deterministic_algorithms_enabled() ) _use_deterministic_algorithms(torch.are_deterministic_algorithms_enabled()) - if self._graph_transition_manager._device != module_device: + if self._device != module_device: self._graph_transition_manager._device = module_device if create_execution_session: @@ -158,29 +158,19 @@ def forward(self, *inputs, **kwargs): if self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False: # Assert that the input and model device match - _utils._check_same_device(self._graph_transition_manager._device, "Input argument to forward", *inputs) + _utils._check_same_device(self._device, "Input argument to forward", *inputs) if self._runtime_options.enable_zero_stage3_support: - self._append_pull_weight_trigger_as_input(kwargs, self._graph_transition_manager._device) + self._append_pull_weight_trigger_as_input(kwargs, self._device) prepared_input_map = self._graph_transition_manager._post_export_processed_model_info.construct_inputs( inputs, kwargs, True ) - if ( - self._runtime_inspector.memory_ob.is_enabled() - and not self._runtime_inspector.memory_ob.symbolic_dim_collecting_completed - ): - self._runtime_inspector.memory_ob.collect_symbolic_dim_values( - self._graph_transition_manager._post_export_processed_model_info.onnx_graph_input_dynamic_axes_map, - prepared_input_map, - ) - self._runtime_inspector.memory_ob.symbolic_dim_collecting_completed = True - user_outputs, _ = InferenceManager.execution_session_run_forward( self._execution_agent, self._onnx_models.optimized_model, - self._graph_transition_manager._device, + self._device, *prepared_input_map.values(), ) @@ -193,10 +183,8 @@ def forward(self, *inputs, **kwargs): self._execution_agent._inference_session, False, self._runtime_options.tuning_results_path ) - # print("user_outputs: ", user_outputs) - # print("self._module_output_schema: ", self._module_output_schema) return self._graph_transition_manager._post_export_processed_model_info.restore_outputs(user_outputs) - # return unflatten_user_output(self._module_output_schema, user_outputs) + except ORTModuleFallbackException as e: # Exceptions subject to fallback are handled here self._fallback_manager.handle_exception(exception=e, log_level=self._debug_options.logging.log_level) diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 735d92a723e57..7eb60c4ab407b 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -21,7 +21,6 @@ extract_data_and_schema, unflatten_data_using_schema, ) -from onnxruntime.training.utils.torch_io_helper import _TensorStub from ._fallback import ORTModuleIOError, wrap_exception @@ -77,89 +76,6 @@ def symbolic(g, self): return g.op("Identity", self) -def flatten_kwargs(kwargs, device): - def _flatten_kwargs(value, name): - if PrimitiveType.is_primitive_type(value): - flattened_kwargs[name] = PrimitiveType.get_tensor(value, device) - elif isinstance(value, torch.Tensor): - flattened_kwargs[name] = value - elif isinstance(value, abc.Sequence): - # If the input is a sequence (like a list), expand the list so that - # each element of the list has a corresponding entry in the flattened - # kwargs dict - for idx, val in enumerate(value): - _flatten_kwargs(val, f"{name}_{idx}") - elif isinstance(value, abc.Mapping): - # If the input is a mapping (like a dict), expand the dict so that - # each element of the dict has an entry in the flattened kwargs dict - for key, val in value.items(): - _flatten_kwargs(val, f"{name}_{key}") - - flattened_kwargs = {} - for key, value in kwargs.items(): - _flatten_kwargs(value, key) - - return flattened_kwargs - - -class _InputInfo: - def __init__( - self, - names: List[str], - shape: List[List[int]], - require_grad_names: Optional[List[str]] = None, - dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None, - schema: Optional[ORTModelInputOutputSchemaType] = None, - num_positionals=0, - ): - self.names: List[str] = names - self.shape: List[List[int]] = shape - self.require_grad_names: List[str] = require_grad_names if require_grad_names else [] - self.dynamic_axes: Dict[str, Dict[int, str]] = dynamic_axes if dynamic_axes else {} - self.schema: ORTModelInputOutputSchemaType = schema if schema else [] - self.num_positionals = num_positionals - self.kwargs = None - - def __repr__(self) -> str: - return f"""_InputInfo class: - \tNames: {self.names} - \tShape: {self.shape} - \tRequire gradient: {self.require_grad_names} - \tDynamic axes: {self.dynamic_axes} - \tSchema: {self.schema} - \t#Positionals (total): {self.num_positionals}""" - - def flatten( - self, - args: Sequence[ORTModelInputOutputType], - kwargs: Mapping[str, ORTModelInputOutputType], - device: torch.device, - ) -> Sequence[ORTModelInputOutputType]: - """Flatten args and kwargs in a single tuple of tensors with strict ordering""" - - ret = [PrimitiveType.get_tensor(arg, device) if PrimitiveType.is_primitive_type(arg) else arg for arg in args] - flattened_kwargs = flatten_kwargs(kwargs, device) - ret += [flattened_kwargs[name] for name in self.names if name in flattened_kwargs] - self.kwargs = kwargs - - # if kwargs is empty, append an empty dictionary at the end of the sample inputs to make exporter - # happy. This is because the exporter is confused with kwargs and dictionary inputs otherwise. - if not kwargs: - ret.append({}) - - return ret - - def unflatten( - self, flat_args: Sequence[ORTModelInputOutputType] - ) -> Tuple[Sequence[ORTModelInputOutputType], Mapping[str, ORTModelInputOutputType]]: - """Unflatten tuple of tensors into args and kwargs""" - - args = tuple(flat_args[: self.num_positionals]) - kwargs = self.kwargs - self.kwargs = None - return args, kwargs - - def deepcopy_model_input( *args, **kwargs ) -> Tuple[Sequence[ORTModelInputOutputType], Mapping[str, ORTModelInputOutputType]]: @@ -183,19 +99,6 @@ def extract_tensor(value): return sample_args_copy, sample_kwargs_copy -def unflatten_user_output(output_schema: Optional[ORTModelInputOutputSchemaType], outputs: List[torch.Tensor]): - try: - # Need to distinguish between a single output and a tuple (having a single tensor) - if len(outputs) == 1 and output_schema is _TensorStub: - return outputs[0] - return unflatten_data_using_schema(outputs, output_schema) - except TypeError as e: - raise wrap_exception( - ORTModuleIOError, - TypeError(f"ORTModule fails to unflatten user output: {e}"), - ) from None - - def _extract_schema( data: ORTModelInputOutputType, device ) -> Tuple[Sequence[ORTModelInputOutputType], ORTModelInputOutputSchemaType]: @@ -206,74 +109,6 @@ def _extract_schema( raise wrap_exception(ORTModuleIOError, TypeError(f"ORTModule fails to extract schema from data: {e}")) from None -def _parse_outputs_and_extract_names_and_dynamic_axes(module_output) -> Tuple[List[str], Dict[str, Dict[int, str]]]: - """Parses through the module output and returns output names and dynamic axes""" - - def _populate_output_names_and_dynamic_axes( - output, output_names: List[str], output_dynamic_axes: Dict[str, Dict[int, str]], output_idx: List[int] - ): - # Depth first traversal to traverse through the entire output collecting output names and dynamic axes - - if output is None: - return - elif isinstance(output, torch.Tensor): - # Naming the outputs with a hyphen ensures that there can be no input with the same - # name, preventing collisions with other NodeArgs (for example an input to forward called output0) - output_name = f"output-{output_idx[0]}" - output_idx[0] += 1 - output_names.append(output_name) - output_dynamic_axes[output_name] = {} - for dim_idx in range(len(output.shape)): - output_dynamic_axes[output_name].update({dim_idx: f"{output_name}_dim{dim_idx}"}) - return - - if isinstance(output, abc.Sequence): - for value in output: - _populate_output_names_and_dynamic_axes(value, output_names, output_dynamic_axes, output_idx) - elif isinstance(output, abc.Mapping): - for _, value in sorted(output.items()): - _populate_output_names_and_dynamic_axes(value, output_names, output_dynamic_axes, output_idx) - else: - raise wrap_exception( - ORTModuleIOError, - TypeError(f"ORTModule does not support the following model output type {type(output)}"), - ) - - output_names: List[str] = [] - output_dynamic_axes: Dict[str, Dict[int, str]] = {} - output_idx: List[int] = [0] - _populate_output_names_and_dynamic_axes(module_output, output_names, output_dynamic_axes, output_idx) - - return output_names, output_dynamic_axes - - -# def _transform_output_to_flat_tuple(data): -# """Converts the data to a flat tuple by iterating over the entire data structure""" - -# def _flatten_data(data, flat_data): -# # Recursively traverse over the data and populate the flat_data with torch.Tensors - -# if data is None: -# return -# elif isinstance(data, torch.Tensor): -# identity = _OutputIdentityOp.apply -# flat_data.append(identity(data)) -# elif isinstance(data, abc.Sequence): -# for value in data: -# _flatten_data(value, flat_data) -# elif isinstance(data, abc.Mapping): -# for _, value in sorted(data.items()): -# _flatten_data(value, flat_data) -# else: -# raise wrap_exception( -# ORTModuleIOError, TypeError(f"ORTModule does not support the following data type {type(data)}.") -# ) - -# flat_data = [] -# _flatten_data(data, flat_data) -# return tuple(flat_data) - - class _FlattenedModule(torch.nn.Module): def __init__(self, original_module: torch.nn.Module): super().__init__() @@ -298,9 +133,6 @@ def forward(self, *args): new_args = unflatten_data_using_schema(args[: self._num_positionals], self._args_schema) new_kwargs = unflatten_data_using_schema(args[self._num_positionals :], self._kwargs_schema) - # print("unflatten args: ", [v.shape for v in new_args]) - # print("unflatten kwargs: ", {k: v.shape for k, v in new_kwargs.items()}) - original_outputs = self._original_module(*new_args, **new_kwargs) # Flatten the outputs @@ -473,13 +305,7 @@ def _add_input(name, input_value, onnx_graph_input_names, cur_func): dynamic_axes.update(_add_dynamic_shape(name, value)) input_shape.append(list(value.size())) - # Ignore optional inputs explicitly specified as None - # ONNX exporter may remove unused inputs onnx_graph_input_names: List[str] = [] - # onnx_graph = None - # if onnx_graph is not None: - # onnx_graph_input_names = {inp.name for inp in onnx_graph.graph.input} - input_names: List[str] = [] dynamic_axes: Dict[str, Dict[int, str]] = {} input_names_require_grad: List[str] = [] @@ -559,12 +385,6 @@ def _add_input(name, input_value, onnx_graph_input_names, cur_func): data_accessor=data_accessors, export_mode=export_mode, ) - # exported_graph.onnx_graph_input_names = onnx_graph_input_names - # exported_graph.onnx_graph_input_names_require_grad = input_names_require_grad - # exported_graph.onnx_graph_input_dynamic_axes_map = dynamic_axes - # exported_graph.onnx_graph_input_shapes = input_shape - # exported_graph.data_accessor = data_accessors - # exported_graph._export_mode = export_mode return exported_graph @@ -633,23 +453,9 @@ def parse_outputs_for_onnx_export_and_extract_schema( sample_outputs = model_copy(*sample_args_copy, **sample_kwargs_copy) - # print("sample_outputs: ", sample_outputs) - # Parse the output and extract the output_names and output_dynamic_axes to be used for onnx export - # output_names, output_dynamic_axes = _parse_outputs_and_extract_names_and_dynamic_axes(sample_outputs) - output_names: List[str] = [] output_dynamic_axes: Dict[str, Dict[int, str]] = {} - - # # Naming the outputs with a hyphen ensures that there can be no input with the same - # # name, preventing collisions with other NodeArgs (for example an input to forward called output0) - # output_name = f"output-{output_idx[0]}" - # output_idx[0] += 1 - # output_names.append(output_name) - # output_dynamic_axes[output_name] = {} - # for dim_idx in range(len(output.shape)): - # output_dynamic_axes[output_name].update({dim_idx: f"{output_name}_dim{dim_idx}"}) - for output_idx, output in enumerate(sample_outputs): output_name = f"output-{output_idx}" output_names.append(output_name) @@ -659,8 +465,6 @@ def parse_outputs_for_onnx_export_and_extract_schema( original_module_output_schema = model_copy._output_schema - # print("output_schema: ", flattend_module_output_schema) - if deep_copied: del model_copy gc.collect() diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 8fba3c0dece6f..dce59ec837e36 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -311,16 +311,6 @@ def forward(self, *inputs, **kwargs): inputs, kwargs, True ) - if ( - self._runtime_inspector.memory_ob.is_enabled() - and not self._runtime_inspector.memory_ob.symbolic_dim_collecting_completed - ): - self._runtime_inspector.memory_ob.collect_symbolic_dim_values( - self._graph_transition_manager._post_export_processed_model_info.onnx_graph_input_dynamic_axes_map, - prepared_input_map, - ) - self._runtime_inspector.memory_ob.symbolic_dim_collecting_completed = True - user_outputs = self._forward_class.apply(*prepared_input_map.values()) outputs = self._graph_transition_manager._post_export_processed_model_info.restore_outputs(user_outputs) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 7cececb1bad7a..a711386760f47 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -3466,6 +3466,13 @@ def train_step(model, x): _test_helpers.assert_values_are_close(pt_out, ort_out) +def _repr_schema(ortmodule): + tm = ortmodule._torch_module._execution_manager(ortmodule._is_training())._graph_transition_manager + return repr(tm._exported_model_info.module_forward_args_schema) + repr( + tm._exported_model_info.module_forward_kwargs_schema + ) + + def test_forward_dynamic_args(): os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED" @@ -3490,30 +3497,21 @@ def test_forward_dynamic_args(): for _ in range(10): output = model(*args_size1) assert output is not None - hash_args_size1 = hash( - repr(model._torch_module._execution_manager(model._is_training())._args_input_schema) - + repr(model._torch_module._execution_manager(model._is_training())._kwargs_input_schema) - ) + hash_args_size1 = hash(_repr_schema(model)) assert hash_args_size1 is not None # Decrease number of inputs and train some more for _ in range(10): output = model(*args_size2) assert output is not None - hash_args_size2 = hash( - repr(model._torch_module._execution_manager(model._is_training())._args_input_schema) - + repr(model._torch_module._execution_manager(model._is_training())._kwargs_input_schema) - ) + hash_args_size2 = hash(_repr_schema(model)) assert hash_args_size2 != hash_args_size1 # Increase number of inputs and train some more for _ in range(10): output = model(*args_size3) assert output is not None - hash_args_size3 = hash( - repr(model._torch_module._execution_manager(model._is_training())._args_input_schema) - + repr(model._torch_module._execution_manager(model._is_training())._kwargs_input_schema) - ) + hash_args_size3 = hash(_repr_schema(model)) assert hash_args_size3 != hash_args_size2 del os.environ["ORTMODULE_SKIPCHECK_POLICY"] @@ -3538,50 +3536,35 @@ def test_forward_dynamic_kwargs(): for _ in range(10): output = model(one) assert output is not None - hash_x = hash( - repr(model._torch_module._execution_manager(model._is_training())._args_input_schema) - + repr(model._torch_module._execution_manager(model._is_training())._kwargs_input_schema) - ) + hash_x = hash(_repr_schema(model)) assert hash_x is not None # Train with x and y as inputs for _ in range(10): output = model(one, y=one) assert output is not None - hash_x_y = hash( - repr(model._torch_module._execution_manager(model._is_training())._args_input_schema) - + repr(model._torch_module._execution_manager(model._is_training())._kwargs_input_schema) - ) + hash_x_y = hash(_repr_schema(model)) assert hash_x_y != hash_x # Train with x and z as inputs for _ in range(10): output = model(one, z=one) assert output is not None - hash_x_z = hash( - repr(model._torch_module._execution_manager(model._is_training())._args_input_schema) - + repr(model._torch_module._execution_manager(model._is_training())._kwargs_input_schema) - ) + hash_x_z = hash(_repr_schema(model)) assert hash_x_z != hash_x_y # Train with x, y and z as inputs for _ in range(10): output = model(one, y=one, z=one) assert output is not None - hash_x_y_z = hash( - repr(model._torch_module._execution_manager(model._is_training())._args_input_schema) - + repr(model._torch_module._execution_manager(model._is_training())._kwargs_input_schema) - ) + hash_x_y_z = hash(_repr_schema(model)) assert hash_x_y_z != hash_x_z # Return to original input with x as input for _ in range(10): output = model(one) assert output is not None - hash_x2 = hash( - repr(model._torch_module._execution_manager(model._is_training())._args_input_schema) - + repr(model._torch_module._execution_manager(model._is_training())._kwargs_input_schema) - ) + hash_x2 = hash(_repr_schema(model)) assert hash_x2 != hash_x_y_z assert hash_x2 == hash_x @@ -4847,11 +4830,21 @@ def forward(self, a): )._graph_transition_manager._exported_model_info.exported_model for training_mode in [False, True]: - assert ort_model._torch_module._execution_manager(training_mode)._original_model_has_changed is False + assert ( + ort_model._torch_module._execution_manager( + training_mode + )._graph_transition_manager._original_model_has_changed + is False + ) ort_model.input_flag = False for training_mode in [False, True]: - assert ort_model._torch_module._execution_manager(training_mode)._original_model_has_changed is True + assert ( + ort_model._torch_module._execution_manager( + training_mode + )._graph_transition_manager._original_model_has_changed + is True + ) _ = ort_model(torch.randn(N, D_in, device=device)) exported_model2 = ort_model._torch_module._execution_manager( From 96e3d2c3a68fd84ce068a20c2c67dbfb57831d3a Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Thu, 4 Jan 2024 13:46:50 +0000 Subject: [PATCH 04/32] fix all tests --- .../ortmodule/_graph_execution_manager.py | 58 +++-- .../ortmodule/_graph_transition_manager.py | 238 +++++++++--------- .../training/ortmodule/_inference_manager.py | 11 +- .../python/training/ortmodule/_io.py | 45 ++-- .../python/training/ortmodule/_onnx_models.py | 10 +- .../training/ortmodule/_training_manager.py | 12 +- .../python/training/ortmodule/_utils.py | 12 +- .../python/orttraining_test_ortmodule_api.py | 112 +++------ 8 files changed, 230 insertions(+), 268 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 3577ee5a04a33..94dbdab7c4184 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -7,7 +7,7 @@ import logging import os from abc import ABC, abstractmethod # noqa: F401 -from typing import Dict, List, OrderedDict, Tuple +from typing import Dict, List, Optional, OrderedDict, Tuple import onnx import torch @@ -66,28 +66,9 @@ def __init__( # Original and flattened (transformed) output module self._flattened_module = module - # Device where the model is placed. - # self._device: Optional[torch.device] = _utils.get_device_from_module(module) - - # Model export and post export processing before inference optimization && building gradient. self._onnx_models = _onnx_models.ONNXModels() self._export_mode = export_mode - self._graph_transition_manager = GraphTransitionManager( - flatten_module=module, - export_mode=export_mode, - save=debug_options.save_onnx_models.save, - save_path=debug_options.save_onnx_models.path, - save_name_prefix=debug_options.save_onnx_models.name_prefix, - deepcopy_before_model_export=self._runtime_options.deepcopy_before_model_export, - torch_exporter_verbose_log=self._debug_options.logging.log_level <= LogLevel.INFO, - enable_zero_stage3_support=self._runtime_options.enable_zero_stage3_support, - onnx_opset_version=self._runtime_options.onnx_opset_version, - enable_custom_autograd_function=self._runtime_options.enable_custom_autograd_function, - enable_symbolic_shape_infer=self._runtime_options.run_symbolic_shape_infer, - exported_model_cache_dir=self._runtime_options.ortmodule_cache_dir, - logger=logger, - ) - self._post_export_processed_model_info = None + self._graph_transition_manager: Optional[GraphTransitionManager] = None # Model after inference optimization && gradient building. self._graph_builder = None @@ -127,6 +108,8 @@ def __init__( configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="ort_output", stats_overwrite=True) + self._initialize_graph_transition_manager() + def _get_torch_gpu_allocator_function_addresses(self): if self._runtime_options.use_external_gpu_allocator and torch.cuda.is_available(): # CPP extension to get torch GPU allocator's alloc and free function addresses @@ -136,6 +119,24 @@ def _get_torch_gpu_allocator_function_addresses(self): self._torch_free = torch_gpu_allocator.gpu_caching_allocator_raw_delete_address() self._torch_empty_cache = torch_gpu_allocator.gpu_caching_allocator_empty_cache_address() + def _initialize_graph_transition_manager(self): + """Creates a new GraphTransitionManager, initializes it and saves it to self._graph_transition_manager""" + self._graph_transition_manager = GraphTransitionManager( + flatten_module=self._flattened_module, + export_mode=self._export_mode, + save=self._debug_options.save_onnx_models.save, + save_path=self._debug_options.save_onnx_models.path, + save_name_prefix=self._debug_options.save_onnx_models.name_prefix, + deepcopy_before_model_export=self._runtime_options.deepcopy_before_model_export, + torch_exporter_verbose_log=self._debug_options.logging.log_level <= LogLevel.INFO, + enable_zero_stage3_support=self._runtime_options.enable_zero_stage3_support, + onnx_opset_version=self._runtime_options.onnx_opset_version, + enable_custom_autograd_function=self._runtime_options.enable_custom_autograd_function, + enable_symbolic_shape_infer=self._runtime_options.run_symbolic_shape_infer, + exported_model_cache_dir=self._runtime_options.ortmodule_cache_dir, + logger=self._logger, + ) + def _validate_module_type(self, module): """Raises ORTModuleTorchModelException if the module is not a torch.nn.Module""" @@ -166,8 +167,9 @@ def forward(self): def _build_graph(self, config): if self._runtime_options.use_static_shape: - # (TODO): add the shape for the onnx graph inputs. - self._graph_builder.build(config) # , self._input_info.shape) + self._graph_builder.build( + config, self._graph_transition_manager._model_info_for_export.onnx_graph_input_shapes + ) else: self._graph_builder.build(config) @@ -302,6 +304,7 @@ def __getstate__(self): "_onnx_models", "_graph_builder", "_graph_info", + "_graph_transition_manager", "_execution_agent", "_torch_alloc", "_torch_free", @@ -317,8 +320,11 @@ def __setstate__(self, state): _utils.reinitialize_graph_execution_manager(self) + self._initialize_graph_transition_manager() + @property def _device(self): + # Graph transition manager is responsible for detecting and managing the device to use. return self._graph_transition_manager._device @_logger.TrackTime(_logger.ORTModuleInitPhase.DETECTION) @@ -343,15 +349,13 @@ def _enable_conditional_optimizations( ) if self._runtime_options.enable_sparse_optimizer: - detected_device = _utils.get_device_from_module(self._original_module) or _utils.get_device_from_inputs( - inputs, kwargs - ) + detected_device = _utils.get_device_from_module_and_inputs(self._original_module, inputs, kwargs) if self._runtime_options.enable_zero_stage3_support: self._append_pull_weight_trigger_as_input(kwargs, detected_device) prepared_input_map = self._graph_transition_manager._post_export_processed_model_info.construct_inputs( - inputs, kwargs, True + inputs, kwargs, True, self._device ) embed_sparsity_results = OrderedDict() diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index 5da9d00b1abd5..76232de51d63b 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -28,26 +28,22 @@ from . import _io, _utils from ._fallback import ORTModuleDeviceException, ORTModuleIOError, ORTModuleONNXModelException, wrap_exception +from ._onnx_models import _get_onnx_file_name, _save_model from ._utils import check_function_has_param, get_rank from ._zero_stage3_compatibility import stage3_export_context -def _get_onnx_file_name(name_prefix, name, export_mode): - suffix = "training" if export_mode == torch.onnx.TrainingMode.TRAINING else "inference" - return f"{name_prefix}_{name}_{suffix}.onnx" - - -def _save_model(model: onnx.ModelProto, file_path: str): - onnx.save(model, file_path) - - class ExportedModelInfo: + """Encapsulates the information of the exported model.""" + def __init__( self, module_forward_args_schema: ORTModelInputOutputSchemaType, module_forward_kwargs_schema: ORTModelInputOutputSchemaType, onnx_graph_input_names: list[str], onnx_graph_input_names_require_grad: list[str], + onnx_graph_input_names_user_defined: list[str], + onnx_graph_input_names_require_grad_user_defined: list[str], exported_model: onnx.ModelProto, module_forward_output_schema: ORTModelInputOutputSchemaType, ): @@ -63,12 +59,21 @@ def __init__( # Be noted: all inputs are used by the model for its compute. self.onnx_graph_input_names_require_grad: list[str] = onnx_graph_input_names_require_grad + self.onnx_graph_input_names_user_defined = ( + onnx_graph_input_names_user_defined # The ONNX graph input names excluding the parameters, buffers. + ) + + # The ONNX graph input names excluding the parameters, buffers. + self.onnx_graph_input_names_require_grad_user_defined = onnx_graph_input_names_require_grad_user_defined + # Exported model proto. self.exported_model: onnx.ModelProto | None = exported_model + # Used as a baseline to compare with the current inputs (args/kwargs) to see if the model needs to be re-exported. self.module_forward_args_schema: ORTModelInputOutputSchemaType | None = module_forward_args_schema self.module_forward_kwargs_schema: ORTModelInputOutputSchemaType | None = module_forward_kwargs_schema + # Used for unflattening the outputs from the ORT forward run. self.module_forward_output_schema: ORTModelInputOutputSchemaType | None = module_forward_output_schema def __str__(self): @@ -85,10 +90,11 @@ def __repro__(self): class PostExportProcessedModelInfo: + """Encapsulates the information of the post-export processed model.""" + def __init__( self, flatten_module: torch.nn.Module, - device: torch.device | None, onnx_graph_input_names_user_defined: list[str], onnx_graph_input_names_require_grad_user_defined: list[str], onnx_graph_input_names: list[str], @@ -98,8 +104,6 @@ def __init__( post_export_processed_model: onnx.ModelProto, data_accessor: list[callable], ): - self.device = device - self._flattened_module = flatten_module # Input names for the pre-gradient-build graph. @@ -148,7 +152,6 @@ def __init__( def __str__(self): return f"""PostExportProcessedModelInfo class: - \tdevice: {self.device} \tonnx_graph_input_names: {self.onnx_graph_input_names} \tonnx_graph_input_names_require_grad: {self.onnx_graph_input_names_require_grad} \tonnx_graph_input_dynamic_axes_map: {self.onnx_graph_input_dynamic_axes_map} @@ -165,6 +168,7 @@ def construct_inputs( args: Sequence[ORTModelInputOutputType], kwargs: Mapping[str, ORTModelInputOutputType], constant_as_tensor: bool, + device: torch.device, ): """Constructs the inputs for the forward method @@ -176,7 +180,7 @@ def construct_inputs( assert name in self.buffer_for_ort_runs, f"{name} is not in buffer_for_ort_runs" data = self.data_accessor[name](args, kwargs) if PrimitiveType.is_primitive_type(data) and constant_as_tensor: - data = PrimitiveType.get_tensor(data, self.device) + data = PrimitiveType.get_tensor(data, device) self.buffer_for_ort_runs[name] = data else: raise wrap_exception( @@ -213,14 +217,14 @@ def __init__( save_name_prefix: str, deepcopy_before_model_export: bool, torch_exporter_verbose_log: bool, - enable_zero_stage3_support: bool, + enable_zero_stage3_support: bool, # TODO(): implement as a plugin onnx_opset_version: int, - enable_custom_autograd_function: bool, + enable_custom_autograd_function: bool, # TODO(): implement as a plugin enable_symbolic_shape_infer: bool, exported_model_cache_dir: str, logger: logging.Logger, ): - self._device = _utils.get_device_from_module(flatten_module) + self._device = _utils._get_device_from_module(flatten_module) self._export_mode = export_mode # Debug options @@ -237,6 +241,8 @@ def __init__( self._run_symbolic_shape_infer = enable_symbolic_shape_infer self._ortmodule_cache_dir = exported_model_cache_dir + self._export_extra_kwargs = {} + self._logger = logger # A signal to indicate if the original model has changed and need a re-export. @@ -260,10 +266,10 @@ def __init__( # Model info after export and post export processing. self._post_export_processed_model_info = None - def use_cache_or_reconstruct( + def use_cache_or_reconstruct_post_processed_model( self, args: Sequence[ORTModelInputOutputType], kwargs: Mapping[str, ORTModelInputOutputType] ) -> tuple[bool, PostExportProcessedModelInfo]: - """Check if the model can be reused, otherwise, reconstruct the model. + """Check if the post-export processed ONNX model can be reused, otherwise, reconstruct the model. Return True if the model can be reused, otherwise, return False. The model can be reused when the following conditions are met: @@ -271,22 +277,26 @@ def use_cache_or_reconstruct( b. In training mode, the graph inputs requiring gradient are the same as the previous ones. """ - ( - need_export_model, - cur_model_info_for_export, - flatten_args, - flatten_kwargs, - cur_args_schema, - cur_kwargs_schema, - ) = GraphTransitionManager._export_check( - prev_exported_model_info=self._model_info_for_export, + + if self._device is None: + device = _utils.get_device_from_module_and_inputs(self._flatten_module._original_module, args, kwargs) + if not self._device or self._device != device: + self._device = device + if not self._device: + raise wrap_exception( + ORTModuleDeviceException, RuntimeError("A device must be specified in the model or inputs!") + ) + + # Extract the schema from the args and kwargs, and compare with the cached one. + # This check ideally is not needed as we already have the above check, but it is added as a safeguard. + flatten_args, cur_args_schema = _io._extract_schema(copy.copy(args), self._device) + flatten_kwargs, cur_kwargs_schema = _io._extract_schema(copy.copy(kwargs), self._device) + + need_export_model = GraphTransitionManager._export_check( prev_model_info_for_export=self._exported_model_info, - args=args, - kwargs=kwargs, - device=self._device, original_model_has_changed=self._original_model_has_changed, - module_forward_func_parameters=self._module_forward_func_parameters, - export_mode=self._export_mode, + cur_args_schema=cur_args_schema, + cur_kwargs_schema=cur_kwargs_schema, logger=self._logger, ) @@ -299,7 +309,18 @@ def use_cache_or_reconstruct( self._flatten_module._num_positionals = len(flatten_args) flatten_inputs = flatten_args + flatten_kwargs - self._set_device_from_module(flatten_inputs, {}) + + # Check graph inputs parsed from the model's forward function signature and current inputs, + # if they are different, we need to re-export the model. + cur_model_info_for_export = _io.parse_inputs_for_onnx_export( + self._module_forward_func_parameters, + args, + kwargs, + True, + self._device, + self._export_mode, + self._export_extra_kwargs, + ) # Start exporting the model by passing the 1-D flatten tensor list containing all args plus kwargs. ( @@ -323,11 +344,28 @@ def use_cache_or_reconstruct( logger=self._logger, ) + # Get the intersection of all user defined input names (parsed from forward function signature) and + # the exported model input names including both user defined names and parameter/buffer names. + # It is possible some user defined input names are not in the exported model input names, if it is not used + # by the model for its compute. + onnx_graph_input_names_user_defined = [ + input_name + for input_name in cur_model_info_for_export.onnx_graph_input_names + if input_name in onnx_graph_input_names + ] + onnx_graph_input_names_require_grad_user_defined = [ + input_name + for input_name in cur_model_info_for_export.onnx_graph_input_names_require_grad + if input_name in onnx_graph_input_names_require_grad + ] + self._exported_model_info = ExportedModelInfo( module_forward_args_schema=cur_args_schema, module_forward_kwargs_schema=cur_kwargs_schema, onnx_graph_input_names=onnx_graph_input_names, onnx_graph_input_names_require_grad=onnx_graph_input_names_require_grad, + onnx_graph_input_names_user_defined=onnx_graph_input_names_user_defined, + onnx_graph_input_names_require_grad_user_defined=onnx_graph_input_names_require_grad_user_defined, exported_model=exported_model, module_forward_output_schema=module_output_schema, ) @@ -357,19 +395,24 @@ def use_cache_or_reconstruct( flatten_module=self._flatten_module, prev_exported_model_info=self._exported_model_info, export_mode=self._export_mode, - cur_model_info_for_export=cur_model_info_for_export, + cur_model_info_for_export=self._model_info_for_export, ) if need_re_processed: - # Update the onnx_graph_input_names_require_grads to make it a new default baseline to compare using new iteration data. + # Update the onnx_graph_input_names_require_grads to make it a new default baseline to compare + # using new iteration data. self._exported_model_info.onnx_graph_input_names_require_grad = updated_onnx_graph_input_requires_grads if need_re_processed: - # At this point, the exported model is ready, and we can start post export processing. + # At this point, the exported model is ready, and we can start post-export processing. self._post_export_processed_model_info = GraphTransitionManager._post_export_process( flatten_module=self._flatten_module, device=self._device, + export_mode=self._export_mode, exported_model_info=self._exported_model_info, model_info_for_export=self._model_info_for_export, + enable_custom_autograd_function=self._enable_custom_autograd_function, + enable_zero_stage3_support=self._enable_zero_stage3_support, + stage3_param_handle=self, logger=self._logger, ) @@ -387,21 +430,16 @@ def use_cache_or_reconstruct( @staticmethod def _export_check( - prev_model_info_for_export: _io.ModelInfoForExport, prev_exported_model_info: ExportedModelInfo, - args: Sequence[ORTModelInputOutputType], - kwargs: Mapping[str, ORTModelInputOutputType], - device: torch.device | None, original_model_has_changed: bool, - module_forward_func_parameters: list[inspect.Parameter], - export_mode: int, + cur_args_schema: ORTModelInputOutputSchemaType, + cur_kwargs_schema: ORTModelInputOutputSchemaType, logger: logging.Logger, ): """Check if the model needs to be exported, if yes, return True. For the following cases, return True: 1. The model has never been exported before. - 2. The model's input names parsed from args and kwargs changed. 3. The model input schema parsed from args and kwargs has changed. """ @@ -409,39 +447,22 @@ def _export_check( need_export_model = need_export_model or original_model_has_changed - # Check graph inputs parsed from the model's forward function signature and current inputs, - # if they are different, we need to re-export the model. - model_info_for_export = _io.parse_inputs_for_onnx_export( - module_forward_func_parameters, args, kwargs, True, device, export_mode - ) - - need_export_model = ( - need_export_model - or prev_model_info_for_export.onnx_graph_input_names != model_info_for_export.onnx_graph_input_names - ) - - # Extract the schema from the args and kwargs, and compare with the cached one. - # This check ideally is not needed as we already have the above check, but it is added as a safeguard. - flatten_args, args_schema = _io._extract_schema(copy.copy(args), device) - flatten_kwargs, kwargs_schema = _io._extract_schema(copy.copy(kwargs), device) - need_export_model = ( need_export_model - or args_schema != prev_exported_model_info.module_forward_args_schema - or kwargs_schema != prev_exported_model_info.module_forward_kwargs_schema + or cur_args_schema != prev_exported_model_info.module_forward_args_schema + or cur_kwargs_schema != prev_exported_model_info.module_forward_kwargs_schema ) - logger.info( - f"_export_check completed - need_export_model: {need_export_model}, model_info_for_export: {model_info_for_export}" - ) + logger.info(f"_export_check completed - need_export_model: {need_export_model}") - return need_export_model, model_info_for_export, flatten_args, flatten_kwargs, args_schema, kwargs_schema + return need_export_model @staticmethod def _reprocess_check( flatten_module, prev_exported_model_info, export_mode, cur_model_info_for_export: _io.ModelInfoForExport ) -> bool: - """Check if the exported model needs to be re-processed, if yes, return True and the updated onnx_graph_input_requires_grads. + """Check if the exported model needs to be re-processed, if yes, + return True and the updated onnx_graph_input_requires_grads. For the following cases, return True: 1. The export mode is TRAINING and the model's input names (including both user input and module parameters) @@ -475,8 +496,12 @@ def _reprocess_check( def _post_export_process( flatten_module, device, + export_mode, exported_model_info: ExportedModelInfo, model_info_for_export: _io.ModelInfoForExport, + enable_custom_autograd_function: bool, + enable_zero_stage3_support: bool, + stage3_param_handle: type, logger: logging.Logger, ): """Post process the exported model, generate the processed model which will be used for initializing graph builder.""" @@ -484,30 +509,32 @@ def _post_export_process( # Deepcopy the exported model, in case modification affects the exported model. # TODO(): Do pre-grad graph modification as needed, for memory efficient gradient management, etc. - # Currently, we don't do any modifications. - post_processed_model = copy.deepcopy(exported_model_info.exported_model) - # Get the intersection of all user defined input names (parsed from forward function signature) and - # the exported model input names including both user defined names and parameter/buffer names. - # It is possible some user defined input names are not in the exported model input names, if it is not used - # by the model for its compute. - onnx_graph_input_names_user_defined = [ - input_name - for input_name in model_info_for_export.onnx_graph_input_names - if input_name in exported_model_info.onnx_graph_input_names - ] - onnx_graph_input_names_require_grad = [ - input_name - for input_name in model_info_for_export.onnx_graph_input_names_require_grad - if input_name in exported_model_info.onnx_graph_input_names_require_grad - ] + if export_mode == torch.onnx.TrainingMode.TRAINING: + if enable_custom_autograd_function: + from ._custom_autograd_function_exporter import post_process_enabling_autograd_function + + post_processed_model = post_process_enabling_autograd_function(post_processed_model) + + if enable_zero_stage3_support: + from ._zero_stage3_compatibility import post_processing_enable_zero_stage3_compat + + post_processed_model = post_processing_enable_zero_stage3_compat( + post_processed_model, + stage3_param_handle._zero_stage3_param_map, + [name for name, _ in flatten_module.named_parameters()], + ) + + # Cannot append pull weight trigger name to input names as following, otherwise, the later check ( + # https://github.com/microsoft/onnxruntime/blob/068300d97eb25e5b52324e7af54a45ed1fa6a4c3/orttraining/orttraining/python/training/ortmodule/_training_manager.py#L466C18-L466C18) + # find input info mismatch, will re-initialize the graph builder. + # self._input_info.require_grad_names.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME) post_export_processed_model_info = PostExportProcessedModelInfo( flatten_module, - device, - onnx_graph_input_names_user_defined, - onnx_graph_input_names_require_grad, + exported_model_info.onnx_graph_input_names_user_defined, + exported_model_info.onnx_graph_input_names_require_grad_user_defined, exported_model_info.onnx_graph_input_names, exported_model_info.onnx_graph_input_names_require_grad, model_info_for_export.onnx_graph_input_dynamic_axes_map, @@ -642,11 +669,6 @@ def _get_exported_model( # Export torch.nn.Module to ONNX f = io.BytesIO() - print("pre_export_graph_info.onnx_graph_input_names: ", model_info_for_export.onnx_graph_input_names) - print( - "pre_export_graph_info.onnx_graph_input_names_require_grad: ", - model_info_for_export.onnx_graph_input_names_require_grad, - ) # Deepcopy inputs, since input values may change after model run. # NOTE: Inputs may contain tensors that have attributes preventing their deepcopy (example grad_fn). @@ -656,8 +678,6 @@ def _get_exported_model( assert len(sample_kwargs_copy) == 0, "Currently, kwargs are not supported for ONNX export." sample_inputs_as_tuple = sample_inputs_copy - print(">>>shapes for the flatten args and kwargs", [v.shape for v in sample_inputs_as_tuple]) - # Ops behaving differently under train/eval mode need to be exported with the # correct training flag to reflect the expected behavior. # For example, the Dropout node in a model is dropped under eval mode. @@ -708,25 +728,6 @@ def _get_exported_model( ) exported_model = onnx.load_model_from_string(f.getvalue()) - if enable_custom_autograd_function: - from ._custom_autograd_function_exporter import post_process_enabling_autograd_function - - exported_model = post_process_enabling_autograd_function(exported_model) - - if enable_zero_stage3_support: - from ._zero_stage3_compatibility import post_processing_enable_zero_stage3_compat - - exported_model = post_processing_enable_zero_stage3_compat( - exported_model, - stage3_param_handle._zero_stage3_param_map, - [name for name, _ in flattened_module.named_parameters()], - ) - - # Cannot append pull weight trigger name to input names as following, otherwise, the later check ( - # https://github.com/microsoft/onnxruntime/blob/068300d97eb25e5b52324e7af54a45ed1fa6a4c3/orttraining/orttraining/python/training/ortmodule/_training_manager.py#L466C18-L466C18) - # find input info mismatch, will re-initialize the graph builder. - # self._input_info.require_grad_names.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME) - # Cache model for future runs if cache_dir: if not os.path.exists(cache_dir): @@ -739,19 +740,6 @@ def _get_exported_model( return exported_model, module_output_schema - def _set_device_from_module(self, inputs, kwargs): - """Get the device from the module and save it to self._device""" - - device = _utils.get_device_from_module(self._flatten_module._original_module) or _utils.get_device_from_inputs( - inputs, kwargs - ) - if not self._device or self._device != device: - self._device = device - if not self._device: - raise wrap_exception( - ORTModuleDeviceException, RuntimeError("A device must be specified in the model or inputs!") - ) - def signal_model_changed(self): """Signals the execution manager to re-export the model on the next forward call""" self._original_model_has_changed = True diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index e41bb9fe947b8..08f04d36eeb55 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -114,9 +114,10 @@ def forward(self, *inputs, **kwargs): # Exporting module to ONNX for the first time - build_graph, post_export_processed_model_info = self._graph_transition_manager.use_cache_or_reconstruct( - inputs, kwargs - ) + ( + build_graph, + post_export_processed_model_info, + ) = self._graph_transition_manager.use_cache_or_reconstruct_post_processed_model(inputs, kwargs) if build_graph: # TODO(): do we need call it for inferencing mode??? self._initialize_graph_builder(post_export_processed_model_info) @@ -137,7 +138,7 @@ def forward(self, *inputs, **kwargs): self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False or not self._execution_agent ): - module_device = _utils.get_device_from_module(self._original_module) + module_device = _utils.get_device_from_module_and_inputs(self._original_module, inputs, kwargs) create_execution_session = ( build_graph @@ -164,7 +165,7 @@ def forward(self, *inputs, **kwargs): self._append_pull_weight_trigger_as_input(kwargs, self._device) prepared_input_map = self._graph_transition_manager._post_export_processed_model_info.construct_inputs( - inputs, kwargs, True + inputs, kwargs, True, self._device ) user_outputs, _ = InferenceManager.execution_session_run_forward( diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 7eb60c4ab407b..7e7ab81b845eb 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -155,13 +155,14 @@ def __init__( onnx_graph_input_shapes: List[List[int]], data_accessor: Optional[List[callable]] = None, export_mode: Optional[int] = None, + export_extra_kwargs: Optional[Dict[str, any]] = None, ): # Value can be either torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL self.export_mode = export_mode # Exporter can take extra arguments for ORTModule extensions # It cannot overlap with required/immutable arguments (validated in runtime) - self.export_extra_kwargs = {} + self.export_extra_kwargs = export_extra_kwargs # Input names parsed and then flatten from the model's forward function signature. # This should contains ONLY the user defined input names @@ -199,6 +200,10 @@ def __repr__(self) -> str: return self.__str__() +def _arg_access__with_index_func(arg_index, args, kwargs): + return args[arg_index] + + def parse_inputs_for_onnx_export( all_input_parameters: List[inspect.Parameter], args: Sequence[ORTModelInputOutputType], @@ -206,6 +211,7 @@ def parse_inputs_for_onnx_export( constant_as_tensor: bool, device: torch.device, export_mode: int, + export_extra_kwargs: Optional[Dict[str, any]] = None, ) -> ModelInfoForExport: """Parses through the model inputs and returns _InputInfo. @@ -244,7 +250,7 @@ def _add_dynamic_shape(name, input) -> Dict[str, Dict[int, str]]: def _warn_of_constant_inputs(data): warnings.warn(f"Received input of type {type(data)} is treated as a constant by ORT by default.") - def _add_input(name, input_value, onnx_graph_input_names, cur_func): + def _add_input(name: str, input_value, onnx_graph_input_names: List[str], cur_func: Callable): """Returns number of expanded non none inputs that _add_input processed""" # in case the input is already handled. @@ -272,11 +278,15 @@ def _add_input(name, input_value, onnx_graph_input_names, cur_func): for i, val in enumerate(value): # Name each input with the index appended to the original name of the # argument. + + def _access_func1(i, cur_func, args, kwargs): + return cur_func(args, kwargs)[i] + _add_input( f"{name}_{i}", val, onnx_graph_input_names, - partial(lambda i, args, kwargs: cur_func(args, kwargs)[i], i), + partial(_access_func1, i, cur_func), ) # Return here since the list by itself is not a valid input. @@ -286,11 +296,15 @@ def _add_input(name, input_value, onnx_graph_input_names, cur_func): # If the input is a mapping (like a dict), expand the dict so that # each element of the dict is an input by itself. for key, val in value.items(): + + def _access_func2(key, cur_func, args, kwargs): + return cur_func(args, kwargs)[key] + _add_input( f"{name}_{key}", val, onnx_graph_input_names, - partial(lambda key, args, kwargs: cur_func(args, kwargs)[key], key), + partial(_access_func2, key, cur_func), ) # Return here since the dict by itself is not a valid input. @@ -339,13 +353,8 @@ def _add_input(name, input_value, onnx_graph_input_names, cur_func): name = f"{input_parameter.name}_{var_positional_idx}" var_positional_idx += 1 inp = args[args_i] - _add_input( - name, - inp, - # onnx_graph, - onnx_graph_input_names, - partial(lambda args_i, args, kwargs: args[args_i], args_i), - ) + + _add_input(name, inp, onnx_graph_input_names, partial(_arg_access__with_index_func, args_i)) elif ( input_parameter.kind == inspect.Parameter.POSITIONAL_ONLY or input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD @@ -359,22 +368,29 @@ def _add_input(name, input_value, onnx_graph_input_names, cur_func): if input_idx < len(args) and args[input_idx] is not None: inp = args[input_idx] - access_func = partial(lambda input_idx, args, kwargs: args[input_idx], input_idx) + access_func = partial(_arg_access__with_index_func, input_idx) elif name in kwargs and kwargs[name] is not None: inp = kwargs[name] - access_func = partial(lambda name, args, kwargs: kwargs[name], name) + def _access_func5(name, args, kwargs): + return kwargs[name] + + access_func = partial(_access_func5, name) _add_input(name, inp, onnx_graph_input_names, access_func) elif input_parameter.kind == inspect.Parameter.VAR_KEYWORD: # **kwargs is always the last argument of forward() for name, inp in kwargs.items(): + + def _access_func6(name, args, kwargs): + return kwargs[name] + _add_input( name, inp, onnx_graph_input_names, - partial(lambda name, args, kwargs: kwargs[name], name), + partial(_access_func6, name), ) exported_graph = ModelInfoForExport( @@ -384,6 +400,7 @@ def _add_input(name, input_value, onnx_graph_input_names, cur_func): onnx_graph_input_shapes=input_shape, data_accessor=data_accessors, export_mode=export_mode, + export_extra_kwargs=export_extra_kwargs, ) return exported_graph diff --git a/orttraining/orttraining/python/training/ortmodule/_onnx_models.py b/orttraining/orttraining/python/training/ortmodule/_onnx_models.py index d687bc24384ed..4b6011f0786ec 100644 --- a/orttraining/orttraining/python/training/ortmodule/_onnx_models.py +++ b/orttraining/orttraining/python/training/ortmodule/_onnx_models.py @@ -23,8 +23,7 @@ def _save_model(model: onnx.ModelProto, file_path: str): class ONNXModels: """Encapsulates all ORTModule onnx models. - 1. exported_model: Model that is exported by torch.onnx.export - 2. optimized_model: For eval mode it's exported_model with concrete input shapes set if needed, + 1. optimized_model: For eval mode it's exported_model with concrete input shapes set if needed, for training mode, it's an optimized model after the gradients graph has been built. In addition, ORTModule also saves two other models, to the user-provided path: a. the pre_grad_model which is the model before the gradients graph is built. @@ -32,15 +31,8 @@ class ONNXModels: It has further optimizations done by the InferenceSession and is saved by the InferenceSession. """ - exported_model: Optional[onnx.ModelProto] = None optimized_model: Optional[onnx.ModelProto] = None - def save_exported_model(self, path, name_prefix, export_mode): - # save the ortmodule exported model - _save_model( - self.exported_model, os.path.join(path, _get_onnx_file_name(name_prefix, "torch_exported", export_mode)) - ) - def save_optimized_model(self, path, name_prefix, export_mode): # save the ortmodule optimized model _save_model( diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index dce59ec837e36..7db6d91a82b36 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -256,7 +256,7 @@ def forward(self, *inputs, **kwargs): ( build_gradient_graph, post_export_processed_model_info, - ) = self._graph_transition_manager.use_cache_or_reconstruct(inputs, kwargs) + ) = self._graph_transition_manager.use_cache_or_reconstruct_post_processed_model(inputs, kwargs) if build_gradient_graph: self._initialize_graph_builder(post_export_processed_model_info) @@ -277,9 +277,7 @@ def forward(self, *inputs, **kwargs): self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False or not self._execution_agent ): - device = _utils.get_device_from_module(self._original_module) or _utils.get_device_from_inputs( - inputs, kwargs - ) + device = _utils.get_device_from_module_and_inputs(self._original_module, inputs, kwargs) create_execution_session = ( build_gradient_graph or self._device != device @@ -305,18 +303,14 @@ def forward(self, *inputs, **kwargs): if self._runtime_options.enable_zero_stage3_support: self._append_pull_weight_trigger_as_input(kwargs, self._device) - # prepared_input_list = self.construct_inputs(inputs, kwargs) - prepared_input_map = self._graph_transition_manager._post_export_processed_model_info.construct_inputs( - inputs, kwargs, True + inputs, kwargs, True, self._device ) user_outputs = self._forward_class.apply(*prepared_input_map.values()) outputs = self._graph_transition_manager._post_export_processed_model_info.restore_outputs(user_outputs) - # print("outputs: ", outputs) - if ( create_execution_session and self._runtime_options.enable_tuning diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index 91825fc492208..2b9a259895793 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -153,7 +153,15 @@ def get_device_str(device: Union[str, int, torch.device]) -> str: return device -def get_device_from_module(module) -> Optional[torch.device]: +def get_device_from_module_and_inputs(module, inputs, kwargs): + """Get the device from the module and save it to self._device""" + + device = _get_device_from_module(module) or _get_device_from_inputs(inputs, kwargs) + + return device + + +def _get_device_from_module(module) -> Optional[torch.device]: """Returns the first device found in the `module`'s parameters or None Args: @@ -179,7 +187,7 @@ def get_device_from_module(module) -> Optional[torch.device]: return device -def get_device_from_inputs(args, kwargs) -> Optional[torch.device]: +def _get_device_from_inputs(args, kwargs) -> Optional[torch.device]: """Returns device from first PyTorch Tensor within args or kwargs Args: diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index a711386760f47..9afa50cba784d 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -30,7 +30,7 @@ import onnxruntime.training.ortmodule as ortmodule_module from onnxruntime.training.optim import AdamWMode, FusedAdam -from onnxruntime.training.ortmodule import DebugOptions, LogLevel, ORTModule, _fallback, _io, _utils +from onnxruntime.training.ortmodule import DebugOptions, LogLevel, ORTModule, _fallback, _utils from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient from onnxruntime.training.ortmodule.options import _SkipCheck @@ -447,14 +447,12 @@ def test_forward_call_single_positional_argument(): N, D_in, H, D_out = 64, 784, 500, 10 # noqa: N806 model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) - # ort_model = ORTModule(model) - from onnxruntime.training.ortmodule import DebugOptions, LogLevel, ORTModule - - ort_model = ORTModule(model, DebugOptions(save_onnx=True, log_level=LogLevel.INFO, onnx_prefix="2024_0102_01")) + ort_model = ORTModule(model) # Check that the original forward signature is preserved. assert inspect.signature(model.forward) == inspect.signature(ort_model.forward) x = torch.randn(N, D_in, device=device) + # Make sure model runs without any exception prediction = ort_model(x) assert prediction is not None @@ -624,7 +622,7 @@ def test_torch_nn_module_to_api(original_device, to_argument): x = x.to(to_argument) model(x) assert _utils.get_device_str( - model._torch_module._execution_manager(model._is_training())._graph_transition_manager._device + model._torch_module._execution_manager(model._is_training())._device ) == _utils.get_device_str(torch.device(to_argument)) @@ -839,13 +837,7 @@ def forward(self, input): device = "cuda" pt_model = NeuralNetTranspose(perm).to(device) - from onnxruntime.training.ortmodule import DebugOptions, LogLevel, ORTModule - - ort_model = ORTModule( - copy.deepcopy(pt_model), DebugOptions(save_onnx=True, log_level=LogLevel.INFO, onnx_prefix="2024_0103_01") - ) - - # ort_model = ORTModule(copy.deepcopy(pt_model)) + ort_model = ORTModule(copy.deepcopy(pt_model)) def run_step(model, x): prediction = model(x) @@ -854,11 +846,11 @@ def run_step(model, x): return prediction x = torch.randn(*shape, device=device, requires_grad=True) - # pt_prediction = run_step(pt_model, x) - run_step(ort_model, x) + pt_prediction = run_step(pt_model, x) + ort_prediction = run_step(ort_model, x) - # _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) - # _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) @pytest.mark.parametrize( @@ -2963,12 +2955,11 @@ def test_forward_data_and_model_on_different_devices(data_device, model_device): runtime_error.value ) else: - # ORT backend - with pytest.raises(_fallback.ORTModuleDeviceException) as runtime_error: + # ORT backend also throw the same exception because PyTorch run failed during export. + with pytest.raises(RuntimeError) as runtime_error: ort_model(x) - assert ( - f"Input argument to forward found on device {torch.device(x.device)}, but expected it to be on module device {ort_model._torch_module._execution_manager(ort_model._is_training())._graph_transition_manager._device}." - in str(runtime_error.value) + assert "Expected all tensors to be on the same device, but found at least two devices" in str( + runtime_error.value ) del os.environ["ORTMODULE_SKIPCHECK_POLICY"] @@ -3645,11 +3636,7 @@ def forward(self, pos_0, pos_1, *args, kw_0=None, kw_1=None, **kwargs): device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 # noqa: N806 model = KwargsNet(input_size=D_in, hidden_size=H, num_classes=D_out).to(device) - # model = ORTModule(model) - - from onnxruntime.training.ortmodule import DebugOptions, LogLevel, ORTModule - - model = ORTModule(model, DebugOptions(save_onnx=True, log_level=LogLevel.INFO, onnx_prefix="2024_0102_02")) + model = ORTModule(model) # Dummy inputs used pos_0 = torch.randn(N, D_in, device=device) @@ -3706,10 +3693,7 @@ def forward(self, a, b, c, d, *args, kw_0=None, **kwargs): # Modeling device = "cuda" model = UnusedNet().to(device) - from onnxruntime.training.ortmodule import DebugOptions, LogLevel, ORTModule - - model = ORTModule(model, DebugOptions(log_level=LogLevel.INFO)) - # model = ORTModule(model) + model = ORTModule(model) # Dummy data one = torch.FloatTensor([1]).to(device) @@ -4164,36 +4148,6 @@ def test_stateless_model_unspecified_device(): _test_helpers.assert_values_are_close(pt_y, ort_y) -# @pytest.mark.parametrize( -# "model", -# [ -# (UnusedBeginParameterNet(784, 500, 400, 10)), -# (UnusedMiddleParameterNet(784, 500, 400, 10)), -# (UnusedEndParameterNet(784, 500, 400, 10)), -# ], -# ) -# def test_unused_parameters_does_not_unnecessarily_reinitialize(model): -# device = "cuda" - -# N, D_in, H1, H2, D_out = 64, 784, 500, 400, 10 -# model = model.to(device) -# ort_model = ORTModule(copy.deepcopy(model)) -# training_manager = ort_model._torch_module._execution_manager(ort_model._is_training()) - -# x = torch.randn(N, D_in, device=device) -# _ = ort_model(x) - -# input_info = _io.parse_inputs_for_onnx_export( -# training_manager._module_parameters, -# training_manager._graph_transition_manager._exported_model_info.exported_model, -# training_manager._input_info.schema, -# x, -# {}, -# ) - -# assert not training_manager._reinitialize_graph_builder(input_info) - - def test_load_state_dict_for_wrapped_ortmodule(): class WrapperModule(torch.nn.Module): def __init__(self, ortmodule): @@ -4668,9 +4622,7 @@ def forward(self, batch, **kwargs): device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 # noqa: F841, N806 pt_model = ListDictKwargsNet(N, D_in).to(device) - ort_model = ORTModule( - copy.deepcopy(pt_model), DebugOptions(save_onnx=True, log_level=LogLevel.INFO, onnx_prefix="kwargsanddict") - ) + ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="kwargsanddict")) x = { "one_value": [torch.randn(N, D_in, device=device)], "two_value": [torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device)], @@ -5007,7 +4959,9 @@ def test_override_pytorch_exporter_kwargs(): model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) ort_model = ORTModule(model) - ort_model._torch_module._execution_manager(True)._export_extra_kwargs = {"custom_opsets": None} + ort_model._torch_module._execution_manager(True)._graph_transition_manager._export_extra_kwargs = { + "custom_opsets": None + } # Make sure model runs without any exception prediction = ort_model(x) @@ -5024,7 +4978,7 @@ def test_override_pytorch_exporter_kwargs__invalid(): model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) ort_model = ORTModule(model) - ort_model._torch_module._execution_manager(True)._export_extra_kwargs = {"verbose": False} + ort_model._torch_module._execution_manager(True)._graph_transition_manager._export_extra_kwargs = {"verbose": False} with pytest.raises(_fallback.ORTModuleONNXModelException) as type_error: _ = ort_model(x) assert "The following PyTorch exporter arguments cannot be specified: '{'verbose'}'." in str(type_error.value) @@ -5037,7 +4991,9 @@ class ORTModuleExtension(ORTModule): def __init__(self, module, debug_options=None): super().__init__(module, debug_options) for training_mode in [False, True]: - self._torch_module._execution_manager(training_mode)._export_extra_kwargs = {"verbose": None} + self._torch_module._execution_manager(training_mode)._graph_transition_manager._export_extra_kwargs = { + "verbose": None + } N, D_in, H, D_out = 64, 784, 500, 10 # noqa: N806 x = torch.randn(N, D_in, device=device) @@ -5057,7 +5013,9 @@ def __init__(self, module, debug_options=None): super().__init__(module, debug_options) # modify GraphExecutionManager internally for training_mode in [False, True]: - self._torch_module._execution_manager(training_mode)._export_extra_kwargs = {"custom_opsets": None} + self._torch_module._execution_manager( + training_mode + )._graph_transition_manager._model_info_for_export.export_extra_kwargs = {"custom_opsets": None} N, D_in, H, D_out = 64, 784, 500, 10 # noqa: N806 x = torch.randn(N, D_in, device=device) @@ -5288,16 +5246,12 @@ def run_step(model, x): ort_prediction, ort_loss = run_step(ort_model, ort_x) pt_prediction, pt_loss = run_step(pt_model, pt_x) if step == 0: - exported_model = ( - ort_model._torch_module._execution_manager._training_manager._graph_transition_manager._exported_model_info.exported_model - ) - # optimized_model = ort_model._torch_module._execution_manager._training_manager._optimized_model - for onx in [ - exported_model, + for onnx_model in [ + ort_model._torch_module._execution_manager._training_manager._graph_transition_manager._exported_model_info.exported_model, + ort_model._torch_module._execution_manager._training_manager._onnx_models.optimized_model, ]: - # "optimized_model"]: opv = None - for op in onx.opset_import: + for op in onnx_model.opset_import: if not op.domain: opv = op.version assert opv == 13 @@ -5349,7 +5303,11 @@ def test_serialize_ortmodule(): device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 # noqa: N806 pt_model = SerializationNet(D_in, H, D_out).to(device) - ort_model = ORTModule(copy.deepcopy(pt_model)) + + from onnxruntime.training.ortmodule import DebugOptions, LogLevel + + ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(log_level=LogLevel.INFO)) + # ort_model = ORTModule(copy.deepcopy(pt_model)) x_1 = torch.randn(N, D_in, device=device) x_2 = copy.deepcopy(x_1) From b2897a3aad9cbfdd1c22bf7f8633262308527281 Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Thu, 4 Jan 2024 13:48:52 +0000 Subject: [PATCH 05/32] fix --- .../python/training/ortmodule/_graph_transition_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index 76232de51d63b..a1d38842952c3 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -293,7 +293,7 @@ def use_cache_or_reconstruct_post_processed_model( flatten_kwargs, cur_kwargs_schema = _io._extract_schema(copy.copy(kwargs), self._device) need_export_model = GraphTransitionManager._export_check( - prev_model_info_for_export=self._exported_model_info, + prev_exported_model_info=self._exported_model_info, original_model_has_changed=self._original_model_has_changed, cur_args_schema=cur_args_schema, cur_kwargs_schema=cur_kwargs_schema, From a01cb88976f602e6183f7f157948ce94518a7ae0 Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Thu, 4 Jan 2024 13:49:49 +0000 Subject: [PATCH 06/32] minor --- .../python/training/ortmodule/_graph_execution_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 94dbdab7c4184..d3b186154fabc 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -304,7 +304,7 @@ def __getstate__(self): "_onnx_models", "_graph_builder", "_graph_info", - "_graph_transition_manager", + "_graph_transition_manager", # Not pickled as it is re-constructed in __setstate__ "_execution_agent", "_torch_alloc", "_torch_free", From 34cdba4dbc3161bb488d585306d6a2e24394ab21 Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Thu, 4 Jan 2024 16:43:18 +0000 Subject: [PATCH 07/32] fix --- .../ortmodule/_graph_execution_manager.py | 14 +- .../ortmodule/_graph_transition_manager.py | 224 ++++++++++-------- .../python/training/ortmodule/_io.py | 21 +- .../training/ortmodule/_training_manager.py | 3 - 4 files changed, 131 insertions(+), 131 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index d3b186154fabc..324955e045e63 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -124,16 +124,8 @@ def _initialize_graph_transition_manager(self): self._graph_transition_manager = GraphTransitionManager( flatten_module=self._flattened_module, export_mode=self._export_mode, - save=self._debug_options.save_onnx_models.save, - save_path=self._debug_options.save_onnx_models.path, - save_name_prefix=self._debug_options.save_onnx_models.name_prefix, - deepcopy_before_model_export=self._runtime_options.deepcopy_before_model_export, - torch_exporter_verbose_log=self._debug_options.logging.log_level <= LogLevel.INFO, - enable_zero_stage3_support=self._runtime_options.enable_zero_stage3_support, - onnx_opset_version=self._runtime_options.onnx_opset_version, - enable_custom_autograd_function=self._runtime_options.enable_custom_autograd_function, - enable_symbolic_shape_infer=self._runtime_options.run_symbolic_shape_infer, - exported_model_cache_dir=self._runtime_options.ortmodule_cache_dir, + debug_options=self._debug_options, + runtime_options=self._runtime_options, logger=self._logger, ) @@ -304,7 +296,7 @@ def __getstate__(self): "_onnx_models", "_graph_builder", "_graph_info", - "_graph_transition_manager", # Not pickled as it is re-constructed in __setstate__ + "_graph_transition_manager", # Not pickled as it is re-constructed in __setstate__ "_execution_agent", "_torch_alloc", "_torch_free", diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index a1d38842952c3..aa9d76b826408 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -24,13 +24,14 @@ PrimitiveType, unflatten_data_using_schema, ) -from onnxruntime.training.utils.torch_io_helper import _TensorStub from . import _io, _utils from ._fallback import ORTModuleDeviceException, ORTModuleIOError, ORTModuleONNXModelException, wrap_exception +from ._logger import LogLevel from ._onnx_models import _get_onnx_file_name, _save_model from ._utils import check_function_has_param, get_rank from ._zero_stage3_compatibility import stage3_export_context +from .options import DebugOptions, _RuntimeOptions class ExportedModelInfo: @@ -47,6 +48,10 @@ def __init__( exported_model: onnx.ModelProto, module_forward_output_schema: ORTModelInputOutputSchemaType, ): + # Used as a baseline to compare with the current inputs (args/kwargs) to see if the model needs to be re-exported. + self.module_forward_args_schema: ORTModelInputOutputSchemaType | None = module_forward_args_schema + self.module_forward_kwargs_schema: ORTModelInputOutputSchemaType | None = module_forward_kwargs_schema + # Input names parsed and then flatten from the model's forward function signature + buffers + parameters (since we use # keep_initializers_as_inputs=True for model export) # Be noted: all inputs are used by the model for its compute. @@ -54,25 +59,22 @@ def __init__( # A subset of onnx_graph_input_names. # Input names that require gradient parsed and then flatten from the model's forward function signature - # This should contains both the user input names, the buffer names, and the parameter names (since we use + # This should contain both the user input names, the buffer names, and the parameter names (since we use # keep_initializers_as_inputs=True for model export) # Be noted: all inputs are used by the model for its compute. self.onnx_graph_input_names_require_grad: list[str] = onnx_graph_input_names_require_grad - self.onnx_graph_input_names_user_defined = ( - onnx_graph_input_names_user_defined # The ONNX graph input names excluding the parameters, buffers. - ) + # Input names parsed from the model's forward function signature. + # Be noted: all inputs are used by the model for its compute. + # The ONNX graph input names exclude the parameters, and buffers. + self.onnx_graph_input_names_user_defined = onnx_graph_input_names_user_defined - # The ONNX graph input names excluding the parameters, buffers. + # A subset of onnx_graph_input_names_user_defined. self.onnx_graph_input_names_require_grad_user_defined = onnx_graph_input_names_require_grad_user_defined # Exported model proto. self.exported_model: onnx.ModelProto | None = exported_model - # Used as a baseline to compare with the current inputs (args/kwargs) to see if the model needs to be re-exported. - self.module_forward_args_schema: ORTModelInputOutputSchemaType | None = module_forward_args_schema - self.module_forward_kwargs_schema: ORTModelInputOutputSchemaType | None = module_forward_kwargs_schema - # Used for unflattening the outputs from the ORT forward run. self.module_forward_output_schema: ORTModelInputOutputSchemaType | None = module_forward_output_schema @@ -106,6 +108,14 @@ def __init__( ): self._flattened_module = flatten_module + # Input names parsed from the model's forward function signature. + # Be noted: all inputs are used by the model for its compute. + # The ONNX graph input names exclude the parameters, and buffers. + self.onnx_graph_input_names_user_defined = onnx_graph_input_names_user_defined + + # A subset of onnx_graph_input_names_user_defined. + self.onnx_graph_input_names_require_grad_user_defined = onnx_graph_input_names_require_grad_user_defined + # Input names for the pre-gradient-build graph. # This may be different with the one in ExportedGraph since we may modify the graph inputs as needed # for example when memory efficient gradient management is enabled. @@ -120,35 +130,18 @@ def __init__( # e.g. {"input1": {0: "input1_dim0", 1: "input1_dim1"}, "input2": {0: "input2_dim0"}} self.onnx_graph_input_dynamic_axes_map: dict[str, dict[int, str]] = onnx_graph_input_dynamic_axes_map - self.buffer_for_ort_runs: dict[str, torch.Tensor] = OrderedDict() - self.onnx_graph_input_names_user_defined = ( - onnx_graph_input_names_user_defined # The ONNX graph input names excluding the parameters, buffers. - ) - - # The ONNX graph input names excluding the parameters, buffers. - self.onnx_graph_input_names_require_grad_user_defined = onnx_graph_input_names_require_grad_user_defined - self._post_export_processed_model: onnx.ModelProto | None = post_export_processed_model # A function to access the input data from the args and kwargs. - # If it is not None, the length is same as onnx_graph_input_names. + # If it is not None, the length is same as onnx_graph_input_names_user_defined. # For i-th input name, we can use the i-th function to get the input data from args and kwargs. self.data_accessor: list[callable] | None = data_accessor # Used for unflattening the outputs from the ORT forward run. self.module_forward_output_schema: ORTModelInputOutputSchemaType | None = module_forward_output_schema - # Create the buffers for the inputs that are either parameters or buffers in the original module. - # For user inputs, fill with None for now, and will be filled dynamically during the forward run. - parameter_names = {k: v for k, v in self._flattened_module.named_parameters()} - buffer_names = {k: v for k, v in self._flattened_module.named_buffers()} - for input_name in self.onnx_graph_input_names: - if input_name in parameter_names: - self.buffer_for_ort_runs[input_name] = parameter_names[input_name] - elif input_name in buffer_names: - self.buffer_for_ort_runs[input_name] = buffer_names[input_name] - else: - self.buffer_for_ort_runs[input_name] = None + # A buffer to hold the inputs for the ORT forward run. For performance, we reuse the same buffer for each run. + self._buffer_for_ort_runs: dict[str, torch.Tensor] = OrderedDict() def __str__(self): return f"""PostExportProcessedModelInfo class: @@ -157,7 +150,7 @@ def __str__(self): \tonnx_graph_input_dynamic_axes_map: {self.onnx_graph_input_dynamic_axes_map} \tonnx_graph_input_names_user_defined: {self.onnx_graph_input_names_user_defined} \tonnx_graph_input_names_require_grad_user_defined: {self.onnx_graph_input_names_require_grad_user_defined} - \tbuffer_for_ort_runs.keys(): {self.buffer_for_ort_runs.keys()} + \tbuffer_for_ort_runs.keys(): {self._buffer_for_ort_runs.keys()} """ def __repro__(self): @@ -175,28 +168,39 @@ def construct_inputs( The inputs are constructed in the order they appear in the model's forward function signature """ + # First time construct the buffer for the ORT forward run. + if len(self._buffer_for_ort_runs) == 0: + # Create the buffers for the inputs that are either parameters or buffers in the original module. + # For user inputs, fill with None for now, and will be filled dynamically during the forward run. + parameter_names = {k: v for k, v in self._flattened_module.named_parameters()} + buffer_names = {k: v for k, v in self._flattened_module.named_buffers()} + for input_name in self.onnx_graph_input_names: + if input_name in parameter_names: + self._buffer_for_ort_runs[input_name] = parameter_names[input_name] + elif input_name in buffer_names: + self._buffer_for_ort_runs[input_name] = buffer_names[input_name] + else: + self._buffer_for_ort_runs[input_name] = None + for name in self.onnx_graph_input_names_user_defined: if name in self.data_accessor: - assert name in self.buffer_for_ort_runs, f"{name} is not in buffer_for_ort_runs" + assert name in self._buffer_for_ort_runs, f"{name} is not in buffer_for_ort_runs" data = self.data_accessor[name](args, kwargs) if PrimitiveType.is_primitive_type(data) and constant_as_tensor: data = PrimitiveType.get_tensor(data, device) - self.buffer_for_ort_runs[name] = data + self._buffer_for_ort_runs[name] = data else: raise wrap_exception( ORTModuleONNXModelException, RuntimeError(f"Input is present in ONNX graph but not provided: {name}."), ) - return self.buffer_for_ort_runs + return self._buffer_for_ort_runs def restore_outputs(self, ort_flatten_outputs: list[torch.Tensor]): """Restores the outputs from the ORT forward run, back to the original data structure""" try: - # Need to distinguish between a single output and a tuple (having a single tensor) - if len(ort_flatten_outputs) == 1 and self.module_forward_output_schema is _TensorStub: - return ort_flatten_outputs[0] return unflatten_data_using_schema(ort_flatten_outputs, self.module_forward_output_schema) except TypeError as e: raise wrap_exception( @@ -212,38 +216,20 @@ def __init__( self, flatten_module: torch.nn.Module, export_mode: int, - save: bool, - save_path: str, - save_name_prefix: str, - deepcopy_before_model_export: bool, - torch_exporter_verbose_log: bool, - enable_zero_stage3_support: bool, # TODO(): implement as a plugin - onnx_opset_version: int, - enable_custom_autograd_function: bool, # TODO(): implement as a plugin - enable_symbolic_shape_infer: bool, - exported_model_cache_dir: str, + debug_options: DebugOptions, + runtime_options: _RuntimeOptions, logger: logging.Logger, ): self._device = _utils._get_device_from_module(flatten_module) self._export_mode = export_mode - # Debug options - self._save = save - self._save_path = save_path - self._save_name_prefix = save_name_prefix - - # Runtime options - self._deepcopy_before_model_export = deepcopy_before_model_export - self._torch_exporter_verbose_log = torch_exporter_verbose_log - self._enable_zero_stage3_support = enable_zero_stage3_support - self._onnx_opset_version = onnx_opset_version - self._enable_custom_autograd_function = enable_custom_autograd_function - self._run_symbolic_shape_infer = enable_symbolic_shape_infer - self._ortmodule_cache_dir = exported_model_cache_dir + self._debug_options = debug_options + self._runtime_options = runtime_options self._export_extra_kwargs = {} self._logger = logger + self._torch_exporter_verbose_log = self._debug_options.log_level < LogLevel.WARNING # A signal to indicate if the original model has changed and need a re-export. self._original_model_has_changed = False @@ -274,7 +260,7 @@ def use_cache_or_reconstruct_post_processed_model( Return True if the model can be reused, otherwise, return False. The model can be reused when the following conditions are met: a. The model has been exported before, and the inputs (args/outputs) schemas are the same as the previous ones. - b. In training mode, the graph inputs requiring gradient are the same as the previous ones. + b. If it is in training mode, the graph inputs requiring gradient are the same as the previous ones. """ @@ -287,8 +273,7 @@ def use_cache_or_reconstruct_post_processed_model( ORTModuleDeviceException, RuntimeError("A device must be specified in the model or inputs!") ) - # Extract the schema from the args and kwargs, and compare with the cached one. - # This check ideally is not needed as we already have the above check, but it is added as a safeguard. + # Extract the schema from the args and kwargs, and compare it with the pre-exported one if already exported. flatten_args, cur_args_schema = _io._extract_schema(copy.copy(args), self._device) flatten_kwargs, cur_kwargs_schema = _io._extract_schema(copy.copy(kwargs), self._device) @@ -301,17 +286,37 @@ def use_cache_or_reconstruct_post_processed_model( ) if need_export_model: - # Set the information used to unflatten the inputs during the flatten module forward run. - # Must be set before calling exporting the model. + # Note related to the _io.FlattenedModule export!!! + # + # The _io.FlattenedModule serves as a module wrapper designed to support tuple inputs and outputs for + # PyTorch run during ONNX export. (Remember the PyTorch exporter handles tuple inputs and outputs better.) + # Internally, it facilitates the acceptance of tuple inputs and generation of tuple outputs by invoking + # the original module's forward function. The workflow involves the following steps: + + # 1. Prior to export, both args and kwargs are flattened into a 1-D tensor list, and a schema for the + # flattened args and kwargs is generated. This schema is essential for the subsequent unflattening + # process. + + # 2. The flattened inputs (args + kwargs) are passed to the _io.FlattenedModule's forward run. + + # 3. The args schema and kwargs schema, etc are conveyed to the _io.FlattenedModule by setting the + # corresponding attributes. + + # 4. Within the _io.FlattenedModule's forward run, the inputs are unflattened to the original args and + # kwargs using the associated schemas, and then they are passed to the original module's forward function. + + # 5. Upon the completion of the forward function, the outputs from the original module are flattened and + # returned to the caller. + + # 6. The 1-D flattened output tensors retain the same order as the outputs from the ONNX Runtime (ORT) + # forward run. To facilitate unflattening during subsequent ORT runs, the output schema is saved as + # an attribute named `_output_schema` in the _io.FlattenedModule. + flatten_inputs = flatten_args + flatten_kwargs self._flatten_module._device = self._device self._flatten_module._args_schema = cur_args_schema self._flatten_module._kwargs_schema = cur_kwargs_schema self._flatten_module._num_positionals = len(flatten_args) - flatten_inputs = flatten_args + flatten_kwargs - - # Check graph inputs parsed from the model's forward function signature and current inputs, - # if they are different, we need to re-export the model. cur_model_info_for_export = _io.parse_inputs_for_onnx_export( self._module_forward_func_parameters, args, @@ -322,31 +327,30 @@ def use_cache_or_reconstruct_post_processed_model( self._export_extra_kwargs, ) - # Start exporting the model by passing the 1-D flatten tensor list containing all args plus kwargs. ( exported_model, - module_output_schema, + module_output_schema, # Retrieved from _io.FlattenedModule's _output_schema onnx_graph_input_names, onnx_graph_input_names_require_grad, ) = GraphTransitionManager._export_model( flattened_module=self._flatten_module, model_info_for_export=cur_model_info_for_export, flatten_module_inputs=flatten_inputs, - run_symbolic_shape_infer=self._run_symbolic_shape_infer, - deepcopy_before_model_export=self._deepcopy_before_model_export, + run_symbolic_shape_infer=self._runtime_options.run_symbolic_shape_infer, + deepcopy_before_model_export=self._runtime_options.deepcopy_before_model_export, device=self._device, - ortmodule_cache_dir=self._ortmodule_cache_dir, - enable_custom_autograd_function=self._enable_custom_autograd_function, - enable_zero_stage3_support=self._enable_zero_stage3_support, - onnx_opset_version=self._onnx_opset_version, + ortmodule_cache_dir=self._runtime_options.ortmodule_cache_dir, + enable_custom_autograd_function=self._runtime_options.enable_custom_autograd_function, + enable_zero_stage3_support=self._runtime_options.enable_zero_stage3_support, + onnx_opset_version=self._runtime_options.onnx_opset_version, torch_exporter_verbose_log=self._torch_exporter_verbose_log, stage3_param_handle=self, logger=self._logger, ) - # Get the intersection of all user defined input names (parsed from forward function signature) and - # the exported model input names including both user defined names and parameter/buffer names. - # It is possible some user defined input names are not in the exported model input names, if it is not used + # Get the intersection of all user-defined input names (parsed from forward function signature) and + # the exported model input names including both user-defined input names and training parameter/buffer names. + # It is possible some user-defined input names are not in the exported model input names, if it is not used # by the model for its compute. onnx_graph_input_names_user_defined = [ input_name @@ -376,12 +380,14 @@ def use_cache_or_reconstruct_post_processed_model( self._original_model_has_changed = False # Save the exported model - if self._save: + if self._debug_options.save_onnx_models.save: _save_model( self._exported_model_info.exported_model, os.path.join( - self._save_path, - _get_onnx_file_name(self._save_name_prefix, "torch_exported", self._export_mode), + self._debug_options.save_onnx_models.path, + _get_onnx_file_name( + self._debug_options.save_onnx_models.name_prefix, "torch_exported", self._export_mode + ), ), ) @@ -393,9 +399,11 @@ def use_cache_or_reconstruct_post_processed_model( else: need_re_processed, updated_onnx_graph_input_requires_grads = GraphTransitionManager._reprocess_check( flatten_module=self._flatten_module, - prev_exported_model_info=self._exported_model_info, + exported_model_info=self._exported_model_info, export_mode=self._export_mode, - cur_model_info_for_export=self._model_info_for_export, + model_info_for_export=self._model_info_for_export, + args=args, + kwargs=kwargs, ) if need_re_processed: # Update the onnx_graph_input_names_require_grads to make it a new default baseline to compare @@ -406,23 +414,24 @@ def use_cache_or_reconstruct_post_processed_model( # At this point, the exported model is ready, and we can start post-export processing. self._post_export_processed_model_info = GraphTransitionManager._post_export_process( flatten_module=self._flatten_module, - device=self._device, export_mode=self._export_mode, exported_model_info=self._exported_model_info, model_info_for_export=self._model_info_for_export, - enable_custom_autograd_function=self._enable_custom_autograd_function, - enable_zero_stage3_support=self._enable_zero_stage3_support, + enable_custom_autograd_function=self._runtime_options.enable_custom_autograd_function, + enable_zero_stage3_support=self._runtime_options.enable_zero_stage3_support, stage3_param_handle=self, logger=self._logger, ) # Save the post_processed model - if self._save: + if self._debug_options.save_onnx_models.save: _save_model( self._post_export_processed_model_info._post_export_processed_model, os.path.join( - self._save_path, - _get_onnx_file_name(self._save_name_prefix, "post_processed", self._export_mode), + self._debug_options.save_onnx_models.path, + _get_onnx_file_name( + self._debug_options.save_onnx_models.name_prefix, "post_processed", self._export_mode + ), ), ) @@ -438,8 +447,9 @@ def _export_check( ): """Check if the model needs to be exported, if yes, return True. - For the following cases, return True: + If either of the following conditions is met, return True: 1. The model has never been exported before. + 2. The original_model_has_changed is True. 3. The model input schema parsed from args and kwargs has changed. """ @@ -459,7 +469,12 @@ def _export_check( @staticmethod def _reprocess_check( - flatten_module, prev_exported_model_info, export_mode, cur_model_info_for_export: _io.ModelInfoForExport + flatten_module: _io._FlattenedModule, + exported_model_info: ExportedModelInfo, + export_mode: int, + model_info_for_export: _io.ModelInfoForExport, + args: Sequence[ORTModelInputOutputType], + kwargs: Mapping[str, ORTModelInputOutputType], ) -> bool: """Check if the exported model needs to be re-processed, if yes, return True and the updated onnx_graph_input_requires_grads. @@ -469,7 +484,7 @@ def _reprocess_check( requiring gradient change. """ if export_mode == torch.onnx.TrainingMode.TRAINING: - # If inputs requiring gradient change from forward to the next, the module_gradient_graph_builder + # If inputs requiring gradient change from forward to the next, the gradient graph builder # needs to be reinitialized so it can compute the backward output for the new inputs that require_grad # Reinitialize graph builder if the inputs or initializers requiring gradient have changed. @@ -478,15 +493,20 @@ def _reprocess_check( onnx_graph_input_requires_grads = [] parameter_names = {k: v for k, v in flatten_module.named_parameters()} - for input_name in prev_exported_model_info.onnx_graph_input_names: + for input_name in exported_model_info.onnx_graph_input_names: if input_name in parameter_names and parameter_names[input_name].requires_grad: onnx_graph_input_requires_grads.append(input_name) else: - # If not in parameter list, then it would come from user defined inputs. - if input_name in cur_model_info_for_export.onnx_graph_input_names_require_grad: + # If not in the parameter list, then it would come from user-defined inputs. + assert ( + input_name in model_info_for_export.data_accessor + ), f"{input_name} is not in model_info_for_export.onnx_graph_input_names_user_defined" + # We assume the data accessor should be the same as the one used for the previous export, because + # there is args and kwargs schema check during export check phase. + if model_info_for_export.data_accessor[input_name](args, kwargs).requires_grad: onnx_graph_input_requires_grads.append(input_name) - if onnx_graph_input_requires_grads == prev_exported_model_info.onnx_graph_input_names_require_grad: + if onnx_graph_input_requires_grads == exported_model_info.onnx_graph_input_names_require_grad: return False, [] return True, onnx_graph_input_requires_grads @@ -495,7 +515,6 @@ def _reprocess_check( @staticmethod def _post_export_process( flatten_module, - device, export_mode, exported_model_info: ExportedModelInfo, model_info_for_export: _io.ModelInfoForExport, @@ -508,7 +527,7 @@ def _post_export_process( # Deepcopy the exported model, in case modification affects the exported model. - # TODO(): Do pre-grad graph modification as needed, for memory efficient gradient management, etc. + # TODO(): Do pre-grad graph modification as needed, for memory-efficient gradient management, etc. post_processed_model = copy.deepcopy(exported_model_info.exported_model) if export_mode == torch.onnx.TrainingMode.TRAINING: @@ -526,11 +545,6 @@ def _post_export_process( [name for name, _ in flatten_module.named_parameters()], ) - # Cannot append pull weight trigger name to input names as following, otherwise, the later check ( - # https://github.com/microsoft/onnxruntime/blob/068300d97eb25e5b52324e7af54a45ed1fa6a4c3/orttraining/orttraining/python/training/ortmodule/_training_manager.py#L466C18-L466C18) - # find input info mismatch, will re-initialize the graph builder. - # self._input_info.require_grad_names.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME) - post_export_processed_model_info = PostExportProcessedModelInfo( flatten_module, exported_model_info.onnx_graph_input_names_user_defined, diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 7e7ab81b845eb..607af8104056a 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -200,10 +200,14 @@ def __repr__(self) -> str: return self.__str__() -def _arg_access__with_index_func(arg_index, args, kwargs): +def _arg_access_with_index_func(arg_index, args, kwargs): return args[arg_index] +def _kwarg_access_with_name_func(name, args, kwargs): + return kwargs[name] + + def parse_inputs_for_onnx_export( all_input_parameters: List[inspect.Parameter], args: Sequence[ORTModelInputOutputType], @@ -354,7 +358,7 @@ def _access_func2(key, cur_func, args, kwargs): var_positional_idx += 1 inp = args[args_i] - _add_input(name, inp, onnx_graph_input_names, partial(_arg_access__with_index_func, args_i)) + _add_input(name, inp, onnx_graph_input_names, partial(_arg_access_with_index_func, args_i)) elif ( input_parameter.kind == inspect.Parameter.POSITIONAL_ONLY or input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD @@ -368,29 +372,22 @@ def _access_func2(key, cur_func, args, kwargs): if input_idx < len(args) and args[input_idx] is not None: inp = args[input_idx] - access_func = partial(_arg_access__with_index_func, input_idx) + access_func = partial(_arg_access_with_index_func, input_idx) elif name in kwargs and kwargs[name] is not None: inp = kwargs[name] - def _access_func5(name, args, kwargs): - return kwargs[name] - - access_func = partial(_access_func5, name) + access_func = partial(_kwarg_access_with_name_func, name) _add_input(name, inp, onnx_graph_input_names, access_func) elif input_parameter.kind == inspect.Parameter.VAR_KEYWORD: # **kwargs is always the last argument of forward() for name, inp in kwargs.items(): - - def _access_func6(name, args, kwargs): - return kwargs[name] - _add_input( name, inp, onnx_graph_input_names, - partial(_access_func6, name), + partial(_kwarg_access_with_name_func, name), ) exported_graph = ModelInfoForExport( diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 7db6d91a82b36..37010f0829edf 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -206,9 +206,6 @@ def backward(ctx, *grad_outputs): self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD) - # print("transferred_backward_outputs: ", transferred_backward_outputs) - # print("self._gradient_map: ", self._gradient_map) - return tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map) return _ORTModuleFunction From 8d34f43ec286497268d8b07a6d47406dce315c7c Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Fri, 5 Jan 2024 02:43:32 +0000 Subject: [PATCH 08/32] fixes --- .../ortmodule/_graph_execution_manager.py | 4 +- .../ortmodule/_graph_transition_manager.py | 30 +++++++++------ .../python/training/ortmodule/_logger.py | 38 ++++++++++++++----- .../python/training/utils/torch_to_onnx.py | 6 --- .../python/orttraining_test_ortmodule_api.py | 20 +++++----- 5 files changed, 58 insertions(+), 40 deletions(-) delete mode 100644 orttraining/orttraining/python/training/utils/torch_to_onnx.py diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 324955e045e63..32d8fd6b0003b 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -52,7 +52,6 @@ def __init__( """Manages construction and execution of ONNX graphs""" super().__init__(module._original_module) - super(GraphExecutionInterface, self).__init__() # IMPORTANT: Debug and Fallback must the configured first self._debug_options = debug_options @@ -83,7 +82,7 @@ def __init__( self._runtime_inspector = RuntimeInspector(self._logger, self._original_module) self._runtime_inspector.memory_ob.enable_memory_stats_by_step(self._runtime_options.print_memory_stat_by_step) - # Tracker for ORTModule model export, session creation overhead. + # Tracker for session creation overhead. self.time_tracker = _logger.TimeTracker() self.is_rocm_pytorch = bool(torch.version.hip is not None and ROCM_HOME is not None) @@ -126,6 +125,7 @@ def _initialize_graph_transition_manager(self): export_mode=self._export_mode, debug_options=self._debug_options, runtime_options=self._runtime_options, + time_tracker=self.time_tracker, logger=self._logger, ) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index aa9d76b826408..186292b1fd896 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -27,7 +27,7 @@ from . import _io, _utils from ._fallback import ORTModuleDeviceException, ORTModuleIOError, ORTModuleONNXModelException, wrap_exception -from ._logger import LogLevel +from ._logger import LogLevel, ORTModuleInitPhase, SuppressLogs, TimeTracker, TrackTimeForStaticFunction from ._onnx_models import _get_onnx_file_name, _save_model from ._utils import check_function_has_param, get_rank from ._zero_stage3_compatibility import stage3_export_context @@ -218,6 +218,7 @@ def __init__( export_mode: int, debug_options: DebugOptions, runtime_options: _RuntimeOptions, + time_tracker: TimeTracker, logger: logging.Logger, ): self._device = _utils._get_device_from_module(flatten_module) @@ -229,7 +230,9 @@ def __init__( self._export_extra_kwargs = {} self._logger = logger - self._torch_exporter_verbose_log = self._debug_options.log_level < LogLevel.WARNING + + # Tracker for ORTModule model export. + self._time_tracker = time_tracker # A signal to indicate if the original model has changed and need a re-export. self._original_model_has_changed = False @@ -343,8 +346,9 @@ def use_cache_or_reconstruct_post_processed_model( enable_custom_autograd_function=self._runtime_options.enable_custom_autograd_function, enable_zero_stage3_support=self._runtime_options.enable_zero_stage3_support, onnx_opset_version=self._runtime_options.onnx_opset_version, - torch_exporter_verbose_log=self._torch_exporter_verbose_log, stage3_param_handle=self, + debug_options=self._debug_options, + time_tracker=self._time_tracker, logger=self._logger, ) @@ -494,17 +498,18 @@ def _reprocess_check( onnx_graph_input_requires_grads = [] parameter_names = {k: v for k, v in flatten_module.named_parameters()} for input_name in exported_model_info.onnx_graph_input_names: - if input_name in parameter_names and parameter_names[input_name].requires_grad: - onnx_graph_input_requires_grads.append(input_name) - else: - # If not in the parameter list, then it would come from user-defined inputs. + if input_name in exported_model_info.onnx_graph_input_names_user_defined: assert ( input_name in model_info_for_export.data_accessor - ), f"{input_name} is not in model_info_for_export.onnx_graph_input_names_user_defined" + ), f"{input_name} model_info_for_export.data_accessor" # We assume the data accessor should be the same as the one used for the previous export, because # there is args and kwargs schema check during export check phase. if model_info_for_export.data_accessor[input_name](args, kwargs).requires_grad: onnx_graph_input_requires_grads.append(input_name) + else: + assert input_name in parameter_names, f"{input_name} not exist parameter_names" + if parameter_names[input_name].requires_grad: + onnx_graph_input_requires_grads.append(input_name) if onnx_graph_input_requires_grads == exported_model_info.onnx_graph_input_names_require_grad: return False, [] @@ -563,10 +568,11 @@ def _post_export_process( return post_export_processed_model_info - # @_logger.TrackTime(_logger.ORTModuleInitPhase.EXPORT) - # @_logger.SuppressLogs(_logger.ORTModuleInitPhase.EXPORT, is_ort_filter=False) @staticmethod + @TrackTimeForStaticFunction(ORTModuleInitPhase.EXPORT) + @SuppressLogs(ORTModuleInitPhase.EXPORT, is_ort_filter=False) def _export_model( + *, flattened_module: torch.nn.Module, model_info_for_export: _io.ModelInfoForExport, flatten_module_inputs: Sequence[ORTModelInputOutputType], @@ -577,14 +583,16 @@ def _export_model( enable_custom_autograd_function: bool, enable_zero_stage3_support: bool, onnx_opset_version: int, - torch_exporter_verbose_log: bool, stage3_param_handle: type, + debug_options: DebugOptions, + time_tracker: TimeTracker, logger: logging.Logger, ) -> tuple[onnx.ModelProto, ORTModelInputOutputSchemaType, list[str], list[str]]: # Record random states here and restore later in case any of them gets changed during the export, # e.g., some sympy functions in symbolic_shape_infer will change Python's random state. random_states = _utils.get_random_states() + torch_exporter_verbose_log = debug_options.log_level < LogLevel.WARNING from onnxruntime.training.utils.hooks._subscriber_manager import no_increase_global_step with no_increase_global_step(): diff --git a/orttraining/orttraining/python/training/ortmodule/_logger.py b/orttraining/orttraining/python/training/ortmodule/_logger.py index a01db28374b8d..6cfd947117d5d 100644 --- a/orttraining/orttraining/python/training/ortmodule/_logger.py +++ b/orttraining/orttraining/python/training/ortmodule/_logger.py @@ -165,6 +165,24 @@ def wrapper(graph_execution_manager, *args, **kwargs): return wrapper +class TrackTimeForStaticFunction: + """A function decorator to track time spent in different phases of ORT backend first-time initialization.""" + + def __init__(self, phase: ORTModuleInitPhase): + self.phase = phase + + def __call__(self, func: Callable): + def wrapper(*args, **kwargs): + if "time_tracker" not in kwargs: + raise RuntimeError("The function to be tracked must have a 'time_tracker' kwarg.") + kwargs["time_tracker"].start(self.phase) + result = func(*args, **kwargs) + kwargs["time_tracker"].end(self.phase) + return result + + return wrapper + + @contextmanager def _suppress_os_stream_output(enable=True, on_exit: Optional[Callable] = None): """Suppress output from being printed to stdout and stderr. @@ -255,25 +273,25 @@ def __init__(self, phase: ORTModuleInitPhase, is_ort_filter=True): self.is_ort_filter = is_ort_filter def __call__(self, func: Callable): - def wrapper(graph_execution_manager, *args, **kwargs): - if not hasattr(graph_execution_manager, "_logger"): - raise RuntimeError("The class of the function to be tracked must have a '_logger' attribute.") + def wrapper(*args, **kwargs): + if "logger" not in kwargs: + raise RuntimeError("The function to be tracked must have a 'logger' kwarg.") - if not hasattr(graph_execution_manager, "_debug_options"): - raise RuntimeError("The class of the function to be tracked must have a '_debug_options' attribute.") + if "debug_options" not in kwargs: + raise RuntimeError("The function to be tracked must have a 'debug_options' kwarg.") with _suppress_os_stream_output( - enable=graph_execution_manager._debug_options.log_level >= LogLevel.DEVINFO, + enable=kwargs["debug_options"].log_level >= LogLevel.DEVINFO, on_exit=partial( _log_with_filter, - graph_execution_manager._logger, - graph_execution_manager._debug_options.onnxruntime_log_filter + kwargs["logger"], + kwargs["debug_options"].onnxruntime_log_filter if self.is_ort_filter - else graph_execution_manager._debug_options.torch_exporter_filter, + else kwargs["debug_options"].torch_exporter_filter, self.phase.to_string(), ), ): - result = func(graph_execution_manager, *args, **kwargs) + result = func(*args, **kwargs) return result return wrapper diff --git a/orttraining/orttraining/python/training/utils/torch_to_onnx.py b/orttraining/orttraining/python/training/utils/torch_to_onnx.py deleted file mode 100644 index f5c1fb13cbbc2..0000000000000 --- a/orttraining/orttraining/python/training/utils/torch_to_onnx.py +++ /dev/null @@ -1,6 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -from __future__ import annotations diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 9afa50cba784d..cd0bfc54275f0 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -2955,11 +2955,12 @@ def test_forward_data_and_model_on_different_devices(data_device, model_device): runtime_error.value ) else: - # ORT backend also throw the same exception because PyTorch run failed during export. - with pytest.raises(RuntimeError) as runtime_error: + # ORT backend + with pytest.raises(_fallback.ORTModuleDeviceException) as runtime_error: ort_model(x) - assert "Expected all tensors to be on the same device, but found at least two devices" in str( - runtime_error.value + assert ( + f"Input argument to forward found on device {torch.device(x.device)}, but expected it to be on module device {ort_model._torch_module._execution_manager(ort_model._is_training())._device}." + in str(runtime_error.value) ) del os.environ["ORTMODULE_SKIPCHECK_POLICY"] @@ -5013,9 +5014,9 @@ def __init__(self, module, debug_options=None): super().__init__(module, debug_options) # modify GraphExecutionManager internally for training_mode in [False, True]: - self._torch_module._execution_manager( - training_mode - )._graph_transition_manager._model_info_for_export.export_extra_kwargs = {"custom_opsets": None} + self._torch_module._execution_manager(training_mode)._graph_transition_manager._export_extra_kwargs = { + "custom_opsets": None + } N, D_in, H, D_out = 64, 784, 500, 10 # noqa: N806 x = torch.randn(N, D_in, device=device) @@ -5304,10 +5305,7 @@ def test_serialize_ortmodule(): N, D_in, H, D_out = 64, 784, 500, 10 # noqa: N806 pt_model = SerializationNet(D_in, H, D_out).to(device) - from onnxruntime.training.ortmodule import DebugOptions, LogLevel - - ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(log_level=LogLevel.INFO)) - # ort_model = ORTModule(copy.deepcopy(pt_model)) + ort_model = ORTModule(copy.deepcopy(pt_model)) x_1 = torch.randn(N, D_in, device=device) x_2 = copy.deepcopy(x_1) From d990e0f2eff49d28913b57260a040cdd58f534cd Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Fri, 5 Jan 2024 07:55:06 +0000 Subject: [PATCH 09/32] fix --- .../test/python/orttraining_test_ortmodule_onnx_ops.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py index 4f0925c5c855b..a0150ea9dede2 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py @@ -68,7 +68,9 @@ def run_step(model, x): self.assert_values_are_close(ort_prediction, pt_prediction, **kwargs) self.assert_gradients_match_and_reset_gradient(ort_model, pt_model, **kwargs) - onnx_graph_inf = ort_model._torch_module._execution_manager._training_manager._onnx_models.exported_model + onnx_graph_inf = ( + ort_model._torch_module._execution_manager._training_manager._graph_transition_manager._exported_model_info.exported_model + ) onnx_graph_train = ort_model._torch_module._execution_manager._training_manager._onnx_models.optimized_model if debug: with open("debug_%s_ortmodule_infer.onnx" % name, "wb") as f: From 29c8a98200c041923512aceec56d20e73f930cbf Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Mon, 8 Jan 2024 06:51:16 +0000 Subject: [PATCH 10/32] fix --- .../ortmodule/_graph_transition_manager.py | 68 +++++----- .../python/training/ortmodule/_io.py | 122 ++++++++++++++---- 2 files changed, 134 insertions(+), 56 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index 186292b1fd896..15fa5faaa8190 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -104,7 +104,7 @@ def __init__( onnx_graph_input_dynamic_axes_map: dict[str, dict[int, str]], module_forward_output_schema: ORTModelInputOutputSchemaType, post_export_processed_model: onnx.ModelProto, - data_accessor: list[callable], + onnx_graph_input_data_accessor: dict[str, callable], ): self._flattened_module = flatten_module @@ -135,7 +135,7 @@ def __init__( # A function to access the input data from the args and kwargs. # If it is not None, the length is same as onnx_graph_input_names_user_defined. # For i-th input name, we can use the i-th function to get the input data from args and kwargs. - self.data_accessor: list[callable] | None = data_accessor + self.onnx_graph_input_data_accessor: dict[str, callable] | None = onnx_graph_input_data_accessor # Used for unflattening the outputs from the ORT forward run. self.module_forward_output_schema: ORTModelInputOutputSchemaType | None = module_forward_output_schema @@ -183,9 +183,9 @@ def construct_inputs( self._buffer_for_ort_runs[input_name] = None for name in self.onnx_graph_input_names_user_defined: - if name in self.data_accessor: + if name in self.onnx_graph_input_data_accessor: assert name in self._buffer_for_ort_runs, f"{name} is not in buffer_for_ort_runs" - data = self.data_accessor[name](args, kwargs) + data = self.onnx_graph_input_data_accessor[name](args, kwargs) if PrimitiveType.is_primitive_type(data) and constant_as_tensor: data = PrimitiveType.get_tensor(data, device) self._buffer_for_ort_runs[name] = data @@ -277,14 +277,22 @@ def use_cache_or_reconstruct_post_processed_model( ) # Extract the schema from the args and kwargs, and compare it with the pre-exported one if already exported. - flatten_args, cur_args_schema = _io._extract_schema(copy.copy(args), self._device) - flatten_kwargs, cur_kwargs_schema = _io._extract_schema(copy.copy(kwargs), self._device) + + cur_model_info_for_export = _io.parse_inputs_for_onnx_export( + self._module_forward_func_parameters, + args, + kwargs, + True, + self._device, + self._export_mode, + self._export_extra_kwargs, + ) need_export_model = GraphTransitionManager._export_check( prev_exported_model_info=self._exported_model_info, original_model_has_changed=self._original_model_has_changed, - cur_args_schema=cur_args_schema, - cur_kwargs_schema=cur_kwargs_schema, + cur_args_schema=cur_model_info_for_export.onnx_graph_input_arg_schema, + cur_kwargs_schema=cur_model_info_for_export.onnx_graph_input_kwarg_schema, logger=self._logger, ) @@ -314,21 +322,19 @@ def use_cache_or_reconstruct_post_processed_model( # 6. The 1-D flattened output tensors retain the same order as the outputs from the ONNX Runtime (ORT) # forward run. To facilitate unflattening during subsequent ORT runs, the output schema is saved as # an attribute named `_output_schema` in the _io.FlattenedModule. - flatten_inputs = flatten_args + flatten_kwargs + + copied_args = copy.copy(args) + copied_kwargs = copy.copy(kwargs) + flatten_inputs = [ + data_accessor(copied_args, copied_kwargs) + for _, data_accessor in cur_model_info_for_export.onnx_graph_input_data_accessor.items() + ] self._flatten_module._device = self._device - self._flatten_module._args_schema = cur_args_schema - self._flatten_module._kwargs_schema = cur_kwargs_schema - self._flatten_module._num_positionals = len(flatten_args) - - cur_model_info_for_export = _io.parse_inputs_for_onnx_export( - self._module_forward_func_parameters, - args, - kwargs, - True, - self._device, - self._export_mode, - self._export_extra_kwargs, - ) + self._flatten_module._args_schema = cur_model_info_for_export.onnx_graph_input_arg_schema + self._flatten_module._kwargs_schema = cur_model_info_for_export.onnx_graph_input_kwarg_schema + self._flatten_module._num_positionals = cur_model_info_for_export.num_positional_args + + self._logger.info(f"do_export started, model info for export: {cur_model_info_for_export}") ( exported_model, @@ -368,8 +374,8 @@ def use_cache_or_reconstruct_post_processed_model( ] self._exported_model_info = ExportedModelInfo( - module_forward_args_schema=cur_args_schema, - module_forward_kwargs_schema=cur_kwargs_schema, + module_forward_args_schema=cur_model_info_for_export.onnx_graph_input_arg_schema, + module_forward_kwargs_schema=cur_model_info_for_export.onnx_graph_input_kwarg_schema, onnx_graph_input_names=onnx_graph_input_names, onnx_graph_input_names_require_grad=onnx_graph_input_names_require_grad, onnx_graph_input_names_user_defined=onnx_graph_input_names_user_defined, @@ -500,11 +506,11 @@ def _reprocess_check( for input_name in exported_model_info.onnx_graph_input_names: if input_name in exported_model_info.onnx_graph_input_names_user_defined: assert ( - input_name in model_info_for_export.data_accessor - ), f"{input_name} model_info_for_export.data_accessor" + input_name in model_info_for_export.onnx_graph_input_data_accessor + ), f"{input_name} model_info_for_export.onnx_graph_input_data_accessor" # We assume the data accessor should be the same as the one used for the previous export, because # there is args and kwargs schema check during export check phase. - if model_info_for_export.data_accessor[input_name](args, kwargs).requires_grad: + if model_info_for_export.onnx_graph_input_data_accessor[input_name](args, kwargs).requires_grad: onnx_graph_input_requires_grads.append(input_name) else: assert input_name in parameter_names, f"{input_name} not exist parameter_names" @@ -559,7 +565,7 @@ def _post_export_process( model_info_for_export.onnx_graph_input_dynamic_axes_map, exported_model_info.module_forward_output_schema, post_processed_model, - model_info_for_export.data_accessor, + model_info_for_export.onnx_graph_input_data_accessor, ) logger.info( @@ -664,14 +670,16 @@ def _get_exported_model( ) ( output_names, - dynamic_axes, + output_dynamic_axes, module_output_schema, ) = _io.parse_outputs_for_onnx_export_and_extract_schema( flattened_module, flatten_module_inputs, logger, need_deep_copy ) # Combine the dynamic axes from inputs and outputs - dynamic_axes.update(model_info_for_export.onnx_graph_input_dynamic_axes_map) + dynamic_axes = copy.deepcopy(model_info_for_export.onnx_graph_input_dynamic_axes_map) + + dynamic_axes.update(output_dynamic_axes) logger.info("Exporting the PyTorch model to ONNX...") diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 607af8104056a..6ace6e93ff1a3 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -21,6 +21,7 @@ extract_data_and_schema, unflatten_data_using_schema, ) +from onnxruntime.training.utils.torch_io_helper import _TensorStub from ._fallback import ORTModuleIOError, wrap_exception @@ -153,7 +154,10 @@ def __init__( onnx_graph_input_names_require_grad: List[str], onnx_graph_input_dynamic_axes_map: Dict[str, Dict[int, str]], onnx_graph_input_shapes: List[List[int]], - data_accessor: Optional[List[callable]] = None, + onnx_graph_input_data_accessor: Optional[Dict[str, callable]] = None, + onnx_graph_input_arg_schema: Optional[Dict[str, ORTModelInputOutputSchemaType]] = None, + onnx_graph_input_kwarg_schema: Optional[Dict[str, ORTModelInputOutputSchemaType]] = None, + num_positional_args: int = 0, export_mode: Optional[int] = None, export_extra_kwargs: Optional[Dict[str, any]] = None, ): @@ -182,10 +186,22 @@ def __init__( self.onnx_graph_input_shapes: List[List[int]] = onnx_graph_input_shapes + # The input args schema for the original model's forward function. + # Only contains the schema for those inputs used by the model for its compute (e.g. as the inputs + # of the export model). + self.onnx_graph_input_arg_schema: Dict[str, ORTModelInputOutputSchemaType] = onnx_graph_input_arg_schema + + # The input kwargs schema for the original model's forward function. + # Only contains the schema for those inputs used by the model for its compute (e.g. as the inputs + # of the export model). + self.onnx_graph_input_kwarg_schema: Dict[str, ORTModelInputOutputSchemaType] = onnx_graph_input_kwarg_schema + + self.num_positional_args: int = num_positional_args + # A function to access the input data from the args and kwargs. # If it is not None, the length is same as onnx_graph_input_names. # For i-th input name, we can use the i-th function to get the input data from args and kwargs. - self.data_accessor: Optional[List[callable]] = data_accessor + self.onnx_graph_input_data_accessor: Optional[Dict[str, callable]] = onnx_graph_input_data_accessor def __str__(self) -> str: return f"""ModelInfoForExport class: @@ -208,6 +224,12 @@ def _kwarg_access_with_name_func(name, args, kwargs): return kwargs[name] +class SkipRetValue: + """A placeholder class to indicate that the return value of a function should be skipped""" + + pass + + def parse_inputs_for_onnx_export( all_input_parameters: List[inspect.Parameter], args: Sequence[ORTModelInputOutputType], @@ -243,7 +265,7 @@ def parse_inputs_for_onnx_export( """ - data_accessors: Dict[str, Callable] = OrderedDict() + tensor_idx = [-1] def _add_dynamic_shape(name, input) -> Dict[str, Dict[int, str]]: dynamic_axes[name] = {} @@ -258,62 +280,84 @@ def _add_input(name: str, input_value, onnx_graph_input_names: List[str], cur_fu """Returns number of expanded non none inputs that _add_input processed""" # in case the input is already handled. - if name in input_names: # or input_value is None or isinstance(input_value, str): - # Drop all None and string inputs and return 0. - return + if name in visited_input_names: + return SkipRetValue() - # InputInfo should contain all the names irrespective of whether they are - # a part of the onnx graph or not. - input_names.append(name) + visited_input_names.append(name) value = input_value if value is None: _warn_of_constant_inputs(value) + return value elif isinstance(value, str): _warn_of_constant_inputs(value) + return value elif PrimitiveType.is_primitive_type(value): if constant_as_tensor: value = PrimitiveType.get_tensor(value, device) else: _warn_of_constant_inputs(value) + return value elif isinstance(value, abc.Sequence): + sequence_type = type(value) + stubbed_schema = [] + # If the input is a sequence (like a list), expand the list so that # each element of the list is an input by itself. for i, val in enumerate(value): # Name each input with the index appended to the original name of the # argument. - def _access_func1(i, cur_func, args, kwargs): + def _access_func(i, cur_func, args, kwargs): return cur_func(args, kwargs)[i] - _add_input( + input_schema = _add_input( f"{name}_{i}", val, onnx_graph_input_names, - partial(_access_func1, i, cur_func), + partial(_access_func, i, cur_func), ) + if not isinstance(input_schema, SkipRetValue): + stubbed_schema.append(input_schema) + # Return here since the list by itself is not a valid input. # All the elements of the list have already been added as inputs individually. - return + + try: + # namedtuple can be created by passing the list sequence to method _make + stubbed_schema = sequence_type._make(stubbed_schema) + except AttributeError: + # If attribute error is encountered, create the sequence directly + stubbed_schema = sequence_type(stubbed_schema) + return stubbed_schema + elif isinstance(value, abc.Mapping): + dict_type = type(value) + stubbed_schema = OrderedDict() + # If the input is a mapping (like a dict), expand the dict so that # each element of the dict is an input by itself. for key, val in value.items(): - def _access_func2(key, cur_func, args, kwargs): + def _access_func(key, cur_func, args, kwargs): return cur_func(args, kwargs)[key] - _add_input( + input_schema = _add_input( f"{name}_{key}", val, onnx_graph_input_names, - partial(_access_func2, key, cur_func), + partial(_access_func, key, cur_func), ) + if not isinstance(input_schema, SkipRetValue): + stubbed_schema[key] = input_schema + # Return here since the dict by itself is not a valid input. # All the elements of the dict have already been added as inputs individually. - return + + stubbed_schema = dict_type(**stubbed_schema) + return stubbed_schema if isinstance(value, torch.Tensor): onnx_graph_input_names.append(name) @@ -322,12 +366,25 @@ def _access_func2(key, cur_func, args, kwargs): input_names_require_grad.append(name) dynamic_axes.update(_add_dynamic_shape(name, value)) input_shape.append(list(value.size())) + tensor_idx[0] += 1 + return _TensorStub( + tensor_idx[0], + dtype=str(value.dtype), + shape_dims=len(value.size()), + name=name, + ) + + visited_input_names: List[str] = [] onnx_graph_input_names: List[str] = [] - input_names: List[str] = [] dynamic_axes: Dict[str, Dict[int, str]] = {} input_names_require_grad: List[str] = [] input_shape: List[List[int]] = [] + input_arg_schema: Dict[str, ORTModelInputOutputSchemaType] = OrderedDict() + input_kwarg_schema: Dict[str, ORTModelInputOutputSchemaType] = OrderedDict() + data_accessors: Dict[str, Callable] = OrderedDict() + num_positional_args: int = 0 + var_positional_idx = 0 # Be noted, all_input_parameters is a list of inspect.Parameters parsed from the original module's forward method. @@ -356,9 +413,11 @@ def _access_func2(key, cur_func, args, kwargs): for args_i in range(input_idx, len(args)): name = f"{input_parameter.name}_{var_positional_idx}" var_positional_idx += 1 + num_positional_args += 1 inp = args[args_i] - - _add_input(name, inp, onnx_graph_input_names, partial(_arg_access_with_index_func, args_i)) + schema = _add_input(name, inp, onnx_graph_input_names, partial(_arg_access_with_index_func, args_i)) + if not isinstance(schema, SkipRetValue): + input_arg_schema[name] = schema elif ( input_parameter.kind == inspect.Parameter.POSITIONAL_ONLY or input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD @@ -369,33 +428,44 @@ def _access_func2(key, cur_func, args, kwargs): inp = None input_idx += var_positional_idx # noqa: PLW2901 access_func = None + schema_to_write = None if input_idx < len(args) and args[input_idx] is not None: inp = args[input_idx] - + num_positional_args += 1 access_func = partial(_arg_access_with_index_func, input_idx) - + schema_to_write = input_arg_schema elif name in kwargs and kwargs[name] is not None: inp = kwargs[name] - access_func = partial(_kwarg_access_with_name_func, name) + schema_to_write = input_kwarg_schema + else: + continue + + schema = _add_input(name, inp, onnx_graph_input_names, access_func) + if not isinstance(schema, SkipRetValue): + schema_to_write[name] = schema - _add_input(name, inp, onnx_graph_input_names, access_func) elif input_parameter.kind == inspect.Parameter.VAR_KEYWORD: # **kwargs is always the last argument of forward() for name, inp in kwargs.items(): - _add_input( + schema = _add_input( name, inp, onnx_graph_input_names, partial(_kwarg_access_with_name_func, name), ) + if not isinstance(schema, SkipRetValue): + input_kwarg_schema[name] = schema exported_graph = ModelInfoForExport( onnx_graph_input_names=onnx_graph_input_names, onnx_graph_input_names_require_grad=input_names_require_grad, onnx_graph_input_dynamic_axes_map=dynamic_axes, onnx_graph_input_shapes=input_shape, - data_accessor=data_accessors, + onnx_graph_input_data_accessor=data_accessors, + onnx_graph_input_arg_schema=input_arg_schema, + onnx_graph_input_kwarg_schema=input_kwarg_schema, + num_positional_args=num_positional_args, export_mode=export_mode, export_extra_kwargs=export_extra_kwargs, ) From 44f9f3f635e6ea74f1e3a5f19ed2a01a70d72211 Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Mon, 8 Jan 2024 08:33:14 +0000 Subject: [PATCH 11/32] fixes --- .../ortmodule/_graph_transition_manager.py | 44 +++++++++++-------- .../python/training/ortmodule/_io.py | 23 +++++----- 2 files changed, 37 insertions(+), 30 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index 15fa5faaa8190..d52ce009a5d01 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -329,6 +329,7 @@ def use_cache_or_reconstruct_post_processed_model( data_accessor(copied_args, copied_kwargs) for _, data_accessor in cur_model_info_for_export.onnx_graph_input_data_accessor.items() ] + self._flatten_module._device = self._device self._flatten_module._args_schema = cur_model_info_for_export.onnx_graph_input_arg_schema self._flatten_module._kwargs_schema = cur_model_info_for_export.onnx_graph_input_kwarg_schema @@ -345,7 +346,6 @@ def use_cache_or_reconstruct_post_processed_model( flattened_module=self._flatten_module, model_info_for_export=cur_model_info_for_export, flatten_module_inputs=flatten_inputs, - run_symbolic_shape_infer=self._runtime_options.run_symbolic_shape_infer, deepcopy_before_model_export=self._runtime_options.deepcopy_before_model_export, device=self._device, ortmodule_cache_dir=self._runtime_options.ortmodule_cache_dir, @@ -429,6 +429,7 @@ def use_cache_or_reconstruct_post_processed_model( model_info_for_export=self._model_info_for_export, enable_custom_autograd_function=self._runtime_options.enable_custom_autograd_function, enable_zero_stage3_support=self._runtime_options.enable_zero_stage3_support, + run_symbolic_shape_infer=self._runtime_options.run_symbolic_shape_infer, stage3_param_handle=self, logger=self._logger, ) @@ -531,6 +532,7 @@ def _post_export_process( model_info_for_export: _io.ModelInfoForExport, enable_custom_autograd_function: bool, enable_zero_stage3_support: bool, + run_symbolic_shape_infer: bool, stage3_param_handle: type, logger: logging.Logger, ): @@ -541,12 +543,17 @@ def _post_export_process( # TODO(): Do pre-grad graph modification as needed, for memory-efficient gradient management, etc. post_processed_model = copy.deepcopy(exported_model_info.exported_model) - if export_mode == torch.onnx.TrainingMode.TRAINING: - if enable_custom_autograd_function: - from ._custom_autograd_function_exporter import post_process_enabling_autograd_function + if enable_custom_autograd_function: + from ._custom_autograd_function_exporter import post_process_enabling_autograd_function + + post_processed_model = post_process_enabling_autograd_function(post_processed_model) - post_processed_model = post_process_enabling_autograd_function(post_processed_model) + if run_symbolic_shape_infer: + # MUST call symbolic shape inference after custom autograd function post-processing is done, + # Otherwise, there is no ctx output for PythonOp. + post_processed_model = GraphTransitionManager._infer_shapes(post_processed_model) + if export_mode == torch.onnx.TrainingMode.TRAINING: if enable_zero_stage3_support: from ._zero_stage3_compatibility import post_processing_enable_zero_stage3_compat @@ -574,6 +581,20 @@ def _post_export_process( return post_export_processed_model_info + @staticmethod + def _infer_shapes(model: onnx.ModelProto) -> onnx.ModelProto: + """Infer shapes for the exported model.""" + # Record random states here and restore later in case any of them gets changed during the export, + # e.g., some sympy functions in symbolic_shape_infer will change Python's random state. + random_states = _utils.get_random_states() + + model = SymbolicShapeInference.infer_shapes(model, auto_merge=True, guess_output_rank=True) + + # Restore the recorded random states + _utils.set_random_states(random_states) + + return model + @staticmethod @TrackTimeForStaticFunction(ORTModuleInitPhase.EXPORT) @SuppressLogs(ORTModuleInitPhase.EXPORT, is_ort_filter=False) @@ -582,7 +603,6 @@ def _export_model( flattened_module: torch.nn.Module, model_info_for_export: _io.ModelInfoForExport, flatten_module_inputs: Sequence[ORTModelInputOutputType], - run_symbolic_shape_infer: bool, deepcopy_before_model_export: bool, device: torch.device, ortmodule_cache_dir: str, @@ -594,10 +614,6 @@ def _export_model( time_tracker: TimeTracker, logger: logging.Logger, ) -> tuple[onnx.ModelProto, ORTModelInputOutputSchemaType, list[str], list[str]]: - # Record random states here and restore later in case any of them gets changed during the export, - # e.g., some sympy functions in symbolic_shape_infer will change Python's random state. - random_states = _utils.get_random_states() - torch_exporter_verbose_log = debug_options.log_level < LogLevel.WARNING from onnxruntime.training.utils.hooks._subscriber_manager import no_increase_global_step @@ -625,14 +641,6 @@ def _export_model( if input.name in parameter_names or input.name in model_info_for_export.onnx_graph_input_names_require_grad ] - if run_symbolic_shape_infer: - exported_model = SymbolicShapeInference.infer_shapes( - exported_model, auto_merge=True, guess_output_rank=True - ) - - # Restore the recorded random states - _utils.set_random_states(random_states) - return exported_model, module_output_schema, onnx_graph_input_names, onnx_graph_input_names_require_grad @staticmethod diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 6ace6e93ff1a3..bd7e50b15fb34 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -374,14 +374,16 @@ def _access_func(key, cur_func, args, kwargs): name=name, ) + raise TypeError(f"ORTModule does not support input type {type(value)} for input {name}") + visited_input_names: List[str] = [] onnx_graph_input_names: List[str] = [] dynamic_axes: Dict[str, Dict[int, str]] = {} input_names_require_grad: List[str] = [] input_shape: List[List[int]] = [] - input_arg_schema: Dict[str, ORTModelInputOutputSchemaType] = OrderedDict() - input_kwarg_schema: Dict[str, ORTModelInputOutputSchemaType] = OrderedDict() + input_arg_schema: ORTModelInputOutputSchemaType = [] + input_kwarg_schema: ORTModelInputOutputSchemaType = OrderedDict() data_accessors: Dict[str, Callable] = OrderedDict() num_positional_args: int = 0 @@ -417,7 +419,7 @@ def _access_func(key, cur_func, args, kwargs): inp = args[args_i] schema = _add_input(name, inp, onnx_graph_input_names, partial(_arg_access_with_index_func, args_i)) if not isinstance(schema, SkipRetValue): - input_arg_schema[name] = schema + input_arg_schema.append(schema) elif ( input_parameter.kind == inspect.Parameter.POSITIONAL_ONLY or input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD @@ -428,22 +430,19 @@ def _access_func(key, cur_func, args, kwargs): inp = None input_idx += var_positional_idx # noqa: PLW2901 access_func = None - schema_to_write = None if input_idx < len(args) and args[input_idx] is not None: inp = args[input_idx] num_positional_args += 1 access_func = partial(_arg_access_with_index_func, input_idx) - schema_to_write = input_arg_schema + schema = _add_input(name, inp, onnx_graph_input_names, access_func) + if not isinstance(schema, SkipRetValue): + input_arg_schema.append(schema) elif name in kwargs and kwargs[name] is not None: inp = kwargs[name] access_func = partial(_kwarg_access_with_name_func, name) - schema_to_write = input_kwarg_schema - else: - continue - - schema = _add_input(name, inp, onnx_graph_input_names, access_func) - if not isinstance(schema, SkipRetValue): - schema_to_write[name] = schema + schema = _add_input(name, inp, onnx_graph_input_names, access_func) + if not isinstance(schema, SkipRetValue): + input_kwarg_schema[name] = schema elif input_parameter.kind == inspect.Parameter.VAR_KEYWORD: # **kwargs is always the last argument of forward() From b33fd9393229581f5dfd63d44e0b018068474050 Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Mon, 8 Jan 2024 08:35:50 +0000 Subject: [PATCH 12/32] fix ci --- .../test/python/orttraining_test_ortmodule_autograd.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py index bd4fce2cde144..a4eb9c3b36faf 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py @@ -1415,10 +1415,12 @@ def check_pythonop_training_mode(model, is_eval_mode): ## make sure the ort's PythonOp's training_mode is correct if is_eval_mode: onnx_nodes = ( - model._torch_module._execution_manager._inference_manager._onnx_models.exported_model.graph.node + model._torch_module._execution_manager._inference_manager._graph_transition_manager._exported_model_info.exported_model.graph.node ) else: - onnx_nodes = model._torch_module._execution_manager._training_manager._onnx_models.exported_model.graph.node + onnx_nodes = ( + model._torch_module._execution_manager._training_manager._graph_transition_manager._exported_model_info.exported_model.graph.node + ) found_pythonop = False for node in onnx_nodes: From a07f21c517b21d659d34f3639b64c4dc8e30b93c Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Mon, 8 Jan 2024 08:42:19 +0000 Subject: [PATCH 13/32] fix --- .../ortmodule/_graph_transition_manager.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index d52ce009a5d01..d752fd4e93a22 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -584,15 +584,8 @@ def _post_export_process( @staticmethod def _infer_shapes(model: onnx.ModelProto) -> onnx.ModelProto: """Infer shapes for the exported model.""" - # Record random states here and restore later in case any of them gets changed during the export, - # e.g., some sympy functions in symbolic_shape_infer will change Python's random state. - random_states = _utils.get_random_states() model = SymbolicShapeInference.infer_shapes(model, auto_merge=True, guess_output_rank=True) - - # Restore the recorded random states - _utils.set_random_states(random_states) - return model @staticmethod @@ -614,6 +607,10 @@ def _export_model( time_tracker: TimeTracker, logger: logging.Logger, ) -> tuple[onnx.ModelProto, ORTModelInputOutputSchemaType, list[str], list[str]]: + # Record random states here and restore later in case any of them gets changed during the export, + # e.g., some sympy functions in symbolic_shape_infer will change Python's random state. + random_states = _utils.get_random_states() + torch_exporter_verbose_log = debug_options.log_level < LogLevel.WARNING from onnxruntime.training.utils.hooks._subscriber_manager import no_increase_global_step @@ -641,6 +638,9 @@ def _export_model( if input.name in parameter_names or input.name in model_info_for_export.onnx_graph_input_names_require_grad ] + # Restore the recorded random states + _utils.set_random_states(random_states) + return exported_model, module_output_schema, onnx_graph_input_names, onnx_graph_input_names_require_grad @staticmethod From 2168ea75be4f2b0cf33d784f01a463173a9ef450 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Wed, 21 Feb 2024 09:00:00 +0000 Subject: [PATCH 14/32] refine based on review comments --- .../ortmodule/_graph_execution_manager.py | 2 +- .../ortmodule/_graph_transition_manager.py | 27 ++++++++++++++----- .../training/ortmodule/_inference_manager.py | 2 +- .../training/ortmodule/_training_manager.py | 2 +- .../python/orttraining_test_ortmodule_api.py | 2 +- 5 files changed, 25 insertions(+), 10 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 32d8fd6b0003b..f70b251099967 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -69,7 +69,7 @@ def __init__( self._export_mode = export_mode self._graph_transition_manager: Optional[GraphTransitionManager] = None - # Model after inference optimization && gradient building. + # Model after inference optimization and then gradient building. self._graph_builder = None self._graph_info = None diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index d752fd4e93a22..0c06043117a01 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -35,7 +35,14 @@ class ExportedModelInfo: - """Encapsulates the information of the exported model.""" + """Encapsulates the information of the exported model. + + After ONNX model export, the model info is collected and encapsulated in this class, including: + 1. The ONNX graph inputs + 2. Graph input requiring gradient information. + 3. The model's forward function signature and args/kwargs schema. + + """ def __init__( self, @@ -59,7 +66,7 @@ def __init__( # A subset of onnx_graph_input_names. # Input names that require gradient parsed and then flatten from the model's forward function signature - # This should contain both the user input names, the buffer names, and the parameter names (since we use + # This should contain both the user-defined input names, the buffer names, and the parameter names (since we use # keep_initializers_as_inputs=True for model export) # Be noted: all inputs are used by the model for its compute. self.onnx_graph_input_names_require_grad: list[str] = onnx_graph_input_names_require_grad @@ -87,12 +94,20 @@ def __str__(self): \tmodule_forward_output_schema: {self.module_forward_output_schema} """ - def __repro__(self): + def __repr__(self): return self.__str__() class PostExportProcessedModelInfo: - """Encapsulates the information of the post-export processed model.""" + """Encapsulates the information of the post-export processed model. + + After ONNX model post-export processing, the model info is collected and encapsulated in this class, including: + 1. The ONNX graph input names, dynamic axes, and input data accessor functions. + 2. Graph input requiring gradient information. + 3. The interface to construct the inputs for the ORT forward run, from original given inputs running for PyTorch. + 4. The interface to restore the outputs from the ORT forward run, back to the original data structure. + + """ def __init__( self, @@ -153,7 +168,7 @@ def __str__(self): \tbuffer_for_ort_runs.keys(): {self._buffer_for_ort_runs.keys()} """ - def __repro__(self): + def __repr__(self): return self.__str__() def construct_inputs( @@ -255,7 +270,7 @@ def __init__( # Model info after export and post export processing. self._post_export_processed_model_info = None - def use_cache_or_reconstruct_post_processed_model( + def get_post_processed_model( self, args: Sequence[ORTModelInputOutputType], kwargs: Mapping[str, ORTModelInputOutputType] ) -> tuple[bool, PostExportProcessedModelInfo]: """Check if the post-export processed ONNX model can be reused, otherwise, reconstruct the model. diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index 08f04d36eeb55..b603e58fa2ef0 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -117,7 +117,7 @@ def forward(self, *inputs, **kwargs): ( build_graph, post_export_processed_model_info, - ) = self._graph_transition_manager.use_cache_or_reconstruct_post_processed_model(inputs, kwargs) + ) = self._graph_transition_manager.get_post_processed_model(inputs, kwargs) if build_graph: # TODO(): do we need call it for inferencing mode??? self._initialize_graph_builder(post_export_processed_model_info) diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 37010f0829edf..f2fddd4050e4b 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -253,7 +253,7 @@ def forward(self, *inputs, **kwargs): ( build_gradient_graph, post_export_processed_model_info, - ) = self._graph_transition_manager.use_cache_or_reconstruct_post_processed_model(inputs, kwargs) + ) = self._graph_transition_manager.get_post_processed_model(inputs, kwargs) if build_gradient_graph: self._initialize_graph_builder(post_export_processed_model_info) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index cd0bfc54275f0..a9bb07ab4a5bc 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -3566,7 +3566,7 @@ def test_forward_dynamic_kwargs(): @pytest.mark.parametrize( "forward_function", [ # Only pos_X, pos_X as positionals - # lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1), + lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1), # Only pos_X, pos_X as keywords lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0=pos_0, pos_1=pos_1), # pos_X + *args, pos_X as positionals From 958c837037d1e5d9aa07454416927a61a19c22c7 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Thu, 22 Feb 2024 09:53:46 +0000 Subject: [PATCH 15/32] fix merge --- .../ortmodule/_graph_transition_manager.py | 97 +++++++++++++++---- .../ortmodule/_mem_efficient_grad_mgmt.py | 2 +- 2 files changed, 80 insertions(+), 19 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index 0c06043117a01..430254c520973 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -22,6 +22,7 @@ ORTModelInputOutputSchemaType, ORTModelInputOutputType, PrimitiveType, + onnx_dtype_to_pytorch_dtype, unflatten_data_using_schema, ) @@ -38,9 +39,12 @@ class ExportedModelInfo: """Encapsulates the information of the exported model. After ONNX model export, the model info is collected and encapsulated in this class, including: - 1. The ONNX graph inputs + 1. The ONNX graph input names. 2. Graph input requiring gradient information. - 3. The model's forward function signature and args/kwargs schema. + 3. The model's forward function signature and args/kwargs schema, used as a cache key to compare with the current + inputs to see if the model needs to be re-exported. + + This data structure is returned by the GraphTransitionManager._export_model method. """ @@ -120,6 +124,7 @@ def __init__( module_forward_output_schema: ORTModelInputOutputSchemaType, post_export_processed_model: onnx.ModelProto, onnx_graph_input_data_accessor: dict[str, callable], + enable_mem_efficient_grad_management: bool, ): self._flattened_module = flatten_module @@ -152,11 +157,13 @@ def __init__( # For i-th input name, we can use the i-th function to get the input data from args and kwargs. self.onnx_graph_input_data_accessor: dict[str, callable] | None = onnx_graph_input_data_accessor + self._enable_mem_efficient_grad_management = enable_mem_efficient_grad_management + # Used for unflattening the outputs from the ORT forward run. self.module_forward_output_schema: ORTModelInputOutputSchemaType | None = module_forward_output_schema # A buffer to hold the inputs for the ORT forward run. For performance, we reuse the same buffer for each run. - self._buffer_for_ort_runs: dict[str, torch.Tensor] = OrderedDict() + self._buffer_for_ort_runs: dict[str, torch.Tensor] | None = None def __str__(self): return f"""PostExportProcessedModelInfo class: @@ -165,7 +172,7 @@ def __str__(self): \tonnx_graph_input_dynamic_axes_map: {self.onnx_graph_input_dynamic_axes_map} \tonnx_graph_input_names_user_defined: {self.onnx_graph_input_names_user_defined} \tonnx_graph_input_names_require_grad_user_defined: {self.onnx_graph_input_names_require_grad_user_defined} - \tbuffer_for_ort_runs.keys(): {self._buffer_for_ort_runs.keys()} + \tbuffer_for_ort_runs.keys(): {self._buffer_for_ort_runs.keys() if self._buffer_for_ort_runs else None} """ def __repr__(self): @@ -182,12 +189,27 @@ def construct_inputs( The inputs are constructed in the order they appear in the model's forward function signature """ + from ._mem_efficient_grad_mgmt import ( + MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME, + MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE, + MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE, + ) # First time construct the buffer for the ORT forward run. - if len(self._buffer_for_ort_runs) == 0: + if self._buffer_for_ort_runs is None: + self._buffer_for_ort_runs = OrderedDict() + # Create the buffers for the inputs that are either parameters or buffers in the original module. # For user inputs, fill with None for now, and will be filled dynamically during the forward run. - parameter_names = {k: v for k, v in self._flattened_module.named_parameters()} + + if self._enable_mem_efficient_grad_management: + from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger + + parameter_names = get_params_not_connected_to_pull_param_trigger( + self._flattened_module.named_parameters(), self._post_export_processed_model + ) + else: + parameter_names = {k: v for k, v in self._flattened_module.named_parameters()} buffer_names = {k: v for k, v in self._flattened_module.named_buffers()} for input_name in self.onnx_graph_input_names: if input_name in parameter_names: @@ -198,6 +220,14 @@ def construct_inputs( self._buffer_for_ort_runs[input_name] = None for name in self.onnx_graph_input_names_user_defined: + if self._enable_mem_efficient_grad_management and name == MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME: + self._buffer_for_ort_runs[name] = torch.zeros( + MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE, + dtype=onnx_dtype_to_pytorch_dtype(MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE), + device=device, + ).requires_grad_() + continue + if name in self.onnx_graph_input_data_accessor: assert name in self._buffer_for_ort_runs, f"{name} is not in buffer_for_ort_runs" data = self.onnx_graph_input_data_accessor[name](args, kwargs) @@ -316,11 +346,11 @@ def get_post_processed_model( # # The _io.FlattenedModule serves as a module wrapper designed to support tuple inputs and outputs for # PyTorch run during ONNX export. (Remember the PyTorch exporter handles tuple inputs and outputs better.) - # Internally, it facilitates the acceptance of tuple inputs and generation of tuple outputs by invoking + # Internally, it facilitates the acceptance of tuple inputs and the generation of tuple outputs by invoking # the original module's forward function. The workflow involves the following steps: - # 1. Prior to export, both args and kwargs are flattened into a 1-D tensor list, and a schema for the - # flattened args and kwargs is generated. This schema is essential for the subsequent unflattening + # 1. Prior to export, both args and kwargs are flattened into a 1-D tensor list, and schemas for the + # flattened args and kwargs are generated. This schemas are essential for the subsequent un-flattening # process. # 2. The flattened inputs (args + kwargs) are passed to the _io.FlattenedModule's forward run. @@ -328,14 +358,14 @@ def get_post_processed_model( # 3. The args schema and kwargs schema, etc are conveyed to the _io.FlattenedModule by setting the # corresponding attributes. - # 4. Within the _io.FlattenedModule's forward run, the inputs are unflattened to the original args and + # 4. Within the _io.FlattenedModule's forward run, the inputs are un-flattened to the original args and # kwargs using the associated schemas, and then they are passed to the original module's forward function. # 5. Upon the completion of the forward function, the outputs from the original module are flattened and # returned to the caller. # 6. The 1-D flattened output tensors retain the same order as the outputs from the ONNX Runtime (ORT) - # forward run. To facilitate unflattening during subsequent ORT runs, the output schema is saved as + # forward run. To facilitate un-flattening during subsequent ORT runs, the output schema is saved as # an attribute named `_output_schema` in the _io.FlattenedModule. copied_args = copy.copy(args) @@ -446,6 +476,8 @@ def get_post_processed_model( enable_zero_stage3_support=self._runtime_options.enable_zero_stage3_support, run_symbolic_shape_infer=self._runtime_options.run_symbolic_shape_infer, stage3_param_handle=self, + enable_mem_efficient_grad_management=self._export_mode != torch.onnx.TrainingMode.EVAL + and self._runtime_options.enable_mem_efficient_grad_management, logger=self._logger, ) @@ -465,7 +497,7 @@ def get_post_processed_model( @staticmethod def _export_check( - prev_exported_model_info: ExportedModelInfo, + prev_exported_model_info: ExportedModelInfo | None, original_model_has_changed: bool, cur_args_schema: ORTModelInputOutputSchemaType, cur_kwargs_schema: ORTModelInputOutputSchemaType, @@ -549,13 +581,12 @@ def _post_export_process( enable_zero_stage3_support: bool, run_symbolic_shape_infer: bool, stage3_param_handle: type, + enable_mem_efficient_grad_management: bool, logger: logging.Logger, ): """Post process the exported model, generate the processed model which will be used for initializing graph builder.""" # Deepcopy the exported model, in case modification affects the exported model. - - # TODO(): Do pre-grad graph modification as needed, for memory-efficient gradient management, etc. post_processed_model = copy.deepcopy(exported_model_info.exported_model) if enable_custom_autograd_function: @@ -578,16 +609,46 @@ def _post_export_process( [name for name, _ in flatten_module.named_parameters()], ) + onnx_graph_input_names_user_defined = copy.deepcopy(exported_model_info.onnx_graph_input_names_user_defined) + onnx_graph_input_names_require_grad_user_defined = copy.deepcopy( + exported_model_info.onnx_graph_input_names_require_grad_user_defined + ) + onnx_graph_input_names = copy.deepcopy(exported_model_info.onnx_graph_input_names) + onnx_graph_input_names_require_grad = copy.deepcopy(exported_model_info.onnx_graph_input_names_require_grad) + if enable_mem_efficient_grad_management: + from ._mem_efficient_grad_mgmt import post_processing_enable_mem_efficient_training + + # Override the options if model is not modified. + ( + enable_mem_efficient_grad_management, + post_processed_model, + ) = post_processing_enable_mem_efficient_training(post_processed_model, flatten_module.named_parameters()) + + if enable_custom_autograd_function: + from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME + + # Add mem efficient grad trigger name to require_grad_names, so that it will be included in the gradient graph. + onnx_graph_input_names_user_defined.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) + onnx_graph_input_names_require_grad_user_defined.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) + onnx_graph_input_names.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) + onnx_graph_input_names_require_grad.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) + + if run_symbolic_shape_infer: + post_processed_model = SymbolicShapeInference.infer_shapes( + post_processed_model, auto_merge=True, guess_output_rank=True + ) + post_export_processed_model_info = PostExportProcessedModelInfo( flatten_module, - exported_model_info.onnx_graph_input_names_user_defined, - exported_model_info.onnx_graph_input_names_require_grad_user_defined, - exported_model_info.onnx_graph_input_names, - exported_model_info.onnx_graph_input_names_require_grad, + onnx_graph_input_names_user_defined, + onnx_graph_input_names_require_grad_user_defined, + onnx_graph_input_names, + onnx_graph_input_names_require_grad, model_info_for_export.onnx_graph_input_dynamic_axes_map, exported_model_info.module_forward_output_schema, post_processed_model, model_info_for_export.onnx_graph_input_data_accessor, + enable_mem_efficient_grad_management, ) logger.info( diff --git a/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py b/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py index 4663afdaa94a0..61e226307bef2 100644 --- a/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py +++ b/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py @@ -32,7 +32,7 @@ def get_params_not_connected_to_pull_param_trigger( ): # Be noted, some parameters might not in graph input because they are not used in forward, so we filtered them also. onnx_initializer_names = {p.name for p in exported_model.graph.input} - return [v for k, v in named_params if not v.requires_grad and k in onnx_initializer_names] + return {k: v for k, v in named_params if not v.requires_grad and k in onnx_initializer_names} def post_processing_enable_mem_efficient_training( From 2d5314140ff406513daed4dedd1a904b6f294185 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Thu, 22 Feb 2024 11:25:25 +0000 Subject: [PATCH 16/32] fix --- .../ortmodule/_graph_execution_manager.py | 2 +- .../ortmodule/_graph_transition_manager.py | 52 +++++++++---------- .../training/ortmodule/_training_manager.py | 2 +- 3 files changed, 27 insertions(+), 29 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 5993a41ef75ba..494bc6faf392a 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -343,7 +343,7 @@ def _enable_conditional_optimizations( if self._runtime_options.enable_sparse_optimizer: detected_device = _utils.get_device_from_module_and_inputs(self._original_module, inputs, kwargs) - if self._runtime_options.enable_zero_stage3_support or self._mem_efficient_grad_management_is_enabled: + if self._runtime_options.enable_zero_stage3_support: self._append_pull_weight_trigger_as_input(kwargs, detected_device) prepared_input_map = self._graph_transition_manager._post_export_processed_model_info.construct_inputs( diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index 430254c520973..c281bc085d8b7 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -66,22 +66,24 @@ def __init__( # Input names parsed and then flatten from the model's forward function signature + buffers + parameters (since we use # keep_initializers_as_inputs=True for model export) # Be noted: all inputs are used by the model for its compute. - self.onnx_graph_input_names: list[str] = onnx_graph_input_names + self.onnx_graph_input_names: list[str] = copy.deepcopy(onnx_graph_input_names) # A subset of onnx_graph_input_names. # Input names that require gradient parsed and then flatten from the model's forward function signature # This should contain both the user-defined input names, the buffer names, and the parameter names (since we use # keep_initializers_as_inputs=True for model export) # Be noted: all inputs are used by the model for its compute. - self.onnx_graph_input_names_require_grad: list[str] = onnx_graph_input_names_require_grad + self.onnx_graph_input_names_require_grad: list[str] = copy.deepcopy(onnx_graph_input_names_require_grad) # Input names parsed from the model's forward function signature. # Be noted: all inputs are used by the model for its compute. # The ONNX graph input names exclude the parameters, and buffers. - self.onnx_graph_input_names_user_defined = onnx_graph_input_names_user_defined + self.onnx_graph_input_names_user_defined = copy.deepcopy(onnx_graph_input_names_user_defined) # A subset of onnx_graph_input_names_user_defined. - self.onnx_graph_input_names_require_grad_user_defined = onnx_graph_input_names_require_grad_user_defined + self.onnx_graph_input_names_require_grad_user_defined = copy.deepcopy( + onnx_graph_input_names_require_grad_user_defined + ) # Exported model proto. self.exported_model: onnx.ModelProto | None = exported_model @@ -131,19 +133,30 @@ def __init__( # Input names parsed from the model's forward function signature. # Be noted: all inputs are used by the model for its compute. # The ONNX graph input names exclude the parameters, and buffers. - self.onnx_graph_input_names_user_defined = onnx_graph_input_names_user_defined + self.onnx_graph_input_names_user_defined = copy.deepcopy(onnx_graph_input_names_user_defined) # A subset of onnx_graph_input_names_user_defined. - self.onnx_graph_input_names_require_grad_user_defined = onnx_graph_input_names_require_grad_user_defined + self.onnx_graph_input_names_require_grad_user_defined = copy.deepcopy( + onnx_graph_input_names_require_grad_user_defined + ) # Input names for the pre-gradient-build graph. # This may be different with the one in ExportedGraph since we may modify the graph inputs as needed # for example when memory efficient gradient management is enabled. - self.onnx_graph_input_names: list[str] = onnx_graph_input_names + self.onnx_graph_input_names: list[str] = copy.deepcopy(onnx_graph_input_names) # A subset of onnx_graph_input_names. # Input names that require gradients for the pre-gradient-build graph. - self.onnx_graph_input_names_require_grad: list[str] = onnx_graph_input_names_require_grad + self.onnx_graph_input_names_require_grad: list[str] = copy.deepcopy(onnx_graph_input_names_require_grad) + + if enable_mem_efficient_grad_management: + from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME + + # Add mem efficient grad trigger name to require_grad_names, so that it will be included in the gradient graph. + self.onnx_graph_input_names_user_defined.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) + self.onnx_graph_input_names_require_grad_user_defined.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) + self.onnx_graph_input_names.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) + self.onnx_graph_input_names_require_grad.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) # Create symbolic names for each dimension of the graph input (e.g. onnx_graph_input_names). # The key is the input name, the value is a dict of {dim_index: symbolic_dim_name} @@ -609,12 +622,6 @@ def _post_export_process( [name for name, _ in flatten_module.named_parameters()], ) - onnx_graph_input_names_user_defined = copy.deepcopy(exported_model_info.onnx_graph_input_names_user_defined) - onnx_graph_input_names_require_grad_user_defined = copy.deepcopy( - exported_model_info.onnx_graph_input_names_require_grad_user_defined - ) - onnx_graph_input_names = copy.deepcopy(exported_model_info.onnx_graph_input_names) - onnx_graph_input_names_require_grad = copy.deepcopy(exported_model_info.onnx_graph_input_names_require_grad) if enable_mem_efficient_grad_management: from ._mem_efficient_grad_mgmt import post_processing_enable_mem_efficient_training @@ -624,15 +631,6 @@ def _post_export_process( post_processed_model, ) = post_processing_enable_mem_efficient_training(post_processed_model, flatten_module.named_parameters()) - if enable_custom_autograd_function: - from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME - - # Add mem efficient grad trigger name to require_grad_names, so that it will be included in the gradient graph. - onnx_graph_input_names_user_defined.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) - onnx_graph_input_names_require_grad_user_defined.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) - onnx_graph_input_names.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) - onnx_graph_input_names_require_grad.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) - if run_symbolic_shape_infer: post_processed_model = SymbolicShapeInference.infer_shapes( post_processed_model, auto_merge=True, guess_output_rank=True @@ -640,10 +638,10 @@ def _post_export_process( post_export_processed_model_info = PostExportProcessedModelInfo( flatten_module, - onnx_graph_input_names_user_defined, - onnx_graph_input_names_require_grad_user_defined, - onnx_graph_input_names, - onnx_graph_input_names_require_grad, + exported_model_info.onnx_graph_input_names_user_defined, + exported_model_info.onnx_graph_input_names_require_grad_user_defined, + exported_model_info.onnx_graph_input_names, + exported_model_info.onnx_graph_input_names_require_grad, model_info_for_export.onnx_graph_input_dynamic_axes_map, exported_model_info.module_forward_output_schema, post_processed_model, diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index c5e649b57475d..ceb988b08b9fd 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -299,7 +299,7 @@ def forward(self, *inputs, **kwargs): self._gradient_accumulation_manager.maybe_update_cache_before_run() - if self._runtime_options.enable_zero_stage3_support or self._mem_efficient_grad_management_is_enabled: + if self._runtime_options.enable_zero_stage3_support: self._append_pull_weight_trigger_as_input(kwargs, self._device) prepared_input_map = self._graph_transition_manager._post_export_processed_model_info.construct_inputs( From 8078d369127ac53a272d21df491bcfa7260c37e6 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Fri, 23 Feb 2024 02:47:44 +0000 Subject: [PATCH 17/32] fix --- .../python/training/ortmodule/_io.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index bd7e50b15fb34..3a834c5c5b1c6 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -265,7 +265,8 @@ def parse_inputs_for_onnx_export( """ - tensor_idx = [-1] + arg_tensor_idx = [-1] + kwarg_tensor_idx = [-1] def _add_dynamic_shape(name, input) -> Dict[str, Dict[int, str]]: dynamic_axes[name] = {} @@ -276,7 +277,9 @@ def _add_dynamic_shape(name, input) -> Dict[str, Dict[int, str]]: def _warn_of_constant_inputs(data): warnings.warn(f"Received input of type {type(data)} is treated as a constant by ORT by default.") - def _add_input(name: str, input_value, onnx_graph_input_names: List[str], cur_func: Callable): + def _add_input( + name: str, input_value, onnx_graph_input_names: List[str], cur_func: Callable, tensor_idx: List[int] + ): """Returns number of expanded non none inputs that _add_input processed""" # in case the input is already handled. @@ -316,6 +319,7 @@ def _access_func(i, cur_func, args, kwargs): val, onnx_graph_input_names, partial(_access_func, i, cur_func), + tensor_idx, ) if not isinstance(input_schema, SkipRetValue): @@ -348,6 +352,7 @@ def _access_func(key, cur_func, args, kwargs): val, onnx_graph_input_names, partial(_access_func, key, cur_func), + tensor_idx, ) if not isinstance(input_schema, SkipRetValue): @@ -417,7 +422,13 @@ def _access_func(key, cur_func, args, kwargs): var_positional_idx += 1 num_positional_args += 1 inp = args[args_i] - schema = _add_input(name, inp, onnx_graph_input_names, partial(_arg_access_with_index_func, args_i)) + schema = _add_input( + name, + inp, + onnx_graph_input_names, + partial(_arg_access_with_index_func, args_i), + arg_tensor_idx, + ) if not isinstance(schema, SkipRetValue): input_arg_schema.append(schema) elif ( @@ -434,13 +445,13 @@ def _access_func(key, cur_func, args, kwargs): inp = args[input_idx] num_positional_args += 1 access_func = partial(_arg_access_with_index_func, input_idx) - schema = _add_input(name, inp, onnx_graph_input_names, access_func) + schema = _add_input(name, inp, onnx_graph_input_names, access_func, arg_tensor_idx) if not isinstance(schema, SkipRetValue): input_arg_schema.append(schema) elif name in kwargs and kwargs[name] is not None: inp = kwargs[name] access_func = partial(_kwarg_access_with_name_func, name) - schema = _add_input(name, inp, onnx_graph_input_names, access_func) + schema = _add_input(name, inp, onnx_graph_input_names, access_func, kwarg_tensor_idx) if not isinstance(schema, SkipRetValue): input_kwarg_schema[name] = schema @@ -452,6 +463,7 @@ def _access_func(key, cur_func, args, kwargs): inp, onnx_graph_input_names, partial(_kwarg_access_with_name_func, name), + kwarg_tensor_idx, ) if not isinstance(schema, SkipRetValue): input_kwarg_schema[name] = schema From 970525b6aab641a19b8dfe3fec9c2335aa587252 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Mon, 26 Feb 2024 12:57:33 +0000 Subject: [PATCH 18/32] fix all tests --- .../ortmodule/_graph_transition_manager.py | 44 ++++++++++++++++--- .../python/training/ortmodule/_io.py | 31 ++++++++++--- 2 files changed, 62 insertions(+), 13 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index c281bc085d8b7..6edb9bfb93b89 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -126,6 +126,7 @@ def __init__( module_forward_output_schema: ORTModelInputOutputSchemaType, post_export_processed_model: onnx.ModelProto, onnx_graph_input_data_accessor: dict[str, callable], + onnx_graph_input_const_as_tensor: dict[str, torch.device], enable_mem_efficient_grad_management: bool, ): self._flattened_module = flatten_module @@ -170,6 +171,8 @@ def __init__( # For i-th input name, we can use the i-th function to get the input data from args and kwargs. self.onnx_graph_input_data_accessor: dict[str, callable] | None = onnx_graph_input_data_accessor + self.onnx_graph_input_const_as_tensor: dict[str, torch.device] | None = onnx_graph_input_const_as_tensor + self._enable_mem_efficient_grad_management = enable_mem_efficient_grad_management # Used for unflattening the outputs from the ORT forward run. @@ -244,7 +247,7 @@ def construct_inputs( if name in self.onnx_graph_input_data_accessor: assert name in self._buffer_for_ort_runs, f"{name} is not in buffer_for_ort_runs" data = self.onnx_graph_input_data_accessor[name](args, kwargs) - if PrimitiveType.is_primitive_type(data) and constant_as_tensor: + if name in self.onnx_graph_input_const_as_tensor: data = PrimitiveType.get_tensor(data, device) self._buffer_for_ort_runs[name] = data else: @@ -383,17 +386,36 @@ def get_post_processed_model( copied_args = copy.copy(args) copied_kwargs = copy.copy(kwargs) - flatten_inputs = [ - data_accessor(copied_args, copied_kwargs) - for _, data_accessor in cur_model_info_for_export.onnx_graph_input_data_accessor.items() - ] + flatten_inputs = [] + + # This looks a bit duplicated with `extract_data_and_schema` function, but this might be better to + # defined as a specialized logic which is the counter-part of `parse_inputs_for_onnx_export`, which handles + # args and kwargs separately. + for name, data_accessor in cur_model_info_for_export.onnx_graph_input_data_accessor.items(): + print("find data accessor for name: ", name) + d = data_accessor(copied_args, copied_kwargs) + if name in cur_model_info_for_export.onnx_graph_input_const_as_tensor: + flatten_inputs.append( + PrimitiveType.get_tensor( + d, + cur_model_info_for_export.onnx_graph_input_const_as_tensor[name], + ) + ) + print("pass 1") + else: + if isinstance(d, torch.Tensor): + flatten_inputs.append(d) + print("pass 2") + else: + print("pass 3") + # Ignore all other non-tensor inputs. self._flatten_module._device = self._device self._flatten_module._args_schema = cur_model_info_for_export.onnx_graph_input_arg_schema self._flatten_module._kwargs_schema = cur_model_info_for_export.onnx_graph_input_kwarg_schema self._flatten_module._num_positionals = cur_model_info_for_export.num_positional_args - self._logger.info(f"do_export started, model info for export: {cur_model_info_for_export}") + self._logger.warning(f"do_export started, model info for export: {cur_model_info_for_export}") ( exported_model, @@ -534,7 +556,14 @@ def _export_check( or cur_kwargs_schema != prev_exported_model_info.module_forward_kwargs_schema ) - logger.info(f"_export_check completed - need_export_model: {need_export_model}") + print( + f"cur_args_schema: {cur_args_schema}, prev_exported_model_info.module_forward_args_schema: {prev_exported_model_info.module_forward_args_schema if prev_exported_model_info else None}" + ) + print( + f"cur_kwargs_schema: {cur_kwargs_schema}, prev_exported_model_info.module_forward_kwargs_schema: {prev_exported_model_info.module_forward_kwargs_schema if prev_exported_model_info else None}" + ) + + logger.warning(f"_export_check completed - need_export_model: {need_export_model}") return need_export_model @@ -646,6 +675,7 @@ def _post_export_process( exported_model_info.module_forward_output_schema, post_processed_model, model_info_for_export.onnx_graph_input_data_accessor, + model_info_for_export.onnx_graph_input_const_as_tensor, enable_mem_efficient_grad_management, ) diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 3a834c5c5b1c6..efb3cd9f7954a 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -132,6 +132,7 @@ def __init__(self, original_module: torch.nn.Module): def forward(self, *args): new_args = unflatten_data_using_schema(args[: self._num_positionals], self._args_schema) + new_kwargs = unflatten_data_using_schema(args[self._num_positionals :], self._kwargs_schema) original_outputs = self._original_module(*new_args, **new_kwargs) @@ -155,6 +156,7 @@ def __init__( onnx_graph_input_dynamic_axes_map: Dict[str, Dict[int, str]], onnx_graph_input_shapes: List[List[int]], onnx_graph_input_data_accessor: Optional[Dict[str, callable]] = None, + onnx_graph_input_const_as_tensor: Optional[Dict[str, torch.device]] = None, onnx_graph_input_arg_schema: Optional[Dict[str, ORTModelInputOutputSchemaType]] = None, onnx_graph_input_kwarg_schema: Optional[Dict[str, ORTModelInputOutputSchemaType]] = None, num_positional_args: int = 0, @@ -203,6 +205,8 @@ def __init__( # For i-th input name, we can use the i-th function to get the input data from args and kwargs. self.onnx_graph_input_data_accessor: Optional[Dict[str, callable]] = onnx_graph_input_data_accessor + self.onnx_graph_input_const_as_tensor: Optional[Dict[str, torch.device]] = onnx_graph_input_const_as_tensor + def __str__(self) -> str: return f"""ModelInfoForExport class: \tExport mode: {self.export_mode} @@ -210,7 +214,10 @@ def __str__(self) -> str: \tInput names: {self.onnx_graph_input_names} \tInput names require grad: {self.onnx_graph_input_names_require_grad} \tInput dynamic axes: {self.onnx_graph_input_dynamic_axes_map} - \tInput shapes: {self.onnx_graph_input_shapes}""" + \tInput shapes: {self.onnx_graph_input_shapes} + \tInput args schema: {self.onnx_graph_input_arg_schema} + \tInput kwargs schema: {self.onnx_graph_input_kwarg_schema} + \tNum input args: {self.num_positional_args}""" def __repr__(self) -> str: return self.__str__() @@ -289,16 +296,24 @@ def _add_input( visited_input_names.append(name) value = input_value + primitive_dtype = None if value is None: _warn_of_constant_inputs(value) + data_accessors[name] = cur_func return value elif isinstance(value, str): _warn_of_constant_inputs(value) + data_accessors[name] = cur_func return value elif PrimitiveType.is_primitive_type(value): if constant_as_tensor: + # This has special handling for bool type to string conversion. + primitive_dtype = PrimitiveType.get_primitive_dtype(value) value = PrimitiveType.get_tensor(value, device) + const_to_tensor_inputs[name] = device + else: + data_accessors[name] = cur_func _warn_of_constant_inputs(value) return value elif isinstance(value, abc.Sequence): @@ -374,7 +389,7 @@ def _access_func(key, cur_func, args, kwargs): tensor_idx[0] += 1 return _TensorStub( tensor_idx[0], - dtype=str(value.dtype), + dtype=primitive_dtype if primitive_dtype else str(value.dtype), # special handle for bool primitive shape_dims=len(value.size()), name=name, ) @@ -390,6 +405,7 @@ def _access_func(key, cur_func, args, kwargs): input_arg_schema: ORTModelInputOutputSchemaType = [] input_kwarg_schema: ORTModelInputOutputSchemaType = OrderedDict() data_accessors: Dict[str, Callable] = OrderedDict() + const_to_tensor_inputs: Dict[str, torch.device] = OrderedDict() num_positional_args: int = 0 var_positional_idx = 0 @@ -420,8 +436,8 @@ def _access_func(key, cur_func, args, kwargs): for args_i in range(input_idx, len(args)): name = f"{input_parameter.name}_{var_positional_idx}" var_positional_idx += 1 - num_positional_args += 1 inp = args[args_i] + pre_tensor_idx = arg_tensor_idx[0] schema = _add_input( name, inp, @@ -429,6 +445,7 @@ def _access_func(key, cur_func, args, kwargs): partial(_arg_access_with_index_func, args_i), arg_tensor_idx, ) + num_positional_args += arg_tensor_idx[0] - pre_tensor_idx if not isinstance(schema, SkipRetValue): input_arg_schema.append(schema) elif ( @@ -441,14 +458,15 @@ def _access_func(key, cur_func, args, kwargs): inp = None input_idx += var_positional_idx # noqa: PLW2901 access_func = None - if input_idx < len(args) and args[input_idx] is not None: + if input_idx < len(args): inp = args[input_idx] - num_positional_args += 1 access_func = partial(_arg_access_with_index_func, input_idx) + pre_tensor_idx = arg_tensor_idx[0] schema = _add_input(name, inp, onnx_graph_input_names, access_func, arg_tensor_idx) + num_positional_args += arg_tensor_idx[0] - pre_tensor_idx if not isinstance(schema, SkipRetValue): input_arg_schema.append(schema) - elif name in kwargs and kwargs[name] is not None: + elif name in kwargs: inp = kwargs[name] access_func = partial(_kwarg_access_with_name_func, name) schema = _add_input(name, inp, onnx_graph_input_names, access_func, kwarg_tensor_idx) @@ -474,6 +492,7 @@ def _access_func(key, cur_func, args, kwargs): onnx_graph_input_dynamic_axes_map=dynamic_axes, onnx_graph_input_shapes=input_shape, onnx_graph_input_data_accessor=data_accessors, + onnx_graph_input_const_as_tensor=const_to_tensor_inputs, onnx_graph_input_arg_schema=input_arg_schema, onnx_graph_input_kwarg_schema=input_kwarg_schema, num_positional_args=num_positional_args, From f45c4b44e2a9e7bec97d027d7427d893123aa1cc Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Mon, 26 Feb 2024 13:00:21 +0000 Subject: [PATCH 19/32] minors --- .../training/ortmodule/_graph_transition_manager.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index 6edb9bfb93b89..fbbc88403a2b0 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -392,7 +392,6 @@ def get_post_processed_model( # defined as a specialized logic which is the counter-part of `parse_inputs_for_onnx_export`, which handles # args and kwargs separately. for name, data_accessor in cur_model_info_for_export.onnx_graph_input_data_accessor.items(): - print("find data accessor for name: ", name) d = data_accessor(copied_args, copied_kwargs) if name in cur_model_info_for_export.onnx_graph_input_const_as_tensor: flatten_inputs.append( @@ -401,13 +400,10 @@ def get_post_processed_model( cur_model_info_for_export.onnx_graph_input_const_as_tensor[name], ) ) - print("pass 1") else: if isinstance(d, torch.Tensor): flatten_inputs.append(d) - print("pass 2") - else: - print("pass 3") + # Ignore all other non-tensor inputs. self._flatten_module._device = self._device @@ -415,7 +411,7 @@ def get_post_processed_model( self._flatten_module._kwargs_schema = cur_model_info_for_export.onnx_graph_input_kwarg_schema self._flatten_module._num_positionals = cur_model_info_for_export.num_positional_args - self._logger.warning(f"do_export started, model info for export: {cur_model_info_for_export}") + self._logger.info(f"do_export started, model info for export: {cur_model_info_for_export}") ( exported_model, From 2c69654abe788694c197ccd3d54deb5c4366eeab Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Mon, 26 Feb 2024 13:01:48 +0000 Subject: [PATCH 20/32] minor --- .../training/ortmodule/_graph_transition_manager.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index fbbc88403a2b0..ddc79bf064722 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -552,14 +552,7 @@ def _export_check( or cur_kwargs_schema != prev_exported_model_info.module_forward_kwargs_schema ) - print( - f"cur_args_schema: {cur_args_schema}, prev_exported_model_info.module_forward_args_schema: {prev_exported_model_info.module_forward_args_schema if prev_exported_model_info else None}" - ) - print( - f"cur_kwargs_schema: {cur_kwargs_schema}, prev_exported_model_info.module_forward_kwargs_schema: {prev_exported_model_info.module_forward_kwargs_schema if prev_exported_model_info else None}" - ) - - logger.warning(f"_export_check completed - need_export_model: {need_export_model}") + logger.info(f"_export_check completed - need_export_model: {need_export_model}") return need_export_model From c4880c146aeef5948b16a1ef0bf3228f525c256d Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Tue, 27 Feb 2024 04:13:25 +0000 Subject: [PATCH 21/32] fix test --- .../ortmodule/_graph_transition_manager.py | 3 ++- .../python/training/ortmodule/_io.py | 4 ++-- .../python/orttraining_test_ortmodule_api.py | 23 +++++++++++++------ 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index ddc79bf064722..ee0f6f982d356 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -346,6 +346,7 @@ def get_post_processed_model( True, self._device, self._export_mode, + self._logger, self._export_extra_kwargs, ) @@ -389,7 +390,7 @@ def get_post_processed_model( flatten_inputs = [] # This looks a bit duplicated with `extract_data_and_schema` function, but this might be better to - # defined as a specialized logic which is the counter-part of `parse_inputs_for_onnx_export`, which handles + # defined as a specialized logic that is the counter-part of `parse_inputs_for_onnx_export`, which handles # args and kwargs separately. for name, data_accessor in cur_model_info_for_export.onnx_graph_input_data_accessor.items(): d = data_accessor(copied_args, copied_kwargs) diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index efb3cd9f7954a..37efb6faaef8a 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -6,7 +6,6 @@ import copy import gc import inspect -import warnings from collections import OrderedDict, abc from functools import partial from logging import Logger @@ -244,6 +243,7 @@ def parse_inputs_for_onnx_export( constant_as_tensor: bool, device: torch.device, export_mode: int, + logger: Logger, export_extra_kwargs: Optional[Dict[str, any]] = None, ) -> ModelInfoForExport: """Parses through the model inputs and returns _InputInfo. @@ -282,7 +282,7 @@ def _add_dynamic_shape(name, input) -> Dict[str, Dict[int, str]]: return dynamic_axes def _warn_of_constant_inputs(data): - warnings.warn(f"Received input of type {type(data)} is treated as a constant by ORT by default.") + logger.info(f"Received input of type {type(data)} is treated as a constant by ORT by default.") def _add_input( name: str, input_value, onnx_graph_input_names: List[str], cur_func: Callable, tensor_idx: List[int] diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index bc57715d5a484..712afd6c52cb5 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -4239,14 +4239,23 @@ def test_hf_save_pretrained(): assert p1.data.ne(p2.data).sum() == 0 -def test_ortmodule_string_inputs_are_ignored(): +def test_ortmodule_string_inputs_are_ignored(caplog): pt_model = MyStrNet() - target_str = "Received input of type which may be treated as a constant by ORT by default." - with pytest.warns(UserWarning, match=target_str): - ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(log_level=LogLevel.INFO)) - x = torch.randn(1, 2) - out = ort_model(x, "hello") - _test_helpers.assert_values_are_close(out, x + 1) + target_str = "Received input of type is treated as a constant by ORT by default." + + ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(log_level=LogLevel.INFO)) + x = torch.randn(1, 2) + out = ort_model(x, "hello") + _test_helpers.assert_values_are_close(out, x + 1) + + found_log = False + for record in caplog.records: + msg = record.getMessage() + if target_str in msg: + found_log = True + break + + assert found_log, f"Expected to find log message '{target_str}' in the logs, but didn't find it." def test_ortmodule_list_input(): From 898ead8d4f180f7444d257acb9b4d1a79f4eeb07 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Tue, 27 Feb 2024 08:57:36 +0000 Subject: [PATCH 22/32] fix tests orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py --- .../ortmodule/_graph_transition_manager.py | 1 - .../python/training/ortmodule/_io.py | 2 +- .../python/training/ortmodule/_utils.py | 7 +++- .../orttraining_test_ortmodule_fallback.py | 39 ++++++++++++------- 4 files changed, 31 insertions(+), 18 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index ee0f6f982d356..94c81b70d8c60 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -338,7 +338,6 @@ def get_post_processed_model( ) # Extract the schema from the args and kwargs, and compare it with the pre-exported one if already exported. - cur_model_info_for_export = _io.parse_inputs_for_onnx_export( self._module_forward_func_parameters, args, diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 37efb6faaef8a..0053bb8186692 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -394,7 +394,7 @@ def _access_func(key, cur_func, args, kwargs): name=name, ) - raise TypeError(f"ORTModule does not support input type {type(value)} for input {name}") + raise ORTModuleIOError(f"ORTModule does not support input type {type(value)} for input {name}") visited_input_names: List[str] = [] diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index 2b9a259895793..c81259a05d477 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -200,9 +200,12 @@ def _get_device_from_inputs(args, kwargs) -> Optional[torch.device]: device = None if args: - device = torch.device(args[0].device) + if args[0] and hasattr(args[0], "device"): + device = torch.device(args[0].device) elif kwargs: - device = torch.device(next(iter(kwargs.values())).device) + v = next(iter(kwargs.values())) + if v and hasattr(v, "device"): + device = torch.device(v.device) return device diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py index 34453c89157a8..4e0fcafecffe5 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py @@ -49,6 +49,7 @@ def test_ortmodule_fallback_forward(is_training, fallback_enabled, matching_poli class Point: x: int y: int + device: str = "cpu" # Otherwise, no device can be found from inputs, and the test will fail earlier. class UnsupportedInputModel(torch.nn.Module): def __init__(self): @@ -78,11 +79,17 @@ def forward(self, point): else: with pytest.raises(_fallback.ORTModuleFallbackException) as type_error: ort_model(inputs) - assert "ORTModule fails to extract schema from data" in str(type_error.value) + assert ( + "ORTModule does not support input type .Point'> for input point" + in str(type_error.value) + ) else: with pytest.raises(_fallback.ORTModuleFallbackException) as type_error: ort_model(inputs) - assert "ORTModule fails to extract schema from data" in str(type_error.value) + assert ( + "ORTModule does not support input type .Point'> for input point" + in str(type_error.value) + ) @pytest.mark.parametrize( @@ -250,11 +257,17 @@ def test_ortmodule_fallback_output(is_training, fallback_enabled, matching_polic else: with pytest.raises(_fallback.ORTModuleIOError) as runtime_error: ort_model(x, y, z) - assert "ORTModule does not support the following model output type" in str(runtime_error.value) + assert ( + "ORTModule fails to extract schema from data: Unsupported flatten data type: " + in str(runtime_error.value) + ) else: with pytest.raises(_fallback.ORTModuleIOError) as runtime_error: ort_model(x, y, z) - assert "ORTModule does not support the following model output type" in str(runtime_error.value) + assert ( + "ORTModule fails to extract schema from data: Unsupported flatten data type: " + in str(runtime_error.value) + ) @pytest.mark.parametrize( @@ -302,20 +315,18 @@ def __init__(self, x): with pytest.raises(_fallback.ORTModuleIOError) as ex_info: _ = ort_model(torch.randn(1, 2), CustomClass(1)) assert ( - "ORTModule fails to extract schema from data: " - "Unsupported flatten data type: " - ".CustomClass'>" in str(ex_info.value) + "ORTModule does not support input type " + ".CustomClass'> " + "for input custom_class_obj" in str(ex_info.value) ) else: with pytest.raises(_fallback.ORTModuleIOError) as ex_info: _ = ort_model(torch.randn(1, 2), CustomClass(1)) - assert ( - "ORTModule fails to extract schema from data: " - "Unsupported flatten data type: " - ".CustomClass'>" in str(ex_info.value) - ) + assert ( + "ORTModule does not support input type " + ".CustomClass'> " + "for input custom_class_obj" in str(ex_info.value) + ) @pytest.mark.parametrize( From a1d1afef38adab2432a3d64dc352f4a33c0b8c83 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Tue, 27 Feb 2024 08:59:45 +0000 Subject: [PATCH 23/32] yes, another minor fix --- orttraining/orttraining/python/training/ortmodule/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index c81259a05d477..9fbc7ed267ea4 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -200,11 +200,11 @@ def _get_device_from_inputs(args, kwargs) -> Optional[torch.device]: device = None if args: - if args[0] and hasattr(args[0], "device"): + if args[0] is not None and hasattr(args[0], "device"): device = torch.device(args[0].device) elif kwargs: v = next(iter(kwargs.values())) - if v and hasattr(v, "device"): + if v is not None and hasattr(v, "device"): device = torch.device(v.device) return device From 1958aede5a6212a5165c5c9ade3f45cbdbbf9855 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Tue, 27 Feb 2024 11:09:26 +0000 Subject: [PATCH 24/32] fix memory efficient grad mangement --- .../ortmodule/_graph_transition_manager.py | 103 +++++++++++------- .../python/training/ortmodule/_io.py | 8 +- .../ortmodule/_mem_efficient_grad_mgmt.py | 21 ++-- 3 files changed, 75 insertions(+), 57 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index 94c81b70d8c60..e50d018cef115 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -125,7 +125,7 @@ def __init__( onnx_graph_input_dynamic_axes_map: dict[str, dict[int, str]], module_forward_output_schema: ORTModelInputOutputSchemaType, post_export_processed_model: onnx.ModelProto, - onnx_graph_input_data_accessor: dict[str, callable], + onnx_graph_input_data_accessor_user_defined: dict[str, callable], onnx_graph_input_const_as_tensor: dict[str, torch.device], enable_mem_efficient_grad_management: bool, ): @@ -150,15 +150,6 @@ def __init__( # Input names that require gradients for the pre-gradient-build graph. self.onnx_graph_input_names_require_grad: list[str] = copy.deepcopy(onnx_graph_input_names_require_grad) - if enable_mem_efficient_grad_management: - from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME - - # Add mem efficient grad trigger name to require_grad_names, so that it will be included in the gradient graph. - self.onnx_graph_input_names_user_defined.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) - self.onnx_graph_input_names_require_grad_user_defined.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) - self.onnx_graph_input_names.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) - self.onnx_graph_input_names_require_grad.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) - # Create symbolic names for each dimension of the graph input (e.g. onnx_graph_input_names). # The key is the input name, the value is a dict of {dim_index: symbolic_dim_name} # e.g. {"input1": {0: "input1_dim0", 1: "input1_dim1"}, "input2": {0: "input2_dim0"}} @@ -169,7 +160,9 @@ def __init__( # A function to access the input data from the args and kwargs. # If it is not None, the length is same as onnx_graph_input_names_user_defined. # For i-th input name, we can use the i-th function to get the input data from args and kwargs. - self.onnx_graph_input_data_accessor: dict[str, callable] | None = onnx_graph_input_data_accessor + self.onnx_graph_input_data_accessor_user_defined: dict[ + str, callable + ] | None = onnx_graph_input_data_accessor_user_defined self.onnx_graph_input_const_as_tensor: dict[str, torch.device] | None = onnx_graph_input_const_as_tensor @@ -218,22 +211,18 @@ def construct_inputs( # Create the buffers for the inputs that are either parameters or buffers in the original module. # For user inputs, fill with None for now, and will be filled dynamically during the forward run. - if self._enable_mem_efficient_grad_management: - from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger - - parameter_names = get_params_not_connected_to_pull_param_trigger( - self._flattened_module.named_parameters(), self._post_export_processed_model - ) - else: - parameter_names = {k: v for k, v in self._flattened_module.named_parameters()} + parameter_names = {k: v for k, v in self._flattened_module.named_parameters()} buffer_names = {k: v for k, v in self._flattened_module.named_buffers()} + for input_name in self.onnx_graph_input_names: if input_name in parameter_names: self._buffer_for_ort_runs[input_name] = parameter_names[input_name] elif input_name in buffer_names: self._buffer_for_ort_runs[input_name] = buffer_names[input_name] else: - self._buffer_for_ort_runs[input_name] = None + self._buffer_for_ort_runs[ + input_name + ] = None # Fill None for user input first, will be overridden later. for name in self.onnx_graph_input_names_user_defined: if self._enable_mem_efficient_grad_management and name == MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME: @@ -244,9 +233,9 @@ def construct_inputs( ).requires_grad_() continue - if name in self.onnx_graph_input_data_accessor: + if name in self.onnx_graph_input_data_accessor_user_defined: assert name in self._buffer_for_ort_runs, f"{name} is not in buffer_for_ort_runs" - data = self.onnx_graph_input_data_accessor[name](args, kwargs) + data = self.onnx_graph_input_data_accessor_user_defined[name](args, kwargs) if name in self.onnx_graph_input_const_as_tensor: data = PrimitiveType.get_tensor(data, device) self._buffer_for_ort_runs[name] = data @@ -391,7 +380,7 @@ def get_post_processed_model( # This looks a bit duplicated with `extract_data_and_schema` function, but this might be better to # defined as a specialized logic that is the counter-part of `parse_inputs_for_onnx_export`, which handles # args and kwargs separately. - for name, data_accessor in cur_model_info_for_export.onnx_graph_input_data_accessor.items(): + for name, data_accessor in cur_model_info_for_export.onnx_graph_input_data_accessor_user_defined.items(): d = data_accessor(copied_args, copied_kwargs) if name in cur_model_info_for_export.onnx_graph_input_const_as_tensor: flatten_inputs.append( @@ -585,11 +574,13 @@ def _reprocess_check( for input_name in exported_model_info.onnx_graph_input_names: if input_name in exported_model_info.onnx_graph_input_names_user_defined: assert ( - input_name in model_info_for_export.onnx_graph_input_data_accessor - ), f"{input_name} model_info_for_export.onnx_graph_input_data_accessor" + input_name in model_info_for_export.onnx_graph_input_data_accessor_user_defined + ), f"{input_name} model_info_for_export.onnx_graph_input_data_accessor_user_defined" # We assume the data accessor should be the same as the one used for the previous export, because # there is args and kwargs schema check during export check phase. - if model_info_for_export.onnx_graph_input_data_accessor[input_name](args, kwargs).requires_grad: + if model_info_for_export.onnx_graph_input_data_accessor_user_defined[input_name]( + args, kwargs + ).requires_grad: onnx_graph_input_requires_grads.append(input_name) else: assert input_name in parameter_names, f"{input_name} not exist parameter_names" @@ -640,30 +631,62 @@ def _post_export_process( [name for name, _ in flatten_module.named_parameters()], ) + onnx_graph_input_names_user_defined = copy.deepcopy(exported_model_info.onnx_graph_input_names_user_defined) + onnx_graph_input_names_require_grad_user_defined = copy.deepcopy( + exported_model_info.onnx_graph_input_names_require_grad_user_defined + ) + onnx_graph_input_names = copy.deepcopy(exported_model_info.onnx_graph_input_names) + onnx_graph_input_names_require_grad = copy.deepcopy(exported_model_info.onnx_graph_input_names_require_grad) + if enable_mem_efficient_grad_management: - from ._mem_efficient_grad_mgmt import post_processing_enable_mem_efficient_training + # Remove those trainable parameters from graph input, as they will be retrieved from weight pull node. + from ._mem_efficient_grad_mgmt import get_params_connected_to_pull_param_trigger - # Override the options if model is not modified. - ( - enable_mem_efficient_grad_management, - post_processed_model, - ) = post_processing_enable_mem_efficient_training(post_processed_model, flatten_module.named_parameters()) + # MUST call this before post_processing_enable_mem_efficient_training, otherwise, the onnx graph input + # will be modified. + parameter_not_as_graph_input_names = get_params_connected_to_pull_param_trigger( + flatten_module.named_parameters(), post_processed_model + ) + + if len(parameter_not_as_graph_input_names) > 0: + for k in parameter_not_as_graph_input_names: + if k in onnx_graph_input_names: + onnx_graph_input_names.remove(k) + + if k in onnx_graph_input_names_require_grad: + onnx_graph_input_names_require_grad.remove(k) + + from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME + + # Add mem efficient grad trigger name to require_grad_names, so that it will be included in the gradient graph. + onnx_graph_input_names_user_defined.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) + onnx_graph_input_names_require_grad_user_defined.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) + onnx_graph_input_names.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) + onnx_graph_input_names_require_grad.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) - if run_symbolic_shape_infer: - post_processed_model = SymbolicShapeInference.infer_shapes( - post_processed_model, auto_merge=True, guess_output_rank=True + from ._mem_efficient_grad_mgmt import post_processing_enable_mem_efficient_training + + # Override the options if model is not modified. + + post_processed_model = post_processing_enable_mem_efficient_training( + post_processed_model, flatten_module.named_parameters(), parameter_not_as_graph_input_names ) + if run_symbolic_shape_infer: + post_processed_model = SymbolicShapeInference.infer_shapes( + post_processed_model, auto_merge=True, guess_output_rank=True + ) + post_export_processed_model_info = PostExportProcessedModelInfo( flatten_module, - exported_model_info.onnx_graph_input_names_user_defined, - exported_model_info.onnx_graph_input_names_require_grad_user_defined, - exported_model_info.onnx_graph_input_names, - exported_model_info.onnx_graph_input_names_require_grad, + onnx_graph_input_names_user_defined, + onnx_graph_input_names_require_grad_user_defined, + onnx_graph_input_names, + onnx_graph_input_names_require_grad, model_info_for_export.onnx_graph_input_dynamic_axes_map, exported_model_info.module_forward_output_schema, post_processed_model, - model_info_for_export.onnx_graph_input_data_accessor, + model_info_for_export.onnx_graph_input_data_accessor_user_defined, model_info_for_export.onnx_graph_input_const_as_tensor, enable_mem_efficient_grad_management, ) diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 0053bb8186692..279a175e34498 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -154,7 +154,7 @@ def __init__( onnx_graph_input_names_require_grad: List[str], onnx_graph_input_dynamic_axes_map: Dict[str, Dict[int, str]], onnx_graph_input_shapes: List[List[int]], - onnx_graph_input_data_accessor: Optional[Dict[str, callable]] = None, + onnx_graph_input_data_accessor_user_defined: Optional[Dict[str, callable]] = None, onnx_graph_input_const_as_tensor: Optional[Dict[str, torch.device]] = None, onnx_graph_input_arg_schema: Optional[Dict[str, ORTModelInputOutputSchemaType]] = None, onnx_graph_input_kwarg_schema: Optional[Dict[str, ORTModelInputOutputSchemaType]] = None, @@ -202,7 +202,9 @@ def __init__( # A function to access the input data from the args and kwargs. # If it is not None, the length is same as onnx_graph_input_names. # For i-th input name, we can use the i-th function to get the input data from args and kwargs. - self.onnx_graph_input_data_accessor: Optional[Dict[str, callable]] = onnx_graph_input_data_accessor + self.onnx_graph_input_data_accessor_user_defined: Optional[Dict[str, callable]] = ( + onnx_graph_input_data_accessor_user_defined + ) self.onnx_graph_input_const_as_tensor: Optional[Dict[str, torch.device]] = onnx_graph_input_const_as_tensor @@ -491,7 +493,7 @@ def _access_func(key, cur_func, args, kwargs): onnx_graph_input_names_require_grad=input_names_require_grad, onnx_graph_input_dynamic_axes_map=dynamic_axes, onnx_graph_input_shapes=input_shape, - onnx_graph_input_data_accessor=data_accessors, + onnx_graph_input_data_accessor_user_defined=data_accessors, onnx_graph_input_const_as_tensor=const_to_tensor_inputs, onnx_graph_input_arg_schema=input_arg_schema, onnx_graph_input_kwarg_schema=input_kwarg_schema, diff --git a/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py b/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py index 61e226307bef2..7153447c64ada 100644 --- a/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py +++ b/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py @@ -38,20 +38,13 @@ def get_params_not_connected_to_pull_param_trigger( def post_processing_enable_mem_efficient_training( exported_model: ModelProto, named_params: dict[str, torch.nn.parameter.Parameter], -) -> tuple[bool, ModelProto]: - """This function is used to enable zero stage3 compatibility. - - Args: - exported_model (ModelProto): The exported model. - named_params (Optional[Dict[str, torch.nn.parameter.Parameter]]): The full parameter map. - - Returns: - tuple[bool, ModelProto]: A tuple of bool and ModelProto. The bool indicates whether the model is modified. - - """ - trainable_named_params = get_params_connected_to_pull_param_trigger(named_params, exported_model) + trainable_named_params: dict[str, torch.nn.parameter.Parameter], +) -> ModelProto: + """This function is used to enable memory efficient gradient management.""" + # trainable_named_params = get_params_connected_to_pull_param_trigger(named_params, exported_model) if len(trainable_named_params) == 0: - return False, exported_model + return exported_model + # return False, exported_model # Create weight retrieving function using trainable_named_params. param_pull_trigger_func_class = _create_param_trigger_function(trainable_named_params) @@ -144,7 +137,7 @@ def _get_param_pull_trigger_name(param_name: str) -> str: exported_model.graph.input.insert(offset, inputs[0]) exported_model.graph.node.insert(0, weight_pull_node) - return True, exported_model + return exported_model _PARAM_FUNCTION_INDEX = [0] From 49cf04148d6370ec9aa4e3015441a77a4fca8b84 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Tue, 27 Feb 2024 11:17:20 +0000 Subject: [PATCH 25/32] minor --- orttraining/orttraining/python/training/ortmodule/_io.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 279a175e34498..b6e1406524113 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -202,9 +202,9 @@ def __init__( # A function to access the input data from the args and kwargs. # If it is not None, the length is same as onnx_graph_input_names. # For i-th input name, we can use the i-th function to get the input data from args and kwargs. - self.onnx_graph_input_data_accessor_user_defined: Optional[Dict[str, callable]] = ( - onnx_graph_input_data_accessor_user_defined - ) + self.onnx_graph_input_data_accessor_user_defined: Optional[ + Dict[str, callable] + ] = onnx_graph_input_data_accessor_user_defined self.onnx_graph_input_const_as_tensor: Optional[Dict[str, torch.device]] = onnx_graph_input_const_as_tensor From c4ebb6e9e8f7c29b93da854ded74ec605ad214bd Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Wed, 12 Jun 2024 11:16:43 +0000 Subject: [PATCH 26/32] lint --- .../ortmodule/_graph_execution_manager.py | 13 +--- .../ortmodule/_graph_transition_manager.py | 64 +++++++++++++++++-- .../python/training/ortmodule/_io.py | 8 +-- .../python/training/ortmodule/_logger.py | 8 ++- 4 files changed, 68 insertions(+), 25 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 8be1923eef9bc..f89fd48b06289 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -7,10 +7,9 @@ import logging import os from abc import ABC, abstractmethod # noqa: F401 -from typing import Dict, List, Optional, OrderedDict, Tuple from functools import partial from hashlib import md5 as hash_fn -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, OrderedDict, Tuple import onnx import torch @@ -18,12 +17,6 @@ import onnxruntime from onnxruntime.capi import _pybind_state as C - - - - -from ._graph_transition_manager import GraphTransitionManager, PostExportProcessedModelInfo - from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference from onnxruntime.training.utils import ORTModelInputOutputSchemaType, PTable, onnx_dtype_to_pytorch_dtype @@ -38,6 +31,7 @@ ) from ._gradient_accumulation_manager import GradientAccumulationManager from ._graph_execution_interface import GraphExecutionInterface +from ._graph_transition_manager import GraphTransitionManager, PostExportProcessedModelInfo from ._io import _FlattenedModule, _InputInfo from ._logger import LogColor from ._runtime_inspector import FlagAndPrintDensity, RuntimeInspector @@ -474,7 +468,6 @@ def _detect_from_inputs(self, inputs: Tuple, kwargs: Dict): if self._runtime_options.enable_zero_stage3_support or self._mem_efficient_grad_management_is_enabled: self._append_pull_weight_trigger_as_input(kwargs, detected_device) - if ( self._runtime_inspector.memory_ob.is_enabled() and not self._runtime_inspector.memory_ob.symbolic_dim_collecting_completed @@ -489,8 +482,6 @@ def _detect_from_inputs(self, inputs: Tuple, kwargs: Dict): if self._mem_efficient_grad_management_is_enabled: from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger - - if self._runtime_inspector._sceloss_module_to_ignore_density_map: self._runtime_options.label_sparsity_ratio = ",".join( [f"{k}:{v:.0f}%" for k, v in self._runtime_inspector._sceloss_module_to_ignore_density_map.items()] diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index e50d018cef115..9b7bfa0582567 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -160,9 +160,9 @@ def __init__( # A function to access the input data from the args and kwargs. # If it is not None, the length is same as onnx_graph_input_names_user_defined. # For i-th input name, we can use the i-th function to get the input data from args and kwargs. - self.onnx_graph_input_data_accessor_user_defined: dict[ - str, callable - ] | None = onnx_graph_input_data_accessor_user_defined + self.onnx_graph_input_data_accessor_user_defined: dict[str, callable] | None = ( + onnx_graph_input_data_accessor_user_defined + ) self.onnx_graph_input_const_as_tensor: dict[str, torch.device] | None = onnx_graph_input_const_as_tensor @@ -220,9 +220,9 @@ def construct_inputs( elif input_name in buffer_names: self._buffer_for_ort_runs[input_name] = buffer_names[input_name] else: - self._buffer_for_ort_runs[ - input_name - ] = None # Fill None for user input first, will be overridden later. + self._buffer_for_ort_runs[input_name] = ( + None # Fill None for user input first, will be overridden later. + ) for name in self.onnx_graph_input_names_user_defined: if self._enable_mem_efficient_grad_management and name == MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME: @@ -723,6 +723,11 @@ def _export_model( time_tracker: TimeTracker, logger: logging.Logger, ) -> tuple[onnx.ModelProto, ORTModelInputOutputSchemaType, list[str], list[str]]: + + # Add hooks to check the sparsity of the embedding and label inputs during the export. + embedding_hook_handles = self._add_check_embedding_sparsity_hook() + label_hook_handles = self._add_check_label_sparsity_hook() + # Record random states here and restore later in case any of them gets changed during the export, # e.g., some sympy functions in symbolic_shape_infer will change Python's random state. random_states = _utils.get_random_states() @@ -757,6 +762,13 @@ def _export_model( # Restore the recorded random states _utils.set_random_states(random_states) + # Clean up all hooks. + for hook in embedding_hook_handles: + hook.remove() + + for hook in label_hook_handles: + hook.remove() + return exported_model, module_output_schema, onnx_graph_input_names, onnx_graph_input_names_require_grad @staticmethod @@ -873,6 +885,46 @@ def _get_exported_model( **model_info_for_export.export_extra_kwargs, ) except Exception as e: + message = _utils.get_exception_as_string(e) + + # Special handling when Huggingface transformers gradient checkpoint usage pattern found. + # For new versions of PyTorch 2, tracing torch.utils.checkpoint.checkpoint will be failed like this: + # File "microsoft/phi-2/b10c3eba545ad279e7208ee3a5d644566f001670/modeling_phi.py", line 919, in forward + # layer_outputs = self._gradient_checkpointing_func( + # File "/site-packages/torch/_compile.py", line 24, in inner + # return torch._dynamo.disable(fn, recursive)(*args, **kwargs) + # File "/site-packages/torch/_dynamo/eval_frame.py", line 470, in _fn + # raise RuntimeError( + # RuntimeError: Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment. + if ( + "_gradient_checkpointing_func" in message + and "Detected that you are using FX to torch.jit.trace a dynamo-optimized function" in message + ): + is_ckpt_activation_allowed = int(os.getenv("ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT", "0")) == 1 + notes = ( + " Your model is running with gradient checkpointing, yet the PyTorch exporter\n" + " failed during tracing the graph. Try to enable ORTModule's\n" + " gradient checkpointing (a.k.a. Transformer layerwise subgraph recompute)\n" + " using `export ORTMODULE_MEMORY_OPT_LEVEL=1` for similar or even better memory efficiency.\n" + ) + if is_ckpt_activation_allowed: + # If the user allows the gradient checkpointing export, we should inform the user to disable it, + # to make layerwise recompute work. + notes += ( + " We also notice your setting `export ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=1`,\n" + " which enables gradient checkpointing torch.autograd.Functions(s) to export.\n" + " To enable ORTModule's layerwise recompute, it needs to be turned OFF by\n" + " `export ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=0`.\n" + ) + + self._logger.error( + f"{LogColor.RED}\n" + "******************************** IMPORTANT NOTE *******************************\n" + f"{notes}" + "*******************************************************************************\n" + f"{LogColor.ENDC}\n" + ) + raise wrap_exception( # noqa: B904 ORTModuleONNXModelException, RuntimeError( diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index b6e1406524113..8ad3d0df3e4fa 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -202,9 +202,9 @@ def __init__( # A function to access the input data from the args and kwargs. # If it is not None, the length is same as onnx_graph_input_names. # For i-th input name, we can use the i-th function to get the input data from args and kwargs. - self.onnx_graph_input_data_accessor_user_defined: Optional[ - Dict[str, callable] - ] = onnx_graph_input_data_accessor_user_defined + self.onnx_graph_input_data_accessor_user_defined: Optional[Dict[str, callable]] = ( + onnx_graph_input_data_accessor_user_defined + ) self.onnx_graph_input_const_as_tensor: Optional[Dict[str, torch.device]] = onnx_graph_input_const_as_tensor @@ -235,8 +235,6 @@ def _kwarg_access_with_name_func(name, args, kwargs): class SkipRetValue: """A placeholder class to indicate that the return value of a function should be skipped""" - pass - def parse_inputs_for_onnx_export( all_input_parameters: List[inspect.Parameter], diff --git a/orttraining/orttraining/python/training/ortmodule/_logger.py b/orttraining/orttraining/python/training/ortmodule/_logger.py index 6cfd947117d5d..4d54e8e59fb50 100644 --- a/orttraining/orttraining/python/training/ortmodule/_logger.py +++ b/orttraining/orttraining/python/training/ortmodule/_logger.py @@ -285,9 +285,11 @@ def wrapper(*args, **kwargs): on_exit=partial( _log_with_filter, kwargs["logger"], - kwargs["debug_options"].onnxruntime_log_filter - if self.is_ort_filter - else kwargs["debug_options"].torch_exporter_filter, + ( + kwargs["debug_options"].onnxruntime_log_filter + if self.is_ort_filter + else kwargs["debug_options"].torch_exporter_filter + ), self.phase.to_string(), ), ): From a52690573c9559163b23dac2049636a64bd43806 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Wed, 12 Jun 2024 11:52:44 +0000 Subject: [PATCH 27/32] fixes --- .../ortmodule/_graph_execution_manager.py | 90 ------------------- .../ortmodule/_graph_transition_manager.py | 79 +++++++++++++++- 2 files changed, 78 insertions(+), 91 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index f89fd48b06289..8674f47566488 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -57,7 +57,6 @@ def __init__( debug_options: DebugOptions, export_mode: int, fallback_manager: _FallbackManager, - export_mode: int, logger: logging.Logger, ): """Manages construction and execution of ONNX graphs. @@ -86,7 +85,6 @@ def __init__( self._flattened_module = module self._onnx_models = _onnx_models.ONNXModels() - self._export_mode = export_mode self._graph_transition_manager: Optional[GraphTransitionManager] = None # Model after inference optimization and then gradient building. @@ -105,17 +103,6 @@ def __init__( # To be instantiated in the concrete implementation of GraphExecutionManager self._export_mode = export_mode - # Exporter can take extra arguments for ORTModule extensions - # It cannot overlap with required/immutable arguments (validated in runtime) - self._export_extra_kwargs = {} - - # Input and output infos (including schema) for exported model. - self._input_info: Optional[_InputInfo] = None - self._module_output_schema: Optional[ORTModelInputOutputSchemaType] = None - - # Device where the model is placed. - self._device: Optional[torch.device] = _utils.get_device_from_module(module) - # Forward function input parameters of the original module. self._module_parameters: List[inspect.Parameter] = list( inspect.signature(self._original_module.forward).parameters.values() @@ -372,83 +359,6 @@ def _device(self): # Graph transition manager is responsible for detecting and managing the device to use. return self._graph_transition_manager._device - def _add_check_embedding_sparsity_hook(self): - """ - Add hook to check embedding sparsity and enable padding elimination if applicable. - 1. Iterate through all modules to find Embedding modules with padding_idx >= 0. - 2. Register forward pre hook to the Embedding module and the hook will check sparsity of the embedding input. - 3. If the sparsity is below a threshold, enable padding elimination by adding FlagAndPrintDensity after the - output. GraphTransformer of PaddingElimination will check the FlagAndPrintDensity and do the actual - padding elimination graph modification. - 4. Return the hook handles for later removal. - - """ - if not self._runtime_options.enable_embedding_sparse_optimizer or self._device.type != "cuda": - return [] - - def _embedding_hook(name, module, args): - ebd_input = args[0] - if ebd_input is None or not isinstance(ebd_input, torch.Tensor): - self._logger.warning("Embedding input is not a tensor.") - return None - - valid_token = torch.count_nonzero(ebd_input - module.padding_idx) - total_token = ebd_input.numel() - embed_density = float(valid_token) / float(total_token) * 100 - - if embed_density < 90: - self._logger.info("Embedding sparsity-based optimization is ON for density: %.0f%%", embed_density) - self._runtime_inspector._embedding_module_to_padding_density_map[name] = embed_density - return FlagAndPrintDensity.apply(args[0], module.padding_idx, "embedding") - else: - self._logger.info("Embedding sparsity-based optimization is OFF for density: %.0f%%", embed_density) - return None - - embedding_hook_handles = [] - for name, sub_module in self._flattened_module.named_modules(): - if isinstance(sub_module, torch.nn.modules.sparse.Embedding): - if sub_module.padding_idx is not None and sub_module.padding_idx >= 0: - embedding_hook_handles.append(sub_module.register_forward_pre_hook(partial(_embedding_hook, name))) - - return embedding_hook_handles - - def _add_check_label_sparsity_hook(self): - """ - Add hook to check label sparsity and enable sceloss compute optimization if applicable. - 1. Register forward pre hook to the sceloss module in the model and the hook will check sparsity of the label input. - 2. If the sparsity is below a threshold, enable sceloss compute optimization by adding FlagAndPrintDensity after the - output. GraphTransformer of InsertGatherBeforeSceLoss will check the FlagAndPrintDensity and do the actual - sceloss compute optimization graph modification. - - """ - if not self._runtime_options.enable_label_sparse_optimizer: - return None - - def _label_hook(name, module, args): - label_input = args[1] - if label_input is None or not isinstance(label_input, torch.Tensor): - self._logger.warning("Label input is not a tensor.") - return None - - valid_token = torch.count_nonzero(label_input - module.ignore_index) - total_token = label_input.numel() - label_density = float(valid_token) / float(total_token) * 100 - - if label_density < 90: - self._logger.info("Label sparsity-based optimization is ON for density: %.0f%%", label_density) - self._runtime_inspector._sceloss_module_to_ignore_density_map[name] = label_density - return (args[0], FlagAndPrintDensity.apply(args[1], module.ignore_index, "label")) - else: - self._logger.info("Label sparsity-based optimization is OFF for density: %.0f%%", label_density) - return None - - label_check_hook_handles = [] - for name, sub_module in self._flattened_module.named_modules(): - if isinstance(sub_module, torch.nn.modules.loss.CrossEntropyLoss): - label_check_hook_handles.append(sub_module.register_forward_pre_hook(partial(_label_hook, name))) - - return label_check_hook_handles - @_logger.TrackTime(_logger.ORTModuleInitPhase.DETECTION) def _detect_from_inputs(self, inputs: Tuple, kwargs: Dict): """ diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index 9b7bfa0582567..ccee108aa875f 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -28,7 +28,7 @@ from . import _io, _utils from ._fallback import ORTModuleDeviceException, ORTModuleIOError, ORTModuleONNXModelException, wrap_exception -from ._logger import LogLevel, ORTModuleInitPhase, SuppressLogs, TimeTracker, TrackTimeForStaticFunction +from ._logger import LogColor, LogLevel, ORTModuleInitPhase, SuppressLogs, TimeTracker, TrackTimeForStaticFunction from ._onnx_models import _get_onnx_file_name, _save_model from ._utils import check_function_has_param, get_rank from ._zero_stage3_compatibility import stage3_export_context @@ -949,3 +949,80 @@ def _get_exported_model( def signal_model_changed(self): """Signals the execution manager to re-export the model on the next forward call""" self._original_model_has_changed = True + + def _add_check_embedding_sparsity_hook(self): + """ + Add hook to check embedding sparsity and enable padding elimination if applicable. + 1. Iterate through all modules to find Embedding modules with padding_idx >= 0. + 2. Register forward pre hook to the Embedding module and the hook will check sparsity of the embedding input. + 3. If the sparsity is below a threshold, enable padding elimination by adding FlagAndPrintDensity after the + output. GraphTransformer of PaddingElimination will check the FlagAndPrintDensity and do the actual + padding elimination graph modification. + 4. Return the hook handles for later removal. + + """ + if not self._runtime_options.enable_embedding_sparse_optimizer or self._device.type != "cuda": + return [] + + def _embedding_hook(name, module, args): + ebd_input = args[0] + if ebd_input is None or not isinstance(ebd_input, torch.Tensor): + self._logger.warning("Embedding input is not a tensor.") + return None + + valid_token = torch.count_nonzero(ebd_input - module.padding_idx) + total_token = ebd_input.numel() + embed_density = float(valid_token) / float(total_token) * 100 + + if embed_density < 90: + self._logger.info("Embedding sparsity-based optimization is ON for density: %.0f%%", embed_density) + self._runtime_inspector._embedding_module_to_padding_density_map[name] = embed_density + return FlagAndPrintDensity.apply(args[0], module.padding_idx, "embedding") + else: + self._logger.info("Embedding sparsity-based optimization is OFF for density: %.0f%%", embed_density) + return None + + embedding_hook_handles = [] + for name, sub_module in self._flattened_module.named_modules(): + if isinstance(sub_module, torch.nn.modules.sparse.Embedding): + if sub_module.padding_idx is not None and sub_module.padding_idx >= 0: + embedding_hook_handles.append(sub_module.register_forward_pre_hook(partial(_embedding_hook, name))) + + return embedding_hook_handles + + def _add_check_label_sparsity_hook(self): + """ + Add hook to check label sparsity and enable sceloss compute optimization if applicable. + 1. Register forward pre hook to the sceloss module in the model and the hook will check sparsity of the label input. + 2. If the sparsity is below a threshold, enable sceloss compute optimization by adding FlagAndPrintDensity after the + output. GraphTransformer of InsertGatherBeforeSceLoss will check the FlagAndPrintDensity and do the actual + sceloss compute optimization graph modification. + + """ + if not self._runtime_options.enable_label_sparse_optimizer: + return None + + def _label_hook(name, module, args): + label_input = args[1] + if label_input is None or not isinstance(label_input, torch.Tensor): + self._logger.warning("Label input is not a tensor.") + return None + + valid_token = torch.count_nonzero(label_input - module.ignore_index) + total_token = label_input.numel() + label_density = float(valid_token) / float(total_token) * 100 + + if label_density < 90: + self._logger.info("Label sparsity-based optimization is ON for density: %.0f%%", label_density) + self._runtime_inspector._sceloss_module_to_ignore_density_map[name] = label_density + return (args[0], FlagAndPrintDensity.apply(args[1], module.ignore_index, "label")) + else: + self._logger.info("Label sparsity-based optimization is OFF for density: %.0f%%", label_density) + return None + + label_check_hook_handles = [] + for name, sub_module in self._flattened_module.named_modules(): + if isinstance(sub_module, torch.nn.modules.loss.CrossEntropyLoss): + label_check_hook_handles.append(sub_module.register_forward_pre_hook(partial(_label_hook, name))) + + return label_check_hook_handles From cc3871a65490fc7530473360a0514493efb66a8f Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Thu, 13 Jun 2024 04:25:41 +0000 Subject: [PATCH 28/32] fix lints --- .../ortmodule/_graph_execution_manager.py | 45 ++++---------- .../ortmodule/_graph_transition_manager.py | 60 +++++++++++++------ .../ortmodule/_mem_efficient_grad_mgmt.py | 8 --- 3 files changed, 55 insertions(+), 58 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 8674f47566488..bc3981ff67010 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -7,9 +7,7 @@ import logging import os from abc import ABC, abstractmethod # noqa: F401 -from functools import partial -from hashlib import md5 as hash_fn -from typing import Dict, List, Optional, OrderedDict, Tuple +from typing import Dict, List, Optional, Tuple import onnx import torch @@ -17,25 +15,16 @@ import onnxruntime from onnxruntime.capi import _pybind_state as C -from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference -from onnxruntime.training.utils import ORTModelInputOutputSchemaType, PTable, onnx_dtype_to_pytorch_dtype - -from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils, export_context -from ._fallback import ( - ORTModuleDeviceException, - ORTModuleONNXModelException, - ORTModuleTorchModelException, - _FallbackManager, - _FallbackPolicy, - wrap_exception, -) +from onnxruntime.training.utils import PTable, onnx_dtype_to_pytorch_dtype + +from . import _are_deterministic_algorithms_enabled, _logger, _onnx_models, _utils +from ._fallback import ORTModuleTorchModelException, _FallbackManager, _FallbackPolicy, wrap_exception from ._gradient_accumulation_manager import GradientAccumulationManager from ._graph_execution_interface import GraphExecutionInterface from ._graph_transition_manager import GraphTransitionManager, PostExportProcessedModelInfo -from ._io import _FlattenedModule, _InputInfo -from ._logger import LogColor -from ._runtime_inspector import FlagAndPrintDensity, RuntimeInspector -from ._utils import check_function_has_param, get_rank +from ._io import _FlattenedModule +from ._runtime_inspector import RuntimeInspector +from ._utils import get_rank from .options import DebugOptions, LogLevel, _MemoryOptimizationLevel, _RuntimeOptions from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension @@ -103,16 +92,6 @@ def __init__( # To be instantiated in the concrete implementation of GraphExecutionManager self._export_mode = export_mode - # Forward function input parameters of the original module. - self._module_parameters: List[inspect.Parameter] = list( - inspect.signature(self._original_module.forward).parameters.values() - ) - - # TODO: remove after PyTorch ONNX exporter supports VAR_KEYWORD parameters. - for input_parameter in self._module_parameters: - if input_parameter.kind == inspect.Parameter.VAR_KEYWORD: - self._logger.info("The model's forward method has **kwargs parameter which has EXPERIMENTAL support!") - self.is_rocm_pytorch = bool(torch.version.hip is not None and ROCM_HOME is not None) # WIP feature to enable caching in Gradient accumulation scenario. @@ -163,6 +142,7 @@ def _initialize_graph_transition_manager(self): debug_options=self._debug_options, runtime_options=self._runtime_options, time_tracker=self.time_tracker, + runtime_inspector=self._runtime_inspector, logger=self._logger, ) @@ -382,16 +362,15 @@ def _detect_from_inputs(self, inputs: Tuple, kwargs: Dict): self._runtime_inspector.memory_ob.is_enabled() and not self._runtime_inspector.memory_ob.symbolic_dim_collecting_completed ): + prepared_input_map = self._graph_transition_manager._post_export_processed_model_info.construct_inputs( + inputs, kwargs, True, self._device + ) self._runtime_inspector.memory_ob.collect_symbolic_dim_values( self._graph_transition_manager._post_export_processed_model_info.onnx_graph_input_dynamic_axes_map, prepared_input_map, ) self._runtime_inspector.memory_ob.symbolic_dim_collecting_completed = True - param_to_append_as_onnx_graph_inputs = [] - if self._mem_efficient_grad_management_is_enabled: - from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger - if self._runtime_inspector._sceloss_module_to_ignore_density_map: self._runtime_options.label_sparsity_ratio = ",".join( [f"{k}:{v:.0f}%" for k, v in self._runtime_inspector._sceloss_module_to_ignore_density_map.items()] diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index ccee108aa875f..a1aec1307c67e 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -11,6 +11,7 @@ import logging import os from collections import OrderedDict +from functools import partial from hashlib import md5 as hash_fn from typing import Mapping, Sequence @@ -30,6 +31,7 @@ from ._fallback import ORTModuleDeviceException, ORTModuleIOError, ORTModuleONNXModelException, wrap_exception from ._logger import LogColor, LogLevel, ORTModuleInitPhase, SuppressLogs, TimeTracker, TrackTimeForStaticFunction from ._onnx_models import _get_onnx_file_name, _save_model +from ._runtime_inspector import FlagAndPrintDensity, RuntimeInspector from ._utils import check_function_has_param, get_rank from ._zero_stage3_compatibility import stage3_export_context from .options import DebugOptions, _RuntimeOptions @@ -269,6 +271,7 @@ def __init__( debug_options: DebugOptions, runtime_options: _RuntimeOptions, time_tracker: TimeTracker, + runtime_inspector: RuntimeInspector, logger: logging.Logger, ): self._device = _utils._get_device_from_module(flatten_module) @@ -284,6 +287,8 @@ def __init__( # Tracker for ORTModule model export. self._time_tracker = time_tracker + self._runtime_inspector = runtime_inspector + # A signal to indicate if the original model has changed and need a re-export. self._original_model_has_changed = False @@ -416,10 +421,12 @@ def get_post_processed_model( ortmodule_cache_dir=self._runtime_options.ortmodule_cache_dir, enable_custom_autograd_function=self._runtime_options.enable_custom_autograd_function, enable_zero_stage3_support=self._runtime_options.enable_zero_stage3_support, + enable_embedding_sparse_optimizer=self._runtime_options.enable_embedding_sparse_optimizer, onnx_opset_version=self._runtime_options.onnx_opset_version, stage3_param_handle=self, debug_options=self._debug_options, time_tracker=self._time_tracker, + runtime_inspector=self._runtime_inspector, logger=self._logger, ) @@ -717,16 +724,22 @@ def _export_model( ortmodule_cache_dir: str, enable_custom_autograd_function: bool, enable_zero_stage3_support: bool, + enable_embedding_sparse_optimizer: bool, onnx_opset_version: int, stage3_param_handle: type, debug_options: DebugOptions, time_tracker: TimeTracker, + runtime_inspector: RuntimeInspector, logger: logging.Logger, ) -> tuple[onnx.ModelProto, ORTModelInputOutputSchemaType, list[str], list[str]]: # Add hooks to check the sparsity of the embedding and label inputs during the export. - embedding_hook_handles = self._add_check_embedding_sparsity_hook() - label_hook_handles = self._add_check_label_sparsity_hook() + embedding_hook_handles = GraphTransitionManager._add_check_embedding_sparsity_hook( + enable_embedding_sparse_optimizer, device, logger, runtime_inspector, flattened_module + ) + label_hook_handles = GraphTransitionManager._add_check_label_sparsity_hook( + enable_embedding_sparse_optimizer, logger, runtime_inspector + ) # Record random states here and restore later in case any of them gets changed during the export, # e.g., some sympy functions in symbolic_shape_infer will change Python's random state. @@ -917,7 +930,7 @@ def _get_exported_model( " `export ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=0`.\n" ) - self._logger.error( + logger.error( f"{LogColor.RED}\n" "******************************** IMPORTANT NOTE *******************************\n" f"{notes}" @@ -950,7 +963,14 @@ def signal_model_changed(self): """Signals the execution manager to re-export the model on the next forward call""" self._original_model_has_changed = True - def _add_check_embedding_sparsity_hook(self): + @staticmethod + def _add_check_embedding_sparsity_hook( + enable_embedding_sparse_optimizer: bool, + device: torch.device, + logger: logging.Logger, + runtime_inspector: RuntimeInspector, + flattened_module: torch.nn.Module, + ) -> list: """ Add hook to check embedding sparsity and enable padding elimination if applicable. 1. Iterate through all modules to find Embedding modules with padding_idx >= 0. @@ -961,13 +981,13 @@ def _add_check_embedding_sparsity_hook(self): 4. Return the hook handles for later removal. """ - if not self._runtime_options.enable_embedding_sparse_optimizer or self._device.type != "cuda": + if not enable_embedding_sparse_optimizer or device.type != "cuda": return [] def _embedding_hook(name, module, args): ebd_input = args[0] if ebd_input is None or not isinstance(ebd_input, torch.Tensor): - self._logger.warning("Embedding input is not a tensor.") + logger.warning("Embedding input is not a tensor.") return None valid_token = torch.count_nonzero(ebd_input - module.padding_idx) @@ -975,22 +995,28 @@ def _embedding_hook(name, module, args): embed_density = float(valid_token) / float(total_token) * 100 if embed_density < 90: - self._logger.info("Embedding sparsity-based optimization is ON for density: %.0f%%", embed_density) - self._runtime_inspector._embedding_module_to_padding_density_map[name] = embed_density + logger.info("Embedding sparsity-based optimization is ON for density: %.0f%%", embed_density) + runtime_inspector._embedding_module_to_padding_density_map[name] = embed_density return FlagAndPrintDensity.apply(args[0], module.padding_idx, "embedding") else: - self._logger.info("Embedding sparsity-based optimization is OFF for density: %.0f%%", embed_density) + logger.info("Embedding sparsity-based optimization is OFF for density: %.0f%%", embed_density) return None embedding_hook_handles = [] - for name, sub_module in self._flattened_module.named_modules(): + for name, sub_module in flattened_module.named_modules(): if isinstance(sub_module, torch.nn.modules.sparse.Embedding): if sub_module.padding_idx is not None and sub_module.padding_idx >= 0: embedding_hook_handles.append(sub_module.register_forward_pre_hook(partial(_embedding_hook, name))) return embedding_hook_handles - def _add_check_label_sparsity_hook(self): + @staticmethod + def _add_check_label_sparsity_hook( + enable_label_sparse_optimizer: bool, + logger: logging.Logger, + runtime_inspector: RuntimeInspector, + flattened_module: torch.nn.Module, + ) -> list: """ Add hook to check label sparsity and enable sceloss compute optimization if applicable. 1. Register forward pre hook to the sceloss module in the model and the hook will check sparsity of the label input. @@ -999,13 +1025,13 @@ def _add_check_label_sparsity_hook(self): sceloss compute optimization graph modification. """ - if not self._runtime_options.enable_label_sparse_optimizer: + if not enable_label_sparse_optimizer: return None def _label_hook(name, module, args): label_input = args[1] if label_input is None or not isinstance(label_input, torch.Tensor): - self._logger.warning("Label input is not a tensor.") + logger.warning("Label input is not a tensor.") return None valid_token = torch.count_nonzero(label_input - module.ignore_index) @@ -1013,15 +1039,15 @@ def _label_hook(name, module, args): label_density = float(valid_token) / float(total_token) * 100 if label_density < 90: - self._logger.info("Label sparsity-based optimization is ON for density: %.0f%%", label_density) - self._runtime_inspector._sceloss_module_to_ignore_density_map[name] = label_density + logger.info("Label sparsity-based optimization is ON for density: %.0f%%", label_density) + runtime_inspector._sceloss_module_to_ignore_density_map[name] = label_density return (args[0], FlagAndPrintDensity.apply(args[1], module.ignore_index, "label")) else: - self._logger.info("Label sparsity-based optimization is OFF for density: %.0f%%", label_density) + logger.info("Label sparsity-based optimization is OFF for density: %.0f%%", label_density) return None label_check_hook_handles = [] - for name, sub_module in self._flattened_module.named_modules(): + for name, sub_module in flattened_module.named_modules(): if isinstance(sub_module, torch.nn.modules.loss.CrossEntropyLoss): label_check_hook_handles.append(sub_module.register_forward_pre_hook(partial(_label_hook, name))) diff --git a/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py b/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py index 8f4f6204bcfff..93d151ea1217d 100644 --- a/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py +++ b/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py @@ -26,14 +26,6 @@ def get_params_connected_to_pull_param_trigger( return {k: v for k, v in named_params if v.requires_grad and k in onnx_initializer_names} -def get_params_not_connected_to_pull_param_trigger( - named_params: dict[str, torch.nn.parameter.Parameter], exported_model: ModelProto -): - # Be noted, some parameters might not in graph input because they are not used in forward, so we filtered them also. - onnx_initializer_names = {p.name for p in exported_model.graph.input} - return {k: v for k, v in named_params if not v.requires_grad and k in onnx_initializer_names} - - def post_processing_enable_mem_efficient_training( exported_model: ModelProto, named_params: dict[str, torch.nn.parameter.Parameter], From 523e63e79cf240590a1c1b09e37e6ff6068b10d0 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Thu, 13 Jun 2024 04:29:43 +0000 Subject: [PATCH 29/32] minor --- .../python/training/ortmodule/_graph_transition_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index a1aec1307c67e..5be0d38513029 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -738,7 +738,7 @@ def _export_model( enable_embedding_sparse_optimizer, device, logger, runtime_inspector, flattened_module ) label_hook_handles = GraphTransitionManager._add_check_label_sparsity_hook( - enable_embedding_sparse_optimizer, logger, runtime_inspector + enable_embedding_sparse_optimizer, logger, runtime_inspector, flattened_module ) # Record random states here and restore later in case any of them gets changed during the export, From 4737bd4d0c26d4617948867a97d41d0aaf778a49 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Thu, 13 Jun 2024 04:38:51 +0000 Subject: [PATCH 30/32] fix --- .../training/ortmodule/_graph_execution_manager.py | 10 +++++----- .../training/ortmodule/_graph_transition_manager.py | 6 +++++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index bc3981ff67010..91b6e663d010b 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -125,6 +125,10 @@ def __init__( self._initialize_graph_transition_manager() + # Will be reset everytime we re-initialize the graph builder. + # Be noted, we will never enable this feature for inference mode. + self._mem_efficient_grad_management_is_enabled = False + def _get_torch_gpu_allocator_function_addresses(self): if self._runtime_options.use_external_gpu_allocator and torch.cuda.is_available(): # CPP extension to get torch GPU allocator's alloc and free function addresses @@ -351,12 +355,8 @@ def _detect_from_inputs(self, inputs: Tuple, kwargs: Dict): enable sparsity-based optimization. """ - detected_device = _utils.get_device_from_module(self._original_module) or _utils.get_device_from_inputs( - inputs, kwargs - ) - if self._runtime_options.enable_zero_stage3_support or self._mem_efficient_grad_management_is_enabled: - self._append_pull_weight_trigger_as_input(kwargs, detected_device) + self._append_pull_weight_trigger_as_input(kwargs, self._device) if ( self._runtime_inspector.memory_ob.is_enabled() diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index 5be0d38513029..37b8a441134d9 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -675,7 +675,11 @@ def _post_export_process( # Override the options if model is not modified. - post_processed_model = post_processing_enable_mem_efficient_training( + ( + stage3_param_handle._mem_efficient_grad_management_is_enabled, + post_processed_model, + stage3_param_handle._param_trigger_grad, + ) = post_processing_enable_mem_efficient_training( post_processed_model, flatten_module.named_parameters(), parameter_not_as_graph_input_names ) From fd2c95a211059a5fac8e69950a6b6fd0b7f61b5d Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Thu, 13 Jun 2024 06:58:50 +0000 Subject: [PATCH 31/32] fix ut --- .../python/training/ortmodule/_graph_transition_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index 37b8a441134d9..248b89d8ccb18 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -27,7 +27,7 @@ unflatten_data_using_schema, ) -from . import _io, _utils +from . import _io, _utils, export_context from ._fallback import ORTModuleDeviceException, ORTModuleIOError, ORTModuleONNXModelException, wrap_exception from ._logger import LogColor, LogLevel, ORTModuleInitPhase, SuppressLogs, TimeTracker, TrackTimeForStaticFunction from ._onnx_models import _get_onnx_file_name, _save_model @@ -752,7 +752,7 @@ def _export_model( torch_exporter_verbose_log = debug_options.log_level < LogLevel.WARNING from onnxruntime.training.utils.hooks._subscriber_manager import no_increase_global_step - with no_increase_global_step(): + with export_context(), no_increase_global_step(): exported_model, module_output_schema = GraphTransitionManager._get_exported_model( flattened_module=flattened_module, model_info_for_export=model_info_for_export, From 870aa3057cef4afeece4f55dc69977bc254b0e33 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Thu, 13 Jun 2024 11:11:20 +0000 Subject: [PATCH 32/32] fix --- .../training/ortmodule/_graph_execution_manager.py | 12 ++++++------ .../training/ortmodule/_graph_transition_manager.py | 8 ++++---- .../python/orttraining_test_ortmodule_autograd.py | 4 +++- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 91b6e663d010b..18999ce2fa1ab 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -125,10 +125,6 @@ def __init__( self._initialize_graph_transition_manager() - # Will be reset everytime we re-initialize the graph builder. - # Be noted, we will never enable this feature for inference mode. - self._mem_efficient_grad_management_is_enabled = False - def _get_torch_gpu_allocator_function_addresses(self): if self._runtime_options.use_external_gpu_allocator and torch.cuda.is_available(): # CPP extension to get torch GPU allocator's alloc and free function addresses @@ -235,7 +231,8 @@ def _get_session_config(self): # Enable memory efficient execution order for training if 1). memory efficient grad management is enabled # or 2). memory optimizer is enabled. use_memory_efficient_topo_sort = (self._export_mode == torch.onnx.TrainingMode.TRAINING) and ( - self._mem_efficient_grad_management_is_enabled or self._runtime_options.memory_optimizer_is_enabled() + self._graph_transition_manager._post_export_processed_model_info.is_mem_efficient_grad_management_enabled + or self._runtime_options.memory_optimizer_is_enabled() ) session_options.execution_order = ( onnxruntime.ExecutionOrder.MEMORY_EFFICIENT @@ -355,7 +352,10 @@ def _detect_from_inputs(self, inputs: Tuple, kwargs: Dict): enable sparsity-based optimization. """ - if self._runtime_options.enable_zero_stage3_support or self._mem_efficient_grad_management_is_enabled: + if ( + self._runtime_options.enable_zero_stage3_support + or self._graph_transition_manager._post_export_processed_model_info.is_mem_efficient_grad_management_enabled + ): self._append_pull_weight_trigger_as_input(kwargs, self._device) if ( diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index 248b89d8ccb18..80bb00e0c3ac1 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -168,7 +168,7 @@ def __init__( self.onnx_graph_input_const_as_tensor: dict[str, torch.device] | None = onnx_graph_input_const_as_tensor - self._enable_mem_efficient_grad_management = enable_mem_efficient_grad_management + self.is_mem_efficient_grad_management_enabled = enable_mem_efficient_grad_management # Used for unflattening the outputs from the ORT forward run. self.module_forward_output_schema: ORTModelInputOutputSchemaType | None = module_forward_output_schema @@ -227,7 +227,7 @@ def construct_inputs( ) for name in self.onnx_graph_input_names_user_defined: - if self._enable_mem_efficient_grad_management and name == MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME: + if self.is_mem_efficient_grad_management_enabled and name == MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME: self._buffer_for_ort_runs[name] = torch.zeros( MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE, dtype=onnx_dtype_to_pytorch_dtype(MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE), @@ -676,7 +676,7 @@ def _post_export_process( # Override the options if model is not modified. ( - stage3_param_handle._mem_efficient_grad_management_is_enabled, + enable_mem_efficient_grad_management, # Update the flag to indicate the mem efficient grad management is enabled. post_processed_model, stage3_param_handle._param_trigger_grad, ) = post_processing_enable_mem_efficient_training( @@ -732,7 +732,7 @@ def _export_model( onnx_opset_version: int, stage3_param_handle: type, debug_options: DebugOptions, - time_tracker: TimeTracker, + time_tracker: TimeTracker, # time_tracker MUST be provided here to support TrackTimeForStaticFunction runtime_inspector: RuntimeInspector, logger: logging.Logger, ) -> tuple[onnx.ModelProto, ORTModelInputOutputSchemaType, list[str], list[str]]: diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py index 5cbb2aacbe245..95012aa0507a5 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py @@ -1839,7 +1839,9 @@ def forward(self, model_input): ortmodule = ORTModule(TestModel(output_size)).train() _ = ortmodule(torch.randn(output_size, dtype=torch.float)) - onnx_nodes = ortmodule._torch_module._execution_manager._training_manager._onnx_models.exported_model.graph.node + onnx_nodes = ( + ortmodule._torch_module._execution_manager._training_manager._graph_transition_manager._exported_model_info.exported_model.graph.node + ) found_pythonop = False for node in onnx_nodes: