Skip to content

Commit

Permalink
remove stage3 related change
Browse files Browse the repository at this point in the history
  • Loading branch information
pengwa committed Dec 25, 2023
1 parent 3c3b4bf commit c585bc8
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -529,11 +529,13 @@ def _initialize_graph_builder(self):

# Add stage3 pull weight trigger name to require_grad_names, so that it will be included in the gradient graph.
input_names_require_grad.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME)

if self._runtime_options.enable_mem_efficient_grad_management:
from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME

# Add stage3 mem efficient grad trigger name to require_grad_names, so that it will be included in the gradient graph.
input_names_require_grad.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME)
# Add mem efficient grad trigger name to require_grad_names, so that it will be included in the gradient graph.

input_names_require_grad.insert(0, MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME)
grad_builder_config.input_names_require_grad = input_names_require_grad
grad_builder_config.build_gradient_graph = self._export_mode == torch.onnx.TrainingMode.TRAINING
grad_builder_config.enable_caching = self._runtime_options.enable_grad_acc_optimization
Expand Down Expand Up @@ -613,10 +615,20 @@ def _enable_conditional_optimizations(
self._runtime_options.enable_zero_stage3_support
or self._runtime_options.enable_mem_efficient_grad_management
):
self._append_pull_weight_trigger_as_input(kwargs, detected_device)
kwargs = self._append_pull_weight_trigger_as_input(kwargs, detected_device)

param_to_append_as_onnx_graph_inputs = []
if self._runtime_options.enable_mem_efficient_grad_management:
from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger

param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger(
self._flattened_module.named_parameters()
)
else:
param_to_append_as_onnx_graph_inputs = self._graph_initializers

_, embed_sparsity_results, label_sparsity_results = _io._combine_input_buffers_initializers(
self._graph_initializers,
param_to_append_as_onnx_graph_inputs,
self._graph_builder.get_graph_info().user_input_names,
self._input_info,
self._flattened_module.named_buffers(),
Expand Down Expand Up @@ -648,25 +660,43 @@ def _enable_conditional_optimizations(
self._runtime_inspector.disable_input_inspector()

def _append_pull_weight_trigger_as_input(self, kwargs: Dict, device: torch.device):
from ._mem_efficient_grad_mgmt import (
MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
)
if self._runtime_options.enable_zero_stage3_support:
from ._zero_stage3_compatibility import (
STAGE3_PULL_WEIGHT_TRIGGER_NAME,
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE,
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE,
)

new_kwargs = {
MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME: torch.zeros(
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
dtype=onnx_dtype_to_pytorch_dtype(MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE),
kwargs[STAGE3_PULL_WEIGHT_TRIGGER_NAME] = torch.zeros(
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE,
dtype=onnx_dtype_to_pytorch_dtype(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE),
device=device,
).requires_grad_()
}

# Then the trigger input will be the first user input.
return {
**new_kwargs,
**kwargs,
}
return kwargs

if self._runtime_options.enable_mem_efficient_grad_management:
from ._mem_efficient_grad_mgmt import (
MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
)

new_kwargs = {
MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME: torch.zeros(
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
dtype=onnx_dtype_to_pytorch_dtype(MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE),
device=device,
).requires_grad_()
}

# Then the trigger input will be the first user input.
return {
**new_kwargs,
**kwargs,
}

return kwargs

def _log_feature_stats(self):
if get_rank() != 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,20 @@ def forward(self, *inputs, **kwargs):
self._runtime_options.enable_zero_stage3_support
or self._runtime_options.enable_mem_efficient_grad_management
):
self._append_pull_weight_trigger_as_input(kwargs, self._device)
kwargs = self._append_pull_weight_trigger_as_input(kwargs, self._device)

param_to_append_as_onnx_graph_inputs = []
if self._runtime_options.enable_mem_efficient_grad_management:
from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger

param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger(
self._flattened_module.named_parameters()
)
else:
param_to_append_as_onnx_graph_inputs = self._graph_initializers

prepared_input_list, _, _ = _io._combine_input_buffers_initializers(
self._graph_initializers,
param_to_append_as_onnx_graph_inputs,
self._graph_info.user_input_names,
self._input_info,
self._flattened_module.named_buffers(),
Expand Down
12 changes: 6 additions & 6 deletions orttraining/orttraining/python/training/ortmodule/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,12 @@ def _expand_inputs(current_input, non_none_inputs, name=""):
)

# params is a list of all initializers known to the onnx graph
# if zero_stage3_offload_param_map:
# for p in params:
# if p not in zero_stage3_offload_param_map.values():
# result.append(p)
# else:
# result.extend(params)
if zero_stage3_offload_param_map:
for p in params:
if p not in zero_stage3_offload_param_map.values():
result.append(p)
else:
result.extend(params)

if rt_inspector.memory_ob.is_enabled() and not rt_inspector.memory_ob.symbolic_dim_collecting_completed:
rt_inspector.memory_ob.collect_symbolic_dim_values(input_info.dynamic_axes, onnx_input_to_value_map)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE = [1]


def get_params_connected_to_pull_param_trigger(named_params: Dict[str, torch.nn.parameter.Parameter]):
return {k: v for k, v in named_params if v.requires_grad}


def get_params_not_connected_to_pull_param_trigger(named_params: Dict[str, torch.nn.parameter.Parameter]):
return [v for k, v in named_params if not v.requires_grad]


def post_processing_enable_mem_efficient_training(
exported_model: ModelProto,
named_params: Dict[str, torch.nn.parameter.Parameter],
Expand All @@ -29,7 +37,7 @@ def post_processing_enable_mem_efficient_training(
exported_model (ModelProto): The exported model.
named_params (Optional[Dict[str, torch.nn.parameter.Parameter]]): The full parameter map.
"""
trainable_named_params = {k: v for k, v in named_params if v.requires_grad}
trainable_named_params = get_params_connected_to_pull_param_trigger(named_params)

# Create weight retrieving function using trainable_named_params.
param_pull_trigger_func_class = _create_param_trigger_function(trainable_named_params)
Expand Down Expand Up @@ -75,7 +83,8 @@ def _get_param_pull_trigger_name(param_name: str) -> str:
)

graph_inputs_to_remove = []
for graph_input in reversed(exported_model.graph.input):
input_offset = 0
for graph_input in exported_model.graph.input:
if graph_input.name not in trainable_named_params:
continue

Expand Down Expand Up @@ -110,7 +119,8 @@ def _get_param_pull_trigger_name(param_name: str) -> str:
training_mode=1,
safe_run_mode=0,
)
exported_model.graph.node.insert(0, new_node)
exported_model.graph.node.insert(input_offset, new_node)
input_offset += 1

# Delete exported_model.graph.input
for input_to_remove in graph_inputs_to_remove:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,20 @@ def forward(self, *inputs, **kwargs):
self._runtime_options.enable_zero_stage3_support
or self._runtime_options.enable_mem_efficient_grad_management
):
self._append_pull_weight_trigger_as_input(kwargs, self._device)
kwargs = self._append_pull_weight_trigger_as_input(kwargs, self._device)

param_to_append_as_onnx_graph_inputs = []
if self._runtime_options.enable_mem_efficient_grad_management:
from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger

param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger(
self._flattened_module.named_parameters()
)
else:
param_to_append_as_onnx_graph_inputs = self._graph_initializers

prepared_input_list, _, _ = _io._combine_input_buffers_initializers(
self._graph_initializers,
param_to_append_as_onnx_graph_inputs,
self._graph_info.user_input_names,
self._input_info,
self._flattened_module.named_buffers(),
Expand Down

0 comments on commit c585bc8

Please sign in to comment.