Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
pengwa committed Sep 18, 2023
1 parent a4a8558 commit f7941b2
Showing 1 changed file with 10 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,16 @@ def wrap_all_outputs(result):
result = backward_function(*wrapped_args)

# Extract results as DLPack tensor list.
if isinstance(result, torch.Tensor):
result = [result]
elif isinstance(result, (tuple, list)):
result = list(result)
else:
raise wrap_exception(
ORTModuleIOError,
TypeError(f"ORTModule does not support the following model output type {type(result)}."),
)

wrapped_returned_args = wrap_all_outputs(result)

torch_interop_utils.unregister_grad_fn(id(ctx))
Expand Down

0 comments on commit f7941b2

Please sign in to comment.