Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
pengwa committed Oct 19, 2023
1 parent 666a761 commit 00f4598
Showing 1 changed file with 2 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,10 @@ def call_python_forward_function(
or tensor_input_index in tensor_input_indices_for_mark_dirty
)
if is_input_index_saved_in_ctx or is_input_index_marked_dirty:
# when with grad, the leaf tensor after clone will not be leaf.
with torch.set_grad_enabled(is_input_index_marked_dirty):
wrapped_arg = wrapped_arg.clone()
wrapped_arg.requires_grad = is_training_mode and grad_flag

wrapped_args.append(wrapped_arg)
input_tensors_used_for_fw_run[tensor_input_index] = wrapped_arg
Expand Down

0 comments on commit 00f4598

Please sign in to comment.