From 0987d9e7ae49808c9e957b6822fd24a40bbd5cb4 Mon Sep 17 00:00:00 2001 From: "Peng Wang (AI FWK)" Date: Sat, 21 Oct 2023 06:04:39 -0700 Subject: [PATCH] save --- .../_custom_autograd_function_runner.py | 646 +----------------- .../torch_interop_utils.cc | 229 +++++-- .../utils/hooks/_zero_offload_subscriber.py | 14 +- 3 files changed, 218 insertions(+), 671 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py index 5ea34f0103cc2..24b5ae62184bf 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py @@ -30,419 +30,6 @@ def _log_warning(message: str): warnings.warn(f"[rank-{get_rank()}] {message}") -class CustomFuncOpKernelInfo: - """Store the kernel-specific information retrieved with the first-time run.""" - - def __init__(self, kernel_invoke_id: str): - # kernel_invoke_id is a string contains session thread id, op kernel creation time stamp in ms, a random int, - # and address of op_kernel pointer. This can guarantee the uniqueness of the key in case of multiple - # instances of a same named PythonOp/PythonOpGrad in one session, or multiple sessions. - self.kernel_invoke_id = kernel_invoke_id - - self.position_to_tensor_index_map: Optional[Tuple[Tuple[int, ...], ...]] = None - - # For the tensors generated from ORT backend, there is special handling here: - # 1. For the first time run for the kernel (the uniqueness of the kernel is defined by kernel_invoke_id), - # all such tensors will be cloned in case they are saved in context (but ORT backend is not aware of the - # reference, may release the content of the tensor before it is needed in backward). Once - # `autograd.Function.apply` completes, by checking the existence of the tensor in the saved_tensors, - # `_GlobalOpKernelInfoMap` is updated to save the input indices that are saved in context. - # 2. For the subsequent runs, if the input index is in `tensor_input_indices_to_save_in_ctx`, the tensor - # will be cloned before fed into `autograd.Function.apply` as input. - self.tensor_input_indices_to_save_in_ctx: Optional[Tuple[int, ...]] = None - - # To align with PyTorch `ctx.set_materialize_grads(False|True)`` - # materialize_grads_config is a map from output index to (device, dtype, shape) of the output tensor, used - # for materializing the gradient of the output tensor in backward. - self.materialize_grads: bool = False - self.materialize_grads_config: Optional[Dict[int, Tuple[torch.device, torch.dtype, torch.shape]]] = None - - # For the tensors generated from ORT backend, there is special handling here: - # 1. For the first time run for the kernel (the uniqueness of the kernel is defined by kernel_invoke_id), - # all such tensors will be cloned (with gradient) in case they are marked as dirty (if not cloned, but marked - # as dirty, PyTorch will complain the tensor is a leaf, should not be used for inplace update). Once - # `autograd.Function.apply` completes, by checking the existence of the tensor in the dirty_tensors, - # `_GlobalOpKernelInfoMap` is updated to save the input indices that are marked as dirty. - # 2. For the subsequent runs, if the input index is in `tensor_input_indices_for_mark_dirty`, the tensor - # will be cloned (with gradient) before fed into `autograd.Function.apply` as input. - self.tensor_input_indices_for_mark_dirty: Optional[Tuple[int, ...]] = None - - # A list of output indices that needs to be clone before returned, due to inplace update analysis. - self.output_indices_for_clone: Optional[List[int]] = None - - - self.tensor_input_states = OrderedDict() # key: tensor input index, value: TensorInputState. - - def check_with_input_index(self, tensor_input_index): - if tensor_input_index not in self.tensor_input_states: - is_input_index_saved_in_ctx = tensor_input_index in self.tensor_input_indices_to_save_in_ctx - is_input_index_marked_dirty = tensor_input_index in self.tensor_input_indices_for_mark_dirty - self.tensor_input_states[tensor_input_index] = [is_input_index_saved_in_ctx, is_input_index_marked_dirty] - return self.tensor_input_states[tensor_input_index] - - -# Store the kernel-specific information that cannot be retrieved and saved by PyTorch exporter. -# For the infos that can only be retrieved with real run, we try to collect them in the first time run. -# key: kernel_invoke_id, value: CustomFuncOpKernelInfo. -_GlobalOpKernelInfoMap: Dict[str, CustomFuncOpKernelInfo] = {} - - -@nvtx_function_decorator -def _process_inplace_outputs( - kernel_info: CustomFuncOpKernelInfo, - func_name: str, - input_tensors_of_kernel_run: Dict[int, Union[torch.Tensor, None]], - all_outputs_of_kernel_run: List[Union[torch.Tensor, any]], - all_outputs_to_tensor_inputs_reuse_map: List[int], - raw_input_tensors_used_inplace: Dict[int, Union[torch.Tensor, None]], - is_backward=False, -): - """Special handling for in-place reusing in forward or backward. - - Args: - kernel_info: kernel-specific information. - func_name: name of the autograd.Function. - input_tensors_of_kernel_run: all tensor input tensors used to run the autograd.Function forward/backward. - all_outputs_of_kernel_run: all outputs of the autograd.Function forward/backward. - all_outputs_to_tensor_inputs_reuse_map: a list of the same length of kernel outputs, each element representing - which input index it is reusing. If there is no reuse, the value is -1. - raw_input_tensors_used_inplace: a dict of raw input tensors marked as inplace in - `all_outputs_to_tensor_inputs_reuse_map`, the key is the tensor input index, value is the raw input tensor. - is_backward: indicates if this is backward or forward. - - Procedures: - 1. Detect all outputs to tensor inputs reuse mapping. - 2. Validate the detected inplace_map with the registered inplace_map in ORT. For the output tensor, - 2.0 If the reuse mapping value is the same in both inplace_map and detected inplace_map: - 2.0.1 Most likely, we don't need to do anything, except 2.0.2. - 2.0.2 Conditions: - > During forward run, - > The output tensor is reusing one of input tensors, - > The raw input tensor to be reused given from ORT is copied to run the forward kernels - (for two possible reasons: - a. the first time forward run, all inputs will be copied to detect - `tensor_input_indices_to_save_in_ctx`; - b. for every iteration, the input needs to be cloned because it is in - `tensor_input_indices_to_save_in_ctx`). - - In this case, need to copy the output tensor back to the raw input tensor, to make it compatible with - ORT statistically planned buffer reuse. - 2.1 If the reuse mapping value is NOT equal in both inplace_map and detected inplace_map: - 2.1.1 If the detected reuse input index is -1 (e.g. there is NO buffer reuse for this output), - while user specified reuse input index is NOT -1 (ORT planned the reuse), we raise an error. - 2.1.2 If the detected reuse input index is NOT -1 (e.g. there is buffer reuse for this output), - while user specified reuse input index is -1 (ORT did not plan the reuse). We will try to clone the - output tensor before returning to ORT, to align with ORT's NO Buffer reuse plan; otherwise, once the - input buffer is released by ORT memory planner, the output tensor read/write will be corrupted. - Raise a warning to notify users to update inplace_map explicitly for performance consideration. - 2.1.3 Other cases (for example user gives a wrong mapping index compared with detected ones), raise an - error. - 3. Do copies for 2.1.2 cases. - 4. Do copies for 2.0.2 cases. - """ - - log_prefix = f"{func_name}->{'Backward' if is_backward else 'Forward'}: " - input_tensor_address_list = [ - t.data_ptr() if isinstance(t, torch.Tensor) else -1 for t in input_tensors_of_kernel_run.values() - ] - if is_backward: - input_tensor_address_list = [-1, *input_tensor_address_list] # skip the context input - - is_first_time_init = kernel_info.output_indices_for_clone is None - # If this is the first time run, collect runtime tensor reuse mapping. - if is_first_time_init: - # Procedure 1: Detect all outputs to tensor inputs reuse mapping, according to `all_outputs_of_kernel_run` and - # `input_tensors_of_kernel_run`. - assert len(all_outputs_to_tensor_inputs_reuse_map) == len(all_outputs_of_kernel_run), ( - f"{log_prefix}all_outputs_to_tensor_inputs_reuse_map and kernel run outputs should have the same length." - f"all_outputs_to_tensor_inputs_reuse_map: {all_outputs_to_tensor_inputs_reuse_map}, " - f"kernel run outputs: {all_outputs_of_kernel_run}" - ) - - # Detect all outputs to tensor inputs reuse mapping. - detected_reuse_map = [-1] * (len(all_outputs_of_kernel_run)) - for output_index, arg in enumerate(all_outputs_of_kernel_run): - if not isinstance(arg, torch.Tensor): - continue - if arg.data_ptr() in input_tensor_address_list: - input_index = input_tensor_address_list.index(arg.data_ptr()) - detected_reuse_map[output_index] = input_index - - # Procedure 2: Validate the detected inplace_map with the registered inplace_map in ORT. - output_indices_for_clone = ( - [] - ) # collect the output indices that need to be cloned before returned in case 2.1.2. - for output_index, (detected_inplace_index, inplace_index) in enumerate( - zip(detected_reuse_map, all_outputs_to_tensor_inputs_reuse_map) - ): - if inplace_index == detected_inplace_index: - continue - - if ( - inplace_index in raw_input_tensors_used_inplace - and raw_input_tensors_used_inplace[inplace_index] is None - ): - # Use specified inplace input index, but the input tensor is None, which means the input is not - # a tensor, so we don't do further checks. - continue - - # If users register inplace_map (alloc planner will do buffer reuse), - # but detected inplace_map indicates it is NO inplace reusing, we raise an error. - if inplace_index != -1 and detected_inplace_index == -1: - raise RuntimeError( - f"{log_prefix}Fatal: " - f"ONNX Op attribute 'tensor_reuse_map' indicates {output_index}-th output is reusing input " - f"{inplace_index}, but detected inplace_map indicates it is NOT reusing any input. " - "Please update inplace_map explicitly to make it consistent " - f"to avoid undefined behavior due to ORT's memory reuse plan. " - f"inplace_map: {all_outputs_to_tensor_inputs_reuse_map}, " - f"detected inplace_map: {detected_reuse_map}" - ) - - if inplace_index == -1 and detected_inplace_index != -1: - output_indices_for_clone.append(output_index) - continue - - raise RuntimeError( - f"{log_prefix}Fatal: " - f"ONNX Op attribute 'inplace_map' indicates {inplace_index}-th output is reusing " - f"input index {detected_inplace_index}, but detected inplace_map indicates it is reusing " - f"input index {inplace_index}. Please update inplace_map explicitly to avoid undefined behavior " - f"due to memory reuse. inplace_map: {all_outputs_to_tensor_inputs_reuse_map}, " - f"detected inplace_map: {detected_reuse_map}" - ) - - kernel_info.output_indices_for_clone = output_indices_for_clone - - assert kernel_info.output_indices_for_clone is not None - - # Procedure 3: Do copies for 2.1.2 cases. - for output_index in kernel_info.output_indices_for_clone: - _log_warning( - f"{log_prefix}ONNX Op attribute " - f"'tensor_reuse_map' doesn't indicate {output_index}-th output is reusing any input, " - f"but detected inplace_map indicates it is reusing some input index. " - "A clone will be done before returning to ORT, to align with ORT's NO Buffer reuse plan. " - "Please update inplace_map explicitly to avoid such a copy." - ) - all_outputs_of_kernel_run[output_index] = all_outputs_of_kernel_run[output_index].detach().clone() - - # Procedure 4: Do copies for 2.0.2 cases. - if is_backward is False and ( - is_first_time_init - or kernel_info.tensor_input_indices_to_save_in_ctx - or kernel_info.tensor_input_indices_for_mark_dirty - ): - for raw_tensor_input_index, raw_input_tensor in raw_input_tensors_used_inplace.items(): - # raw_input_tensor can be None for backward run, but backward won't go here. - if not isinstance(raw_input_tensor, torch.Tensor): - continue - - # We did not do the check with tensor_input_indices_to_save_in_ctx/tensor_input_indices_for_mark_dirty - # because even for those tensor indices not in - # tensor_input_indices_to_save_in_ctx/tensor_input_indices_for_mark_dirty, we still need to do the - # copy for the first-time run. - if raw_input_tensor.data_ptr() == input_tensor_address_list[raw_tensor_input_index]: - # If the raw input tensor is not copied, we don't need this handling. - continue - - copied = False # for each tensor, we don't do the copy once. - output_indices_reusing_current_raw_input = [ - output_index - for output_index, input_index in enumerate(all_outputs_to_tensor_inputs_reuse_map) - if input_index == raw_tensor_input_index - ] - output_tensor_address = all_outputs_of_kernel_run[output_indices_reusing_current_raw_input[0]].data_ptr() - for output_index in output_indices_reusing_current_raw_input: - assert ( - output_tensor_address == all_outputs_of_kernel_run[output_index].data_ptr() - ), "Outputs reusing the same input tensor should have the same address." - - if not copied: - # Only need a copy once. - # Inplace copy only happens for non-leaf variables, so we have to set requires_grad to False. - raw_input_tensor.requires_grad = False - raw_input_tensor.copy_(all_outputs_of_kernel_run[output_index]) - _log_warning( - f"{log_prefix}Copy output tensor {output_index} to raw input tensor {raw_tensor_input_index}. " - f"{'Provide output to input reuse mapping to avoid the copy overhead.' if not is_first_time_init else ''}" - ) - copied = True - - all_outputs_of_kernel_run[output_index] = raw_input_tensor - - -@nvtx_function_decorator -def _get_context(forward_tensor_outputs: List[torch.Tensor]) -> Tuple[any, Optional[torch.Tensor]]: - """Search for context among all outputs. - - Note 1: All forward outputs of torch.autograd.Function shared the same gradient function pointer, - so here we just get the first tensor having grad_fn attribute. - (https://github.com/PyTorch/PyTorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/custom_function.cpp#L267) - - Note 2: Context can be None because NOT all torch.autograd.Function's are differentiable. The function - https://github.com/PyTorch/PyTorch/blob/d701357d921ef167d42c125e65b6f7da6be3ad0f/torch/csrc/autograd/custom_function.cpp#L209? - means if all output of the forward function is not differentiable, then grad_fn will be None (not be set). - - For example, - class Bar(torch.autograd.Function): - # A non-differentiable autograd Function whose forward output - # doesn't have grad_fn attribute. - @staticmethod - def forward(ctx, x): - y = torch.ones_like(x) - return y - - @staticmethod - def backward(ctx, dy): - dx = torch.zeros_like(dy) - return dx - - Returns: - ctx: context of the autograd.Function. - tensor: a tensor that owns the context. - - """ - ctx = None - first_tensor_output = None - - def _cond(t): - if not isinstance(t, torch.Tensor) or not hasattr(t, "grad_fn"): - return False - if t.grad_fn is None: - # For the following case, it is possible grad_fn exists, but its value is None, - # so we need to continue to search for the first tensor having a non-None grad_fn. - # - # >>> w = torch.randn(5, 6) - # >>> hasattr(w, "grad_fn") - # True - # >>> w.grad_fn is None - # True - # >>> w, ... = CustomFunc.apply(w) # where CustomFunc forward just return w and other tensors. - # - # Then hasattr(w, "grad_fn") is True, but w.grad_fn is None. - return False - return True - - ts = list(filter(_cond, forward_tensor_outputs)) - if ts: - first_tensor_output = ts[0] - # Use the first context we see because all of arg's share the same one. - ctx = first_tensor_output.grad_fn - return (ctx, first_tensor_output) - - -@nvtx_function_decorator -def _finalize_training_mode_forward( - kernel_invoke_id: str, - func_name: str, - input_tensors_used_for_fw_run: Dict[int, torch.Tensor], - forward_output_tensors: List[Union[torch.Tensor, None]], -): - """Complete the epilogue of forward runner for training mode. - - Args: - kernel_invoke_id: kernel_invoke_id of the PythonOp kernel unique id. - input_tensors_from_ort: input tensors generated from ORT backend. - forward_output_tensors: output tensors of the autograd.Function. - - Things to do: - 1. Try to get context from forward output tensors. - 2. Remove the gradient functions between the current autograd.Function and its input's gradient function, because - in ORT we don't depend on PyTorch's autograd engine. - 3. Register the current autograd.Function's gradient function into our PyNodeSharedPointerPool. - 4. Save kernel-specific information into _GlobalOpKernelInfoMap in the first-time kernel run. - """ - - ctx, tensor_owning_ctx = _get_context(forward_output_tensors) - - kernel_info = _GlobalOpKernelInfoMap[kernel_invoke_id] - - # ctx being None in training mode means the forward function is not differentiable, so backward is not needed. - if ctx is None: - # If this is the first time run, collect kernel-specific information. - if kernel_info.tensor_input_indices_to_save_in_ctx is None: - kernel_info.tensor_input_indices_to_save_in_ctx = tuple([]) - - if kernel_info.tensor_input_indices_for_mark_dirty is None: - kernel_info.tensor_input_indices_for_mark_dirty = tuple([]) - - return None - - # Filter out the None in the saved_tensors. - saved_tensors = [t for t in ctx.saved_tensors if t is not None] - - ctx.fw_kernel_invoke_id = kernel_invoke_id - - # If this is the first time run, collect kernel-specific information. - if kernel_info.tensor_input_indices_to_save_in_ctx is None: - if len(saved_tensors): - # Check tensors generated by ORT are in the saved_tensors or not. - # If yes, save the input index of the tensor in the _GlobalOpKernelInfoMap. - kernel_info.tensor_input_indices_to_save_in_ctx = tuple([ - tensor_input_index - for tensor_input_index, tensor in input_tensors_used_for_fw_run.items() - if any(tensor is saved_tensor for saved_tensor in saved_tensors) - ]) - _log_warning( - f"{func_name}: Add input index to _GlobalOpKernelInfoMap, to avoid extra copy in every iteration." - ) - else: - kernel_info.tensor_input_indices_to_save_in_ctx = () - kernel_info.materialize_grads = torch_interop_utils.get_materialize_grads(tensor_owning_ctx) - kernel_info.materialize_grads_config = OrderedDict() - if kernel_info.materialize_grads: - for output_index, tensor in enumerate(forward_output_tensors): - if isinstance(tensor, torch.Tensor): - kernel_info.materialize_grads_config[output_index] = ( - tensor.device, - tensor.dtype, - tensor.shape, - ) - - if kernel_info.tensor_input_indices_for_mark_dirty is None: - # Check tensors generated by ORT are marked as dirty(for inplace update) or not. - # If yes, save the input index of the tensor in the _GlobalOpKernelInfoMap. - are_tensors_marked_as_dirty = torch_interop_utils.are_tensors_marked_as_dirty( - tensor_owning_ctx, [t for t in input_tensors_used_for_fw_run.values()] - ) - kernel_info.tensor_input_indices_for_mark_dirty = tuple([ - tensor_input_index - for is_dirty, (tensor_input_index, tensor) in zip( - are_tensors_marked_as_dirty, input_tensors_used_for_fw_run.items() - ) - if is_dirty is True - ]) - _log_warning(f"{func_name}: Add input index to _GlobalOpKernelInfoMap, to support leaf node do inplace update.") - - - torch_nvtx_range_push(f"{func_name}.clear_grad") - # FORWARD BACKWARD FUNCTION CONNECTIONS - # input_1 (leaf, constructed by from_dlpack) <----reference---- AccumulateGrad gradient function - # ↓ ↑ - # autograd.Function apply() ------------> autograd.Function backward() - # ↓ | ↑ - # output_1, output_2 --- shared_ptr --- ↑ - # ↓ previous gradient function - - # We remove the edges starting between current autograd.Function's gradient function and - # it's input's gradient function (e.g. AccumulateGrad gradient function), then - # AccumulateGrad gradient function will be destroyed, releasing the reference to input_1 - # (https://github.com/PyTorch/PyTorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/functions/accumulate_grad.cpp#L21). - # The next edges are stored in Node, with which we can get next gradient function. - # https://github.com/PyTorch/PyTorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/function.h#L527 - torch_interop_utils.clear_grad_fns_for_next_edges(tensor_owning_ctx, saved_tensors) - torch_nvtx_range_pop() - - # This is mainly to hold grad_fn references by registering it into our PyNodeSharedPointerPool. - torch_nvtx_range_push(f"{func_name}.rg_grad_fn") - torch_interop_utils.register_grad_fn_and_remove_from_autograd(id(ctx), tensor_owning_ctx) - torch_nvtx_range_pop() - - return ctx - - - @nvtx_function_decorator def call_python_forward_function( forward_function: Callable, @@ -476,100 +63,9 @@ def call_python_forward_function( try: func_name = func_name.decode("utf-8") if isinstance(func_name, bytes) else func_name kernel_invoke_id = kernel_invoke_id.decode("utf-8") if isinstance(kernel_invoke_id, bytes) else kernel_invoke_id - wrapped_args = torch_interop_utils.forward_runner(requires_grad_flags, tensor_type_flags, is_training_mode, inplace_map, - kernel_invoke_id, func_name, args) - # # If this is the first time run, collect runtime tensor reuse mapping. - # is_first_time_run = kernel_invoke_id not in _GlobalOpKernelInfoMap - # if is_first_time_run: - # kernel_info = CustomFuncOpKernelInfo(kernel_invoke_id) - # _GlobalOpKernelInfoMap[kernel_invoke_id] = kernel_info - - # kernel_info = _GlobalOpKernelInfoMap[kernel_invoke_id] - - # # tensor_input_indices_to_save_in_ctx = kernel_info.tensor_input_indices_to_save_in_ctx - # # tensor_input_indices_for_mark_dirty = kernel_info.tensor_input_indices_for_mark_dirty - - # if kernel_info.position_to_tensor_index_map is None: - # position_to_tensor_index_map: List[Tuple[int, int]] = [] - # tensor_index = 0 - # for i, flag in enumerate(tensor_type_flags): - # if flag == 1: - # position_to_tensor_index_map.append((i, tensor_index)) - # tensor_index += 1 - # continue - # # position_to_tensor_index_map[i] = -1 - # kernel_info.position_to_tensor_index_map = tuple(position_to_tensor_index_map) - - # position_to_tensor_index_map = kernel_info.position_to_tensor_index_map - - # # Collect the tensor address for all inputs used for run forward, used for reuse detection. - # # If the input is reused, we need to save the raw input tensor for special handling. - # raw_input_tensors_used_inplace = OrderedDict() # Orders matter here. - # input_tensors_used_for_fw_run = OrderedDict() # Orders matter here. - - - - # @nvtx_function_decorator - # def _tensor_handle(pos_and_tensor_index, origin_args): - # input_position, tensor_input_index = pos_and_tensor_index - # arg = origin_args[input_position] - # grad_flag = requires_grad_flags[input_position] - - # # Assume it's a DLPack tensor and convert it to PyTorch tensor. - # wrapped_arg = from_dlpack(arg) - - # if tensor_input_index in inplace_map: - # raw_input_tensors_used_inplace[tensor_input_index] = wrapped_arg - - # # Only requires gradient when running under training mode - # # and the associated tensor has grad_flag=True (i.e., - # # "requires_grad=True" in the original PyTorch script). - # wrapped_arg.requires_grad = is_training_mode and grad_flag - - # # Note1: - # # If it's first-time kernel invocation, tensor_input_indices_to_save_in_ctx is None, we do the - # # copy for all tensors. Otherwise, we only copy the tensors whose indices are in - # # tensor_input_indices_to_save_in_ctx. - # # Note2: - # # For inference mode, we don't need to do the copy because ctx will be None, - # # so nothing will be saved for ctx. - # # Note3: - # # To fix this issue: - # # "a leaf Variable that requires grad has been used in an in-place operation." - # # If it's first-time kernel invocation, tensor_input_indices_for_mark_dirty is None, we do the - # # copy for all tensors to generate grad for it. Otherwise, we only clone (to generate grad) for - # # the tensors whose indices are in tensor_input_indices_for_mark_dirty. - # if is_training_mode: - # if is_first_time_run: - # with torch.set_grad_enabled(True): - # wrapped_arg = wrapped_arg.clone() - # else: - # is_input_index_saved_in_ctx, is_input_index_marked_dirty = kernel_info.check_with_input_index(tensor_input_index) - # if is_input_index_saved_in_ctx or is_input_index_marked_dirty: - # with torch.set_grad_enabled(is_input_index_marked_dirty): - # wrapped_arg = wrapped_arg.clone() - - # input_tensors_used_for_fw_run[tensor_input_index] = wrapped_arg - # wrapped_args[input_position] = wrapped_arg - - # torch_nvtx_range_push(f"{func_name}.pre") - # # wrapped_args = [] - # # if is_first_time_run or True: - # # a = 0 - # # for i, (arg, requires_grad_flag,) in enumerate(zip(args, requires_grad_flags)): - - # # wrapped_args.append(_tensor_handle(i, arg, requires_grad_flag, position_to_tensor_index_map, - # # raw_input_tensors_used_inplace, input_tensors_used_for_fw_run, - # # is_training_mode, inplace_map, - # # is_first_time_run, - # # tensor_input_indices_to_save_in_ctx, - # # tensor_input_indices_for_mark_dirty, - # # a)) - - # # else: - # wrapped_args = [a for a in args] - # _ = [_tensor_handle(p, args) for p in position_to_tensor_index_map] - # torch_nvtx_range_pop() + wrapped_args = torch_interop_utils.forward_runner(requires_grad_flags, tensor_type_flags, + is_training_mode, inplace_map, + kernel_invoke_id, func_name, args) with torch.set_grad_enabled(is_training_mode): # Run autograd.Function.apply(...). @@ -594,38 +90,7 @@ def call_python_forward_function( torch_nvtx_range_push(f"{func_name}.post") rets = torch_interop_utils.complete_forward_runner(is_training_mode, kernel_invoke_id, func_name, tuple(results)) torch_nvtx_range_pop() - return tuple(rets) - # ctx = None - # if is_training_mode: - # ctx = torch_interop_utils._finalize_training_mode_forward(kernel_invoke_id, func_name, results) - # # print(ctx, type(ctx)) - # if ctx is not None: - # ctx.fw_kernel_invoke_id = kernel_invoke_id - - # final_rets = [ctx] - # final_rets.extend(results) - - # # _process_inplace_outputs( - # # kernel_info, - # # func_name, - # # input_tensors_used_for_fw_run, - # # final_rets, - # # inplace_map, - # # raw_input_tensors_used_inplace, - # # ) - - # dlpacks = [final_rets[0]] - # torch_nvtx_range_push(f"{func_name}.post") - # def _wrap_dlpack(value): - # return to_dlpack(value) if value is not None else None - - # dlpacks.extend(list(map(_wrap_dlpack, final_rets[1:]))) - # torch_nvtx_range_pop() - - # Inside the returned list, the first element is context and the rest - # are DLPack tensors. - # return tuple(dlpacks) except Exception as e: # Flush buffers. Otherwise, calling this from C++ may lose them. print("Exception happens when running ", forward_function) @@ -660,82 +125,18 @@ def call_python_backward_function( it is reusing. If there is no reuse, the value is -1. args: inputs to "backward_function". """ - func_name = func_name.decode("utf-8") if isinstance(func_name, bytes) else func_name - with torch.no_grad(): - - def wrap_all_outputs(result): - if isinstance(result, torch.Tensor): - return [to_dlpack(result)] - elif isinstance(result, (tuple, list)): - return [to_dlpack(value) if value is not None else None for value in result] - else: - raise wrap_exception( - ORTModuleIOError, - TypeError(f"ORTModule does not support the following model output type {type(result)}."), - ) - - try: - # If this is the first time run, collect runtime tensor reuse mapping. - # if kernel_invoke_id not in _GlobalOpKernelInfoMap: - # kernel_info = CustomFuncOpKernelInfo(kernel_invoke_id) - # _GlobalOpKernelInfoMap[kernel_invoke_id] = kernel_info - - # kernel_info = _GlobalOpKernelInfoMap[kernel_invoke_id] - - # Backward inputs should not require gradients. - assert all(grad_flag == 0 for grad_flag in requires_grad_flags) - - # Prepare inputs for calling Python function. - ctx = args[0] - - # fw_kernel_invoke_id = ctx.fw_kernel_invoke_id - wrapped_args = [] - - # Collect the tensor address for all inputs used for run backward, used for reuse detection. - tensor_input_index = 1 # skip the context input - # If input is reused, we need to save the raw input tensor for special handling. - raw_input_tensors_used_inplace = OrderedDict() # Orders matter here. - input_tensors_used_for_bw_run = OrderedDict() # Orders matter here. - for grad_input_index, (grad_flag, tensor_flag, arg) in enumerate( - zip(requires_grad_flags, tensor_type_flags, args) - ): - # If an input is a tensor, it is possible we get a None also when it is optional as grad input. - if tensor_flag: - if arg is None: - # if _GlobalOpKernelInfoMap[fw_kernel_invoke_id].materialize_grads: - # config = _GlobalOpKernelInfoMap[fw_kernel_invoke_id].materialize_grads_config - # # ignore the first input, which is the ctx. - # device, dtype, shape = config[grad_input_index - 1] - # wrapped_arg = torch.zeros(shape, device=device, dtype=dtype) - # else: - # wrapped_arg = arg - wrapped_arg = arg - - if grad_input_index in inplace_map: - raw_input_tensors_used_inplace[tensor_input_index] = arg - else: - # Assume it's a DLPack tensor# and convert it to PyTorch tensor. - wrapped_arg = from_dlpack(arg) - - if grad_input_index in inplace_map: - raw_input_tensors_used_inplace[tensor_input_index] = wrapped_arg - - # This may include None values. - input_tensors_used_for_bw_run[tensor_input_index] = wrapped_arg + try: - if wrapped_arg is not None: - # Only requires gradient when running under training mode - # and the associated tensor has grad_flag=True (i.e., - # "requires_grad=True" in the original PyTorch script). - wrapped_arg.requires_grad = is_training_mode and grad_flag + func_name = func_name.decode("utf-8") if isinstance(func_name, bytes) else func_name + kernel_invoke_id = kernel_invoke_id.decode("utf-8") if isinstance(kernel_invoke_id, bytes) else kernel_invoke_id - wrapped_args.append(wrapped_arg) - tensor_input_index += 1 - else: - # Use non-tensor as is. It's a PyObject*. - wrapped_args.append(arg) + wrapped_args = torch_interop_utils.backward_runner(requires_grad_flags, tensor_type_flags, + is_training_mode, inplace_map, + kernel_invoke_id, func_name, args) + ctx = args[0] + with torch.no_grad(): # Call Python function. torch_nvtx_range_push(f"{func_name}.bw") result = backward_function(*wrapped_args) @@ -761,15 +162,22 @@ def wrap_all_outputs(result): # raw_input_tensors_used_inplace, # is_backward=True, # ) - + def wrap_all_outputs(result): + if isinstance(result, torch.Tensor): + return [to_dlpack(result)] + elif isinstance(result, (tuple, list)): + return [to_dlpack(value) if value is not None else None for value in result] + else: + raise wrap_exception( + ORTModuleIOError, + TypeError(f"ORTModule does not support the following model output type {type(result)}."), + ) wrapped_returned_args = wrap_all_outputs(result) - torch_interop_utils.unregister_grad_fn(ctx) - return tuple(wrapped_returned_args) - except Exception as e: - # Flush buffers. Otherwise, calling this from C++ may lose them. - print("Exception happens when running ", backward_function) - sys.stdout.flush() - sys.stderr.flush() - raise wrap_exception(ORTModuleFallbackException, e) # noqa: B904 + except Exception as e: + # Flush buffers. Otherwise, calling this from C++ may lose them. + print("Exception happens when running ", backward_function) + sys.stdout.flush() + sys.stderr.flush() + raise wrap_exception(ORTModuleFallbackException, e) # noqa: B904 diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/torch_interop_utils.cc b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/torch_interop_utils.cc index 5c42426392270..e38e945c23481 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/torch_interop_utils.cc +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/torch_interop_utils.cc @@ -40,7 +40,7 @@ class PyNodeSharedPointerPool { static PyNodeSharedPointerPool& GetInstance() { static PyNodeSharedPointerPool pool; return pool; - }; + } void RegisterGradFuncAndRemoveFromAutoGrad(const size_t& ctx_address, torch::autograd::AutogradMeta* autograd_meta) { @@ -52,14 +52,14 @@ class PyNodeSharedPointerPool { grad_fns_.emplace(ctx_address, std::move(autograd_meta->grad_fn_)); TORCH_CHECK(autograd_meta->grad_fn_ == nullptr, "fail to remove grad_fn_ from torch autograd for ctx ", ctx_address); - }; + } void UnRegisterGradFunc(const size_t& ctx_address) { auto it = grad_fns_.find(ctx_address); TORCH_CHECK(it != grad_fns_.end(), "fail to find grad_fn for ctx ", ctx_address); grad_fns_.erase(ctx_address); - }; + } void ClearAll() { grad_fns_.clear(); @@ -253,8 +253,8 @@ class CustomFuncOpKernelInfo { std::string kernel_invoke_id; std::unordered_map input_global_index_to_tensor_index_map; std::optional> tensor_input_indices_to_save_in_ctx; - bool materialize_grads; - // std::unordered_map> materialize_grads_config; + bool materialize_grads{true}; + std::unordered_map, c10::TensorOptions>> materialize_grads_config; std::optional> tensor_input_indices_for_mark_dirty; std::vector output_indices_for_clone; @@ -265,7 +265,6 @@ class CustomFuncOpKernelInfo { std::unordered_map _GlobalOpKernelInfoMap; py::list forward_runner( - // std::function forward_function, const std::vector& requires_grad_flags, const std::vector& tensor_type_flags, bool is_training_mode, @@ -275,7 +274,6 @@ py::list forward_runner( py::tuple args) { py::gil_scoped_release release; - // auto t0 = std::chrono::high_resolution_clock::now(); auto it = _GlobalOpKernelInfoMap.find(kernel_invoke_id); if (it == _GlobalOpKernelInfoMap.end()) { bool safe_run = false; @@ -283,16 +281,14 @@ py::list forward_runner( } CustomFuncOpKernelInfo& kernel_info = _GlobalOpKernelInfoMap.at(kernel_invoke_id); + // std::unordered_map raw_input_tensors_used_inplace; // std::unordered_map input_tensors_used_for_fw_run; + int tensor_input_index = 0; - std::vector> wrapped_args; + std::vector wrapped_args; wrapped_args.reserve(args.size()); - { - // auto t1 = std::chrono::high_resolution_clock::now(); - // std::chrono::duration fs = t1 - t0; - // std::cout << "ckpt 1 latency(ms): " << fs.count() * 1000 << ", kernel_info.is_first_run: " << kernel_info.is_first_run << std::endl; - } + for (size_t arg_index = 0; arg_index < args.size(); ++arg_index) { bool is_tensor = tensor_type_flags[arg_index] == 1; bool requires_grad = requires_grad_flags[arg_index] && is_training_mode; @@ -303,28 +299,26 @@ py::list forward_runner( at::Tensor tensor; { - // auto t0 = std::chrono::high_resolution_clock::now(); pybind11::gil_scoped_acquire gil; // Assume it's a DLPack tensor and convert it to PyTorch tensor. TORCH_CHECK(PyCapsule_IsValid(args[arg_index].ptr(), "dltensor") != 0, "found invalid pycapsule"); tensor = torch::utils::tensor_fromDLPack(args[arg_index].ptr()); - // auto t1 = std::chrono::high_resolution_clock::now(); - // std::chrono::duration fs = t1 - t0; - // std::cout << "dlpack latency(ms): " << fs.count() * 1000 << ", kernel_info.is_first_run: " << kernel_info.is_first_run << std::endl; } - // bool is_input_used_inplace = std::find(inplace_map.begin(), inplace_map.end(), tensor_input_index) != inplace_map.end(); - // if (is_input_used_inplace) { - // // raw_input_tensors_used_inplace[tensor_input_index] = tensor; - // } - tensor.requires_grad_(requires_grad); if (is_training_mode && kernel_info.safe_run_enabled) { + // bool is_input_used_inplace = std::find(inplace_map.begin(), inplace_map.end(), tensor_input_index) != inplace_map.end(); + // if (is_input_used_inplace) { + // raw_input_tensors_used_inplace[tensor_input_index] = tensor; + // } + if (kernel_info.is_first_run) { at::AutoGradMode enable_grad(true); auto wrapped_arg = tensor.clone(); - wrapped_args.push_back(wrapped_arg); + wrapped_args.push_back(py::reinterpret_steal(THPVariable_Wrap(wrapped_arg))); + + // input_tensors_used_for_fw_run[tensor_input_index] = tensor; } else { bool is_input_index_saved_in_ctx = kernel_info.tensor_input_indices_to_save_in_ctx.value().find(tensor_input_index) != kernel_info.tensor_input_indices_to_save_in_ctx.value().end(); @@ -349,34 +343,22 @@ py::list forward_runner( // input_tensors_used_for_fw_run[tensor_input_index] = wrapped_arg // wrapped_args[input_position] = wrapped_arg wrapped_arg.requires_grad_(requires_grad); - wrapped_args.push_back(wrapped_arg); + wrapped_args.push_back(py::reinterpret_steal(THPVariable_Wrap(wrapped_arg))); + + // input_tensors_used_for_fw_run[tensor_input_index] = wrapped_arg; } else { - wrapped_args.push_back(tensor); + wrapped_args.push_back(py::reinterpret_steal(THPVariable_Wrap(tensor))); + // input_tensors_used_for_fw_run[tensor_input_index] = tensor; } } } else { - wrapped_args.push_back(tensor); + wrapped_args.push_back(py::reinterpret_steal(THPVariable_Wrap(tensor))); } // input_tensors_used_for_fw_run[tensor_input_index] = wrapped_args.back(); - tensor_input_index++; - { - // auto t1 = std::chrono::high_resolution_clock::now(); - // std::chrono::duration fs = t1 - t0; - // // std::chrono::milliseconds d = std::chrono::duration_cast(fs); - // std::cout << "ckpt 2 latency(ms): " << fs.count() * 1000 << ", kernel_info.is_first_run: " << kernel_info.is_first_run << std::endl; - } - } - { - // auto t1 = std::chrono::high_resolution_clock::now(); - // std::chrono::duration fs = t1 - t0; - // // std::chrono::milliseconds d = std::chrono::duration_cast(fs); - // std::cout << "runner e2e latency(ms): " << fs.count() * 1000 << ", kernel_info.is_first_run: " << kernel_info.is_first_run << std::endl; + // tensor_input_index++; } - if (kernel_info.is_first_run) { - kernel_info.is_first_run = false; - } return py::cast(wrapped_args); } @@ -437,11 +419,11 @@ py::object _finalize_training_mode_forward( if (tensor_owning_ctx.has_value()) { ret = py::reinterpret_steal(torch::autograd::functionToPyObject(tensor_owning_ctx.value().grad_fn())); } else { + // ctx being None in training mode means the forward function is not differentiable, so backward is not needed. ret = py::none(); } } - // #ctx being None in training mode means the forward function is not differentiable, so backward is not needed. if (!tensor_owning_ctx.has_value()) { // #If this is the first time run, collect kernel - specific information. if (kernel_info.is_first_run && kernel_info.safe_run_enabled) { @@ -456,6 +438,46 @@ py::object _finalize_training_mode_forward( return ret; } + if (kernel_info.is_first_run) { + py::gil_scoped_release release; + torch::autograd::AutogradMeta* autograd_meta = torch::autograd::impl::get_autograd_meta(tensor_owning_ctx.value()); + const auto& grad_fn = autograd_meta->grad_fn_; + auto py_node_fn = dynamic_cast(grad_fn.get()); + TORCH_CHECK(py_node_fn != nullptr, "grad_fn is not PyNode type."); + THPFunction* py_fn = (THPFunction*)py_node_fn->obj; + kernel_info.materialize_grads = py_fn->materialize_grads; + + // kernel_info.materialize_grads = get_materialize_grads(tensor_owning_ctx.value()); + if (kernel_info.materialize_grads) { + for (size_t i = 0; i < forward_output_tensors.size(); ++i) { + PyObject* obj = forward_output_tensors[i].ptr(); + if (!THPVariable_Check(obj)) { + continue; + } + at::Tensor t = THPVariable_Unpack(obj); + kernel_info.materialize_grads_config.insert({i, {t.sizes().vec(), t.options()}}); + } + } + + // Py_ssize_t num_saved_for_forward = + // PyTuple_GET_SIZE(py_fn->saved_for_forward); + // std::vector saved_tensors; + // saved_tensors.reserve(num_saved_for_forward); + // for (const auto i : c10::irange(num_saved_for_forward)) { + // PyObject* obj = PyTuple_GET_ITEM(py_fn->saved_for_forward, i); + // if (THPVariable_Check(obj)) { + // const auto& tensor = THPVariable_Unpack(obj); + // saved_tensors.push_back(tensor); + // } + // } + + // kernel_info.tensor_input_indices_to_save_in_ctx = tuple([ + // tensor_input_index + // for tensor_input_index, tensor in input_tensors_used_for_fw_run.items() + // if any(tensor is saved_tensor for saved_tensor in saved_tensors) + // ]) + } + // auto py_node_fn = dynamic_cast(tensor_owning_ctx.value().grad_fn().get()); // TORCH_CHECK(py_node_fn != nullptr, "grad_fn is not PyNode type."); // THPFunction* py_fn = (THPFunction*)py_node_fn->obj; @@ -727,8 +749,6 @@ py::list complete_forward_runner( py::object ctx; if (is_training_mode) { ctx = _finalize_training_mode_forward(kernel_invoke_id, func_name, forward_output_tensors); - // if ctx is not None: - // ctx.fw_kernel_invoke_id = kernel_invoke_id if (!ctx.is_none()) { PyObject_SetAttrString(ctx.ptr(), "fw_kernel_invoke_id", py::cast(kernel_invoke_id).ptr()); } @@ -767,10 +787,127 @@ py::list complete_forward_runner( // dlpacks.extend(list(map(_wrap_dlpack, final_rets[1:]))) // torch_nvtx_range_pop() + CustomFuncOpKernelInfo& kernel_info = _GlobalOpKernelInfoMap.at(kernel_invoke_id); + if (kernel_info.is_first_run) { + kernel_info.is_first_run = false; + } return py::cast(rets); } +py::list backward_runner( + // std::function forward_function, + const std::vector& requires_grad_flags, + const std::vector& tensor_type_flags, + bool is_training_mode, + const std::vector& inplace_map, + const std::string& kernel_invoke_id, + const std::string& func_name, + py::tuple args) { + py::gil_scoped_release release; + at::AutoGradMode enable_grad(false); + + // auto t0 = std::chrono::high_resolution_clock::now(); + auto it = _GlobalOpKernelInfoMap.find(kernel_invoke_id); + if (it == _GlobalOpKernelInfoMap.end()) { + bool safe_run = false; + _GlobalOpKernelInfoMap.emplace(kernel_invoke_id, CustomFuncOpKernelInfo(kernel_invoke_id, safe_run)); + } + + CustomFuncOpKernelInfo& kernel_info = _GlobalOpKernelInfoMap.at(kernel_invoke_id); + + // std::unordered_map raw_input_tensors_used_inplace; + // std::unordered_map input_tensors_used_for_fw_run; + + // int tensor_input_index = 0; + std::vector wrapped_args; + wrapped_args.reserve(args.size()); + py::object ctx = args[0]; + pybind11::gil_scoped_acquire gil; + wrapped_args.push_back(ctx); + for (size_t arg_index = 1; arg_index < args.size(); ++arg_index) { + if (tensor_type_flags[arg_index] != 1) { + wrapped_args.push_back(args[arg_index]); + continue; + } + + at::Tensor tensor; + // Assume it's a DLPack tensor and convert it to PyTorch tensor. + bool is_dlpack = PyCapsule_IsValid(args[arg_index].ptr(), "dltensor") != 0; + if (is_dlpack) { + tensor = torch::utils::tensor_fromDLPack(args[arg_index].ptr()); + } else { + TORCH_CHECK(args[arg_index].is_none(), "Only None is supported for non-tensor input."); + PyObject* fw_kernel_invoke_id = PyObject_GetAttrString(ctx.ptr(), "fw_kernel_invoke_id"); + std::string fw_kernel_invoke_id_str = py::cast(py::reinterpret_borrow(fw_kernel_invoke_id)); + CustomFuncOpKernelInfo& fw_kernel_info = _GlobalOpKernelInfoMap.at(fw_kernel_invoke_id_str); + if (fw_kernel_info.materialize_grads) { + auto& config = fw_kernel_info.materialize_grads_config.at(arg_index - 1); + tensor = at::zeros(std::get<0>(config), std::get<1>(config)); // shift by 1 to skip context input. + } + } + + if (kernel_info.safe_run_enabled) { + // bool is_input_used_inplace = std::find(inplace_map.begin(), inplace_map.end(), tensor_input_index) != inplace_map.end(); + // if (is_input_used_inplace) { + // raw_input_tensors_used_inplace[tensor_input_index] = tensor; + // } + + // if (kernel_info.is_first_run) { + // at::AutoGradMode enable_grad(true); + // auto wrapped_arg = tensor.clone(); + // wrapped_args.push_back(py::reinterpret_steal(THPVariable_Wrap(wrapped_arg))); + // } else { + // bool is_input_index_saved_in_ctx = kernel_info.tensor_input_indices_to_save_in_ctx.value().find(tensor_input_index) != + // kernel_info.tensor_input_indices_to_save_in_ctx.value().end(); + // // std::find(kernel_info.tensor_input_indices_to_save_in_ctx.value().begin(), + // // kernel_info.tensor_input_indices_to_save_in_ctx.value().end(), + // // tensor_input_index) != + // // kernel_info.tensor_input_indices_to_save_in_ctx.value().end(); + + // bool is_input_index_marked_dirty = kernel_info.tensor_input_indices_for_mark_dirty.value().find(tensor_input_index) != + // kernel_info.tensor_input_indices_for_mark_dirty.value().end(); + // // std::find(kernel_info.tensor_input_indices_for_mark_dirty.value().begin(), + // // kernel_info.tensor_input_indices_for_mark_dirty.value().end(), + // // tensor_input_index) != + // // kernel_info.tensor_input_indices_for_mark_dirty.value().end(); + + // if (is_input_index_saved_in_ctx || is_input_index_marked_dirty) { + // at::AutoGradMode enable_grad(is_input_index_marked_dirty); + // auto wrapped_arg = tensor.clone(); + // // with torch.set_grad_enabled(is_input_index_marked_dirty): + // // wrapped_arg = wrapped_arg.clone() + + // // input_tensors_used_for_fw_run[tensor_input_index] = wrapped_arg + // // wrapped_args[input_position] = wrapped_arg + // wrapped_arg.requires_grad_(requires_grad); + // wrapped_args.push_back(py::reinterpret_steal(THPVariable_Wrap(wrapped_arg))); + // } else { + // wrapped_args.push_back(py::reinterpret_steal(THPVariable_Wrap(tensor))); + // } + // } + // input_tensors_used_for_fw_run[tensor_input_index] = tensor; + + } else { + if (tensor.defined()) { + wrapped_args.push_back(py::reinterpret_steal(THPVariable_Wrap(tensor))); + } else { + wrapped_args.push_back(py::none()); + } + } + + // input_tensors_used_for_fw_run[tensor_input_index] = wrapped_args.back(); + // tensor_input_index++; + } + + if (kernel_info.is_first_run) { + kernel_info.is_first_run = false; + } + return py::cast(wrapped_args); +} + +size_t get_custom_function_forward_runner() { return reinterpret_cast(&forward_runner); } + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("register_grad_fn_and_remove_from_autograd", ®ister_grad_fn_and_remove_from_autograd, "Increase grad_fn shared pointer reference."); @@ -783,4 +920,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward_runner", &forward_runner, "Forward runner."); m.def("_finalize_training_mode_forward", &_finalize_training_mode_forward, "Finalize training mode forward."); m.def("complete_forward_runner", &complete_forward_runner, "Complete forward runner."); + m.def("backward_runner", &backward_runner, "Backward runner."); + m.def("get_custom_function_forward_runner", &get_custom_function_forward_runner, "Get custom function forward runner."); } diff --git a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py index fd8691a54fafb..5e79334b3c402 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py @@ -284,13 +284,13 @@ def forward( # # updated_kwargs_tensors, _, _, _ = extract_data_and_schema(updated_kwargs) # updated_kwargs_tensors = extract_data_with_access_func(updated_kwargs, kwargs_data_access_func) - rets = tuple( tensor_list[: args_tensor_count + kwargs_tensor_count]) + rets = tuple(tensor_list[: args_tensor_count + kwargs_tensor_count]) - def _do(p): - return p.detach().requires_grad_(p.requires_grad) - - rets += tuple(map(_do, partitioned_params)) + # def _do(p): + # return p.detach().requires_grad_(p.requires_grad) + # rets += tuple(map(_do, partitioned_params)) + rets += tuple([p.detach().requires_grad_(p.requires_grad) for p in partitioned_params]) # PyTorch exporter does not support an empty list of tensors, so we have this check. assert len(rets) != 0 @@ -425,9 +425,9 @@ def forward( ctx.module = module ctx.pre_backward_function = pre_backward_function - + rets = [o.detach().requires_grad_(o.requires_grad) for o in updated_output_tensors] torch_nvtx_range_pop() - return tuple(updated_output_tensors) + return tuple(rets) @staticmethod def backward(ctx, *grads):