Skip to content

Commit

Permalink
Mark attri…
Browse files Browse the repository at this point in the history
Summary:
Shuai wants to test this internally before pytorch/pytorch#133713 can go in. Creating a separate PR for ghmport.

cc rec

X-link: pytorch/pytorch#134136

Reviewed By: yanboliang

Differential Revision: D61612768

Pulled By: anijain2305

fbshipit-source-id: 9681ee2215967730446b3abe57fbf0f5ed209968
  • Loading branch information
anijain2305 authored and facebook-github-bot committed Aug 22, 2024
1 parent b38b6f3 commit 3cd3dd5
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
_push_on_torch_function_stack,
)
from torch._dispatch.python import enable_python_dispatcher
from torch._guards import TracingContext
from torch._guards import Source, TracingContext
from torch._subclasses.meta_utils import is_sparse_compressed
from torch._utils_internal import log_chromium_event_internal, log_compilation_event
from torch.fx._utils import _format_graph_code, lazy_format_graph_code
Expand Down Expand Up @@ -2344,7 +2344,7 @@ def tensor_static_reason_to_message(reason: TensorStaticReason):
def tensor_always_has_static_shape(
tensor: Union[torch.Tensor, Any],
is_tensor: bool,
guard_source: torch._guards.GuardSource,
tensor_source: Source,
) -> Tuple[bool, Optional[TensorStaticReason]]:
"""
Given a tensor, source, and is_tensor flag, determine if a shape should be static.
Expand All @@ -2357,12 +2357,20 @@ def tensor_always_has_static_shape(
Returns a tuple, where the first element is the bool of whether or not this tensor should have a static shape.
The second element is a TensorStaticReason, useful for passing to tensor_static_reason_to_message if needed.
"""
from .source import is_from_unspecialized_param_buffer_source

if (
guard_source.is_specialized_nn_module()
and config.force_nn_module_property_static_shapes
):
tensor_source.guard_source().is_specialized_nn_module()
# Marking the tensor attributes of nn modules static to keep the behavior same as before
# inline_inbuilt_nn_module flag was introduced.
or tensor_source.guard_source().is_unspecialized_nn_module()
) and config.force_nn_module_property_static_shapes:
return True, TensorStaticReason.NN_MODULE_PROPERTY
if type(tensor) is torch.nn.Parameter and config.force_parameter_static_shapes:

if (
type(tensor) is torch.nn.Parameter
or is_from_unspecialized_param_buffer_source(tensor_source)
) and config.force_parameter_static_shapes:
return True, TensorStaticReason.PARAMETER
if not is_tensor:
return True, TensorStaticReason.NOT_TENSOR
Expand Down

0 comments on commit 3cd3dd5

Please sign in to comment.