Skip to content

Commit

Permalink
Mark nn module tensor static for cudagraphs (#132736)
Browse files Browse the repository at this point in the history
Summary:
Fixes pytorch/pytorch#132714

X-link: pytorch/pytorch#132736
Approved by: https://github.com/mlazos
ghstack dependencies: #132538

Reviewed By: PaliC

Differential Revision: D60857243

Pulled By: anijain2305

fbshipit-source-id: a2c44a5772546340d1cd564c00f6b9d271620aff
  • Loading branch information
anijain2305 authored and facebook-github-bot committed Aug 7, 2024
1 parent a9b8a9a commit 0da28c0
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2144,7 +2144,9 @@ def tensor_always_has_static_shape(

if (
tensor_source.guard_source().is_specialized_nn_module()
or tensor_source.guard_source().is_unspecialized_builtin_nn_module()
# Marking the tensor attributes of nn modules static to keep the behavior same as before
# inline_inbuilt_nn_module flag was introduced.
or tensor_source.guard_source().is_unspecialized_nn_module()
) and config.force_nn_module_property_static_shapes:
return True, TensorStaticReason.NN_MODULE_PROPERTY

Expand Down

0 comments on commit 0da28c0

Please sign in to comment.