Skip to content

Commit

Permalink
Avoid one time clone to save memory peak (#17934)
Browse files Browse the repository at this point in the history
### Avoid one more time clone to save memory peak
  • Loading branch information
pengwa authored Oct 21, 2023
1 parent 009cd4e commit 444a0ed
Showing 1 changed file with 32 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ def _process_inplace_outputs(

if not copied:
# Only need a copy once.
# Inplace copy only happens for non-leaf variables, so we have to set requires_grad to False.
raw_input_tensor.requires_grad = False
raw_input_tensor.copy_(all_outputs_of_kernel_run[output_index])
_log_warning(
f"{log_prefix}Copy output tensor {output_index} to raw input tensor {raw_tensor_input_index}. "
Expand Down Expand Up @@ -449,7 +451,8 @@ def call_python_forward_function(
try:
func_name = func_name.decode("utf-8") if isinstance(func_name, bytes) else func_name
# If this is the first time run, collect runtime tensor reuse mapping.
if kernel_invoke_id not in _GlobalOpKernelInfoMap:
is_first_time_run = kernel_invoke_id not in _GlobalOpKernelInfoMap
if is_first_time_run:
kernel_info = CustomFuncOpKernelInfo(kernel_invoke_id)
_GlobalOpKernelInfoMap[kernel_invoke_id] = kernel_info

Expand All @@ -473,36 +476,42 @@ def call_python_forward_function(
if tensor_input_index in inplace_map:
raw_input_tensors_used_inplace[tensor_input_index] = wrapped_arg

# Only requires gradient when running under training mode
# and the associated tensor has grad_flag=True (i.e.,
# "requires_grad=True" in the original PyTorch script).
wrapped_arg.requires_grad = is_training_mode and grad_flag

# Note1:
# If it's first-time kernel invocation, tensor_input_indices_to_save_in_ctx is None, we do the
# copy for all tensors. Otherwise, we only copy the tensors whose indices are in
# tensor_input_indices_to_save_in_ctx.
# Note2:
# For inference mode, we don't need to do the copy because ctx will be None,
# so nothing will be saved for ctx.
if is_training_mode and (
tensor_input_indices_to_save_in_ctx is None
or tensor_input_index in tensor_input_indices_to_save_in_ctx
):
wrapped_arg = wrapped_arg.detach().clone()

# Only requires gradient when running under training mode
# and the associated tensor has grad_flag=True (i.e.,
# "requires_grad=True" in the original PyTorch script).
wrapped_arg.requires_grad = is_training_mode and grad_flag

# Note3:
# If it's not first-time kernel invocation, tensor_input_indices_for_mark_dirty is None, we do the
# mul for all tensors. Otherwise, we only mul by one for the tensors whose indices are in
# tensor_input_indices_for_mark_dirty.
if is_training_mode and (
tensor_input_indices_for_mark_dirty is None
or tensor_input_index in tensor_input_indices_for_mark_dirty
):
# To fix this issue:
# "a leaf Variable that requires grad has been used in an in-place operation."
with torch.set_grad_enabled(True):
wrapped_arg = wrapped_arg.clone()
# To fix this issue:
# "a leaf Variable that requires grad has been used in an in-place operation."
# If it's first-time kernel invocation, tensor_input_indices_for_mark_dirty is None, we do the
# copy for all tensors to generate grad for it. Otherwise, we only clone (to generate grad) for
# the tensors whose indices are in tensor_input_indices_for_mark_dirty.
if is_training_mode:
if is_first_time_run:
with torch.set_grad_enabled(True):
wrapped_arg = wrapped_arg.clone()
else:
is_input_index_saved_in_ctx = (
tensor_input_indices_to_save_in_ctx is None
or tensor_input_index in tensor_input_indices_to_save_in_ctx
)
is_input_index_marked_dirty = (
tensor_input_indices_for_mark_dirty is None
or tensor_input_index in tensor_input_indices_for_mark_dirty
)
if is_input_index_saved_in_ctx or is_input_index_marked_dirty:
# when with grad, the leaf tensor after clone will not be leaf.
with torch.set_grad_enabled(is_input_index_marked_dirty):
wrapped_arg = wrapped_arg.clone()
wrapped_arg.requires_grad = is_training_mode and grad_flag

wrapped_args.append(wrapped_arg)
input_tensors_used_for_fw_run[tensor_input_index] = wrapped_arg
Expand Down

0 comments on commit 444a0ed

Please sign in to comment.