Skip to content

Commit

Permalink
fix again
Browse files Browse the repository at this point in the history
  • Loading branch information
pengwa committed Dec 28, 2023
1 parent e6a733f commit 28f7c9e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ def __init__(

configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="ort_output", stats_overwrite=True)

# Will be reset everytime we re-initialize the graph builder.
# Be noted, we will never enable this feature for inference mode.
self._mem_efficient_grad_management_is_enabled = False

def _get_torch_gpu_allocator_function_addresses(self):
if self._runtime_options.use_external_gpu_allocator and torch.cuda.is_available():
# CPP extension to get torch GPU allocator's alloc and free function addresses
Expand Down Expand Up @@ -497,16 +501,22 @@ def _get_graph_transformer_config(self) -> C.TrainingGraphTransformerConfigurati
def _initialize_graph_builder(self):
"""Creates a new OrtModuleGraphBuilder, initializes it and saves it to self._graph_builder"""

self._mem_efficient_grad_management_is_enabled = (
self._export_mode != torch.onnx.TrainingMode.EVAL
and self._runtime_options.enable_mem_efficient_grad_management
)

# We post process the exported model because the trainable parame might be changed, so this path is
# re-triggered by reinitialize_graph_builder.
exported_model = copy.deepcopy(self._onnx_models.exported_model)
self._onnx_models.processed_exported_model = exported_model
if self._runtime_options.enable_mem_efficient_grad_management:

if self._mem_efficient_grad_management_is_enabled:
from ._mem_efficient_grad_mgmt import post_processing_enable_mem_efficient_training

# Override the options if model is not modified.
(
self._runtime_options.enable_mem_efficient_grad_management,
self._mem_efficient_grad_management_is_enabled,
exported_model,
) = post_processing_enable_mem_efficient_training(exported_model, self._flattened_module.named_parameters())

Expand Down Expand Up @@ -543,7 +553,7 @@ def _initialize_graph_builder(self):
# Add stage3 pull weight trigger name to require_grad_names, so that it will be included in the gradient graph.
input_names_require_grad.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME)

if self._runtime_options.enable_mem_efficient_grad_management:
if self._mem_efficient_grad_management_is_enabled:
from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME

# Add mem efficient grad trigger name to require_grad_names, so that it will be included in the gradient graph.
Expand Down Expand Up @@ -635,14 +645,11 @@ def _enable_conditional_optimizations(
inputs, kwargs
)

if (
self._runtime_options.enable_zero_stage3_support
or self._runtime_options.enable_mem_efficient_grad_management
):
if self._runtime_options.enable_zero_stage3_support or self._mem_efficient_grad_management_is_enabled:
self._append_pull_weight_trigger_as_input(kwargs, detected_device)

param_to_append_as_onnx_graph_inputs = []
if self._runtime_options.enable_mem_efficient_grad_management:
if self._mem_efficient_grad_management_is_enabled:
from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger

param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger(
Expand Down Expand Up @@ -697,7 +704,7 @@ def _append_pull_weight_trigger_as_input(self, kwargs: Dict, device: torch.devic
device=device,
).requires_grad_()

if self._runtime_options.enable_mem_efficient_grad_management:
if self._mem_efficient_grad_management_is_enabled:
from ._mem_efficient_grad_mgmt import (
MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,14 +310,11 @@ def forward(self, *inputs, **kwargs):

self._gradient_accumulation_manager.maybe_update_cache_before_run()

if (
self._runtime_options.enable_zero_stage3_support
or self._runtime_options.enable_mem_efficient_grad_management
):
if self._runtime_options.enable_zero_stage3_support or self._mem_efficient_grad_management_is_enabled:
self._append_pull_weight_trigger_as_input(kwargs, self._device)

param_to_append_as_onnx_graph_inputs = []
if self._runtime_options.enable_mem_efficient_grad_management:
if self._mem_efficient_grad_management_is_enabled:
from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger

param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger(
Expand Down Expand Up @@ -506,7 +503,7 @@ def _reinitialize_graph_builder(self, input_info: _InputInfo):
if param.requires_grad and name in self._graph_initializer_names
}

if self._runtime_options.enable_mem_efficient_grad_management:
if self._mem_efficient_grad_management_is_enabled:
from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME

# Remove the inputs we added during model post-processing.
Expand Down

0 comments on commit 28f7c9e

Please sign in to comment.