From 444a0eda309e0fadf51c63790b6da78258f96a10 Mon Sep 17 00:00:00 2001 From: pengwa Date: Sat, 21 Oct 2023 19:45:45 +0800 Subject: [PATCH] Avoid one time clone to save memory peak (#17934) ### Avoid one more time clone to save memory peak --- .../_custom_autograd_function_runner.py | 55 +++++++++++-------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py index b9318033a3d53..dd32e2aced561 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py @@ -245,6 +245,8 @@ def _process_inplace_outputs( if not copied: # Only need a copy once. + # Inplace copy only happens for non-leaf variables, so we have to set requires_grad to False. + raw_input_tensor.requires_grad = False raw_input_tensor.copy_(all_outputs_of_kernel_run[output_index]) _log_warning( f"{log_prefix}Copy output tensor {output_index} to raw input tensor {raw_tensor_input_index}. " @@ -449,7 +451,8 @@ def call_python_forward_function( try: func_name = func_name.decode("utf-8") if isinstance(func_name, bytes) else func_name # If this is the first time run, collect runtime tensor reuse mapping. - if kernel_invoke_id not in _GlobalOpKernelInfoMap: + is_first_time_run = kernel_invoke_id not in _GlobalOpKernelInfoMap + if is_first_time_run: kernel_info = CustomFuncOpKernelInfo(kernel_invoke_id) _GlobalOpKernelInfoMap[kernel_invoke_id] = kernel_info @@ -473,6 +476,11 @@ def call_python_forward_function( if tensor_input_index in inplace_map: raw_input_tensors_used_inplace[tensor_input_index] = wrapped_arg + # Only requires gradient when running under training mode + # and the associated tensor has grad_flag=True (i.e., + # "requires_grad=True" in the original PyTorch script). + wrapped_arg.requires_grad = is_training_mode and grad_flag + # Note1: # If it's first-time kernel invocation, tensor_input_indices_to_save_in_ctx is None, we do the # copy for all tensors. Otherwise, we only copy the tensors whose indices are in @@ -480,29 +488,30 @@ def call_python_forward_function( # Note2: # For inference mode, we don't need to do the copy because ctx will be None, # so nothing will be saved for ctx. - if is_training_mode and ( - tensor_input_indices_to_save_in_ctx is None - or tensor_input_index in tensor_input_indices_to_save_in_ctx - ): - wrapped_arg = wrapped_arg.detach().clone() - - # Only requires gradient when running under training mode - # and the associated tensor has grad_flag=True (i.e., - # "requires_grad=True" in the original PyTorch script). - wrapped_arg.requires_grad = is_training_mode and grad_flag - # Note3: - # If it's not first-time kernel invocation, tensor_input_indices_for_mark_dirty is None, we do the - # mul for all tensors. Otherwise, we only mul by one for the tensors whose indices are in - # tensor_input_indices_for_mark_dirty. - if is_training_mode and ( - tensor_input_indices_for_mark_dirty is None - or tensor_input_index in tensor_input_indices_for_mark_dirty - ): - # To fix this issue: - # "a leaf Variable that requires grad has been used in an in-place operation." - with torch.set_grad_enabled(True): - wrapped_arg = wrapped_arg.clone() + # To fix this issue: + # "a leaf Variable that requires grad has been used in an in-place operation." + # If it's first-time kernel invocation, tensor_input_indices_for_mark_dirty is None, we do the + # copy for all tensors to generate grad for it. Otherwise, we only clone (to generate grad) for + # the tensors whose indices are in tensor_input_indices_for_mark_dirty. + if is_training_mode: + if is_first_time_run: + with torch.set_grad_enabled(True): + wrapped_arg = wrapped_arg.clone() + else: + is_input_index_saved_in_ctx = ( + tensor_input_indices_to_save_in_ctx is None + or tensor_input_index in tensor_input_indices_to_save_in_ctx + ) + is_input_index_marked_dirty = ( + tensor_input_indices_for_mark_dirty is None + or tensor_input_index in tensor_input_indices_for_mark_dirty + ) + if is_input_index_saved_in_ctx or is_input_index_marked_dirty: + # when with grad, the leaf tensor after clone will not be leaf. + with torch.set_grad_enabled(is_input_index_marked_dirty): + wrapped_arg = wrapped_arg.clone() + wrapped_arg.requires_grad = is_training_mode and grad_flag wrapped_args.append(wrapped_arg) input_tensors_used_for_fw_run[tensor_input_index] = wrapped_arg