Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pengwa committed Dec 28, 2023
1 parent 28b8417 commit e6a733f
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 65 deletions.
3 changes: 2 additions & 1 deletion onnxruntime/core/framework/execution_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ void IExecutionFrame::Init(gsl::span<const int> feed_mlvalue_idxs, gsl::span<con
const std::unordered_map<int, OrtValue>& initializers,
const std::function<bool(const std::string& name)>& is_initializer_sparse_func,
gsl::span<const OrtValue> fetches) {
ORT_ENFORCE(feeds.size() == feed_mlvalue_idxs.size());
ORT_ENFORCE(feeds.size() == feed_mlvalue_idxs.size(), "Get feed size: ", feeds.size(), " but expected feed size: ",
feed_mlvalue_idxs.size());
ORT_ENFORCE(fetches.empty() || fetches.size() == fetch_mlvalue_idxs_.size());

// Need this for sparse conversions in host memory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -438,13 +438,6 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu

exported_model = post_process_enabling_autograd_function(exported_model)

if self._runtime_options.enable_mem_efficient_grad_management:
from ._mem_efficient_grad_mgmt import post_processing_enable_mem_efficient_training

exported_model = post_processing_enable_mem_efficient_training(
exported_model, self._flattened_module.named_parameters()
)

if self._runtime_options.enable_zero_stage3_support:
from ._zero_stage3_compatibility import post_processing_enable_zero_stage3_compat

Expand Down Expand Up @@ -504,9 +497,29 @@ def _get_graph_transformer_config(self) -> C.TrainingGraphTransformerConfigurati
def _initialize_graph_builder(self):
"""Creates a new OrtModuleGraphBuilder, initializes it and saves it to self._graph_builder"""

# We post process the exported model because the trainable parame might be changed, so this path is
# re-triggered by reinitialize_graph_builder.
exported_model = copy.deepcopy(self._onnx_models.exported_model)
self._onnx_models.processed_exported_model = exported_model
if self._runtime_options.enable_mem_efficient_grad_management:
from ._mem_efficient_grad_mgmt import post_processing_enable_mem_efficient_training

# Override the options if model is not modified.
(
self._runtime_options.enable_mem_efficient_grad_management,
exported_model,
) = post_processing_enable_mem_efficient_training(exported_model, self._flattened_module.named_parameters())

# if self._runtime_options.run_symbolic_shape_infer:
# exported_model = SymbolicShapeInference.infer_shapes(
# exported_model, auto_merge=True, guess_output_rank=True
# )

# All initializer names along with user inputs are a part of the onnx graph inputs
# since the onnx model was exported with the flag keep_initializers_as_inputs=True
onnx_initializer_names = {p.name for p in self._onnx_models.exported_model.graph.input}
# We need to use the raw exported model here since the graph inputs include both user inputrs and
# parameters.
onnx_initializer_names = {p.name for p in exported_model.graph.input}

# TODO: PyTorch exporter bug: changes the initializer order in ONNX model
initializer_names = [
Expand Down Expand Up @@ -535,6 +548,7 @@ def _initialize_graph_builder(self):

# Add mem efficient grad trigger name to require_grad_names, so that it will be included in the gradient graph.
input_names_require_grad.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME)

grad_builder_config.input_names_require_grad = input_names_require_grad
grad_builder_config.build_gradient_graph = self._export_mode == torch.onnx.TrainingMode.TRAINING
grad_builder_config.enable_caching = self._runtime_options.enable_grad_acc_optimization
Expand All @@ -546,12 +560,23 @@ def _initialize_graph_builder(self):

# It is assumed here that the order and names of the inputs and outputs are not modified by the backend in any way
# and are kept as they appear in the exported onnx model.
self._graph_builder.initialize(self._onnx_models.exported_model.SerializeToString(), grad_builder_config)
self._graph_builder.initialize(exported_model.SerializeToString(), grad_builder_config)

raw_onnx_initializer_names = {p.name for p in self._onnx_models.exported_model.graph.input}

raw_initializer_names = [
name for name, _ in self._flattened_module.named_parameters() if name in raw_onnx_initializer_names
]
raw_initializer_names_to_train = [
name
for name, param in self._flattened_module.named_parameters()
if param.requires_grad and name in raw_onnx_initializer_names
]

# TODO: Explore ways to make self._graph_info.initializer_names and self._graph_info.initializer_names_to_train
# a set (unordered_set in the backend) that does not require a copy on each reference.
self._graph_initializer_names = set(initializer_names)
self._graph_initializer_names_to_train = set(initializer_names_to_train)
self._graph_initializer_names = set(raw_initializer_names)
self._graph_initializer_names_to_train = set(raw_initializer_names_to_train)

# Initializers can be cached and used since they are expected not to be re-instantiated
# between forward calls.
Expand Down Expand Up @@ -602,7 +627,7 @@ def _enable_conditional_optimizations(
# Enable data sparsity inspection if sparse optimizer is ON or user wants to print input density.
if self._runtime_options.enable_sparse_optimizer or self._runtime_options.print_input_density:
self._runtime_inspector.enable_input_inspector(
self._onnx_models.exported_model, self._graph_builder.get_graph_info().user_input_names
self._onnx_models.processed_exported_model, self._graph_builder.get_graph_info().user_input_names
)

if self._runtime_options.enable_sparse_optimizer:
Expand All @@ -621,7 +646,7 @@ def _enable_conditional_optimizations(
from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger

param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger(
self._flattened_module.named_parameters()
self._flattened_module.named_parameters(), self._onnx_models.exported_model
)
else:
param_to_append_as_onnx_graph_inputs = self._graph_initializers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,21 +159,10 @@ def forward(self, *inputs, **kwargs):
# Assert that the input and model device match
_utils._check_same_device(self._device, "Input argument to forward", *inputs)

if (
self._runtime_options.enable_zero_stage3_support
or self._runtime_options.enable_mem_efficient_grad_management
):
if self._runtime_options.enable_zero_stage3_support:
self._append_pull_weight_trigger_as_input(kwargs, self._device)

param_to_append_as_onnx_graph_inputs = []
if self._runtime_options.enable_mem_efficient_grad_management:
from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger

param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger(
self._flattened_module.named_parameters()
)
else:
param_to_append_as_onnx_graph_inputs = self._graph_initializers
param_to_append_as_onnx_graph_inputs = self._graph_initializers

prepared_input_list, _, _ = _io._combine_input_buffers_initializers(
param_to_append_as_onnx_graph_inputs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from __future__ import annotations

import ctypes
from typing import Dict, List, Optional, Tuple, Union

import torch
from onnx import ModelProto, NodeProto, TensorProto, helper
Expand All @@ -19,39 +19,45 @@
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE = [1]


def get_params_connected_to_pull_param_trigger(named_params: Dict[str, torch.nn.parameter.Parameter]):
return {k: v for k, v in named_params if v.requires_grad}
def get_params_connected_to_pull_param_trigger(
named_params: dict[str, torch.nn.parameter.Parameter], exported_model: ModelProto
):
# Be noted, some parameters might not in graph input because they are not used in forward, so we filtered them also.
onnx_initializer_names = {p.name for p in exported_model.graph.input}
return {k: v for k, v in named_params if v.requires_grad and k in onnx_initializer_names}


def get_params_not_connected_to_pull_param_trigger(named_params: Dict[str, torch.nn.parameter.Parameter]):
return [v for k, v in named_params if not v.requires_grad]
def get_params_not_connected_to_pull_param_trigger(
named_params: dict[str, torch.nn.parameter.Parameter], exported_model: ModelProto
):
# Be noted, some parameters might not in graph input because they are not used in forward, so we filtered them also.
onnx_initializer_names = {p.name for p in exported_model.graph.input}
return [v for k, v in named_params if not v.requires_grad and k in onnx_initializer_names]


def post_processing_enable_mem_efficient_training(
exported_model: ModelProto,
named_params: Dict[str, torch.nn.parameter.Parameter],
) -> ModelProto:
named_params: dict[str, torch.nn.parameter.Parameter],
) -> tuple[bool, ModelProto]:
"""This function is used to enable zero stage3 compatibility.
Args:
exported_model (ModelProto): The exported model.
named_params (Optional[Dict[str, torch.nn.parameter.Parameter]]): The full parameter map.
Returns:
tuple[bool, ModelProto]: A tuple of bool and ModelProto. The bool indicates whether the model is modified.
"""
trainable_named_params = get_params_connected_to_pull_param_trigger(named_params)
trainable_named_params = get_params_connected_to_pull_param_trigger(named_params, exported_model)
# print(exported_model.graph.input)
if len(trainable_named_params) == 0:
return False, exported_model

# Create weight retrieving function using trainable_named_params.
param_pull_trigger_func_class = _create_param_trigger_function(trainable_named_params)
param_retrieve_func_class = _create_param_retrieval_function(trainable_named_params)

consumer_map = {}
for node in exported_model.graph.node:
for inp in node.input:
if inp not in consumer_map:
consumer_map[inp] = []

if node not in consumer_map[inp]:
consumer_map[inp].append(node)

def _get_param_pull_trigger_name(param_name: str) -> str:
return f"pull_{param_name}"

Expand Down Expand Up @@ -90,9 +96,6 @@ def _get_param_pull_trigger_name(param_name: str) -> str:

graph_inputs_to_remove.append(graph_input)

if graph_input.name not in consumer_map:
continue

# Create the param retrieval function for this parameter.
node_inputs = [
helper.make_tensor_value_info(
Expand Down Expand Up @@ -123,6 +126,14 @@ def _get_param_pull_trigger_name(param_name: str) -> str:
input_offset += 1

# Delete exported_model.graph.input

names_to_remove = [input.name for input in graph_inputs_to_remove]
value_infos_to_remove = [
value_info for value_info in exported_model.graph.value_info if value_info.name in names_to_remove
]
for value_info in value_infos_to_remove:
exported_model.graph.value_info.remove(value_info)

for input_to_remove in graph_inputs_to_remove:
exported_model.graph.input.remove(input_to_remove)

Expand All @@ -135,13 +146,13 @@ def _get_param_pull_trigger_name(param_name: str) -> str:
exported_model.graph.input.insert(offset, inputs[0])
exported_model.graph.node.insert(0, weight_pull_node)

return exported_model
return True, exported_model


_PARAM_FUNCTION_INDEX = [0]


def _create_param_trigger_function(trainable_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]]):
def _create_param_trigger_function(trainable_named_params: dict[str, torch.nn.parameter.Parameter]):
"""This function is used to create a weight retrieving function using trainable_named_params."""

@staticmethod
Expand All @@ -160,9 +171,9 @@ def backward(ctx, *grad_outputs):
@staticmethod
def infer_shape(
node: NodeProto,
tensor_input_shapes: List[Optional[List[Union[int, str]]]],
tensor_input_dtypes: List[torch.onnx.TensorProtoDataType],
) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]:
tensor_input_shapes: list[list[int | str] | None],
tensor_input_dtypes: list[torch.onnx.TensorProtoDataType],
) -> tuple[list[list[int | str] | None], list[torch.onnx.TensorProtoDataType]]:
param_count = len(trainable_named_params.values())
tensor_output_shapes = [
tensor_input_shapes[0],
Expand All @@ -186,7 +197,7 @@ def infer_shape(
)


def _create_param_retrieval_function(trainable_named_params: Dict[str, torch.nn.parameter.Parameter]):
def _create_param_retrieval_function(trainable_named_params: dict[str, torch.nn.parameter.Parameter]):
"""This function is used to create a weight retrieving function using trainable_named_params."""

@staticmethod
Expand All @@ -205,9 +216,9 @@ def backward(ctx, *grad_outputs):
@staticmethod
def infer_shape(
node: NodeProto,
tensor_input_shapes: List[Optional[List[Union[int, str]]]],
tensor_input_dtypes: List[torch.onnx.TensorProtoDataType],
) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]:
tensor_input_shapes: list[list[int | str] | None],
tensor_input_dtypes: list[torch.onnx.TensorProtoDataType],
) -> tuple[list[list[int | str] | None], list[torch.onnx.TensorProtoDataType]]:
input_pointer_scalars_attr_name = "input_pointer_scalars"
found = [attr for attr in node.attribute if attr.name == input_pointer_scalars_attr_name]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class ONNXModels:
"""

exported_model: Optional[onnx.ModelProto] = None
processed_exported_model: Optional[onnx.ModelProto] = None
optimized_model: Optional[onnx.ModelProto] = None

def save_exported_model(self, path, name_prefix, export_mode):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,9 @@ def forward(self, *inputs, **kwargs):
from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger

param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger(
self._flattened_module.named_parameters()
self._flattened_module.named_parameters(), self._onnx_models.exported_model
)

else:
param_to_append_as_onnx_graph_inputs = self._graph_initializers

Expand Down Expand Up @@ -505,10 +506,20 @@ def _reinitialize_graph_builder(self, input_info: _InputInfo):
if param.requires_grad and name in self._graph_initializer_names
}

if self._runtime_options.enable_mem_efficient_grad_management:
from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME

# Remove the inputs we added during model post-processing.
existing_require_grad_names = [
n for n in self._input_info.require_grad_names if n != MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME
]
else:
existing_require_grad_names = self._input_info.require_grad_names

# If inputs requiring gradient change from forward to the next, the module_gradient_graph_builder
# needs to be reinitialized so it can compute the backward output for the new inputs that require_grad
if (
input_info.require_grad_names != self._input_info.require_grad_names
input_info.require_grad_names != existing_require_grad_names
or initializer_names_to_train_set_user_model != self._graph_initializer_names_to_train
):
self._input_info = input_info
Expand Down
11 changes: 4 additions & 7 deletions orttraining/orttraining/python/training/ortmodule/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,13 +400,10 @@ def _override_from_env_vars(self):
if "ORTMODULE_ENABLE_ZERO_STAGE3" in os.environ and int(os.getenv("ORTMODULE_ENABLE_ZERO_STAGE3")) == 1:
self.enable_zero_stage3_support = True

if (
"ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT" in os.environ
and int(os.getenv("ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT")) == 1
):
if self.enable_custom_autograd_function:
self.enable_mem_efficient_grad_management = True
else:
if "ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT" in os.environ:
enable_grad_mgmt = int(os.getenv("ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT"))
self.enable_mem_efficient_grad_management = enable_grad_mgmt == 1 and self.enable_custom_autograd_function
if not self.enable_custom_autograd_function and enable_grad_mgmt == 1:
self._logger.warning(
"ORTModule optimization for memory efficient gradient management cannot be enabled "
"because PyTorch custom autograd function support is disabled."
Expand Down

0 comments on commit e6a733f

Please sign in to comment.