Skip to content

Commit

Permalink
Store CompilationEvents in a buffer in torch._dynamo.utils (#115788)
Browse files Browse the repository at this point in the history
Summary:
Motivation: it would be nice to be able to test using the metrics in log_compilation_event; currently dumps logs (or logs to a database in fbcode) - these are hard to use in unit tests.

This change:
* always record the information in torch._dynamo.utils.record_compilation_metrics; here, log into a limited-size deque to prevent the list of metrics from getting too long
* if config.log_compilation_metrics, then call back into the original log_compilation_event function

X-link: pytorch/pytorch#115788
Approved by: https://github.com/yanboliang

Reviewed By: jeanschmidt

Differential Revision: D52298053

Pulled By: davidberard98

fbshipit-source-id: ef291255d6148c0479f3000b4fb21a4ed72cadcb
  • Loading branch information
davidberard98 authored and facebook-github-bot committed Dec 20, 2023
1 parent 539de79 commit 36daad7
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
ClassVar,
Counter,
DefaultDict,
Deque,
Dict,
Iterator,
List,
Expand Down Expand Up @@ -89,6 +90,7 @@
import torch.fx.experimental.symbolic_shapes
from torch import fx
from torch._dispatch.python import enable_python_dispatcher
from torch._utils_internal import log_compilation_event

from torch.nn.modules.lazy import LazyModuleMixin
from torch.utils._pytree import tree_map_only
Expand Down Expand Up @@ -600,6 +602,38 @@ class CompilationMetrics:
compliant_custom_ops: Set[str]


DEFAULT_COMPILATION_METRICS_LIMIT = 64


_compilation_metrics: Deque[CompilationMetrics] = collections.deque(
maxlen=DEFAULT_COMPILATION_METRICS_LIMIT
)


def record_compilation_metrics(compilation_metrics: CompilationMetrics):
global _compilation_metrics
_compilation_metrics.append(compilation_metrics)
if config.log_compilation_metrics:
log_compilation_event(compilation_metrics)


def set_compilation_metrics_limit(new_size: int) -> None:
global _compilation_metrics
while len(_compilation_metrics) > new_size:
_compilation_metrics.popleft()
new_deque = collections.deque(_compilation_metrics, maxlen=new_size)
_compilation_metrics = new_deque


def clear_compilation_metrics() -> None:
global _compilation_metrics
_compilation_metrics.clear()


def get_compilation_metrics() -> List[CompilationMetrics]:
return list(_compilation_metrics)


@dataclasses.dataclass
class CleanupHook:
"""Remove a global variable when hook is called"""
Expand Down

0 comments on commit 36daad7

Please sign in to comment.