diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py index 845c7d83c2e7b..a5b96c4e37140 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py @@ -376,6 +376,16 @@ def wrap_all_outputs(result): result = backward_function(*wrapped_args) # Extract results as DLPack tensor list. + if isinstance(result, torch.Tensor): + result = [result] + elif isinstance(result, (tuple, list)): + result = list(result) + else: + raise wrap_exception( + ORTModuleIOError, + TypeError(f"ORTModule does not support the following model output type {type(result)}."), + ) + wrapped_returned_args = wrap_all_outputs(result) torch_interop_utils.unregister_grad_fn(id(ctx))