diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index 6b4ec04ec..7caf8a350 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -26,6 +26,7 @@ import time import types import typing +import uuid import warnings import weakref from contextlib import contextmanager @@ -64,7 +65,7 @@ from torch._dispatch.python import enable_python_dispatcher from torch._guards import TracingContext from torch._subclasses.meta_utils import is_sparse_compressed -from torch._utils_internal import log_compilation_event +from torch._utils_internal import log_chromium_event_internal, log_compilation_event from torch.fx._utils import _format_graph_code, lazy_format_graph_code from torch.nn.modules.lazy import LazyModuleMixin from torch.utils._triton import has_triton, has_triton_package @@ -212,6 +213,16 @@ def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None: frame_phase_timing[key][phase_name] += time_spent +def get_cache_stats() -> Dict[str, Any]: + """Get a bunch of metadata about cache hits and misses to use in chromium events""" + cache_stats = { + "fxgraph_cache_hit": counters["inductor"]["fxgraph_cache_hit"], + "fxgraph_cache_miss": counters["inductor"]["fxgraph_cache_miss"], + "fxgraph_cache_bypass": counters["inductor"]["fxgraph_cache_bypass"], + } + return cache_stats + + # dynamo_timed is a context manager # By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics # where the key is the functions name. @@ -245,22 +256,34 @@ def dynamo_timed( phase_name: Optional[str] = None, fwd_only: bool = True, ): + chromium_log: ChromiumEventLogger = get_chromium_event_logger() if key not in compilation_time_metrics: compilation_time_metrics[key] = [] fail_type: Optional[str] = None fail_reason: Optional[str] = None time_spent = float("-inf") + if phase_name == "entire_frame_compile": + chromium_log.reset() try: with torch.profiler.record_function(f"{key} (dynamo_timed)"): t0 = time.time() - ChromiumEventLogger.log_event_start(key, time.time_ns()) + start = time.time_ns() + chromium_log.log_event_start(key, start, None) if phase_name: - ChromiumEventLogger.log_event_start(phase_name, time.time_ns()) + chromium_log.log_event_start(phase_name, start) yield + if phase_name: - ChromiumEventLogger.log_event_end(phase_name, time.time_ns()) - ChromiumEventLogger.log_event_end(key, time.time_ns()) + chromium_log.log_event_end( + phase_name, + time.time_ns(), + {"cache_stats": get_cache_stats()}, + start, + ) + chromium_log.log_event_end( + key, time.time_ns(), {"cache_stats": get_cache_stats()}, start + ) time_spent = time.time() - t0 compilation_time_metrics[key].append(time_spent) except Exception as e: @@ -814,8 +837,17 @@ class ChromiumEventLogger: a specification of the Chromium Event JSON format. """ - @staticmethod + def __init__(self): + self.stack = ["__start__"] + # Generate a unique id for this logger, which we can use in scuba to filter down + # to a single python run. + self.id_ = str(uuid.uuid4()) + + # TODO: log to init/id tlparse after I add support for it + log.info("ChromiumEventLogger initialized with id %s", self.id_) + def log_event_start( + self, event_name: str, time_ns: int, metadata: Optional[Dict[str, Any]] = None, @@ -826,18 +858,24 @@ def log_event_start( :param time_ns Timestamp in nanoseconds :param metadata: Any extra metadata associated with this event """ - ChromiumEventLogger._log_timed_event( + event = self._log_timed_event( event_name, time_ns, "B", metadata, ) + log_chromium_event_internal(event, self.stack, self.id_) + self.stack.append(event_name) + + def reset(self) -> None: + self.stack = ["__start__"] - @staticmethod def log_event_end( + self, event_name: str, time_ns: int, metadata: Optional[Dict[str, Any]] = None, + start_time_ns: Optional[int] = None, ) -> None: """ Logs the end of a single event. This function should only be @@ -846,28 +884,53 @@ def log_event_end( :param time_ns: Timestamp in nanoseconds :param metadata: Any extra metadata associated with this event """ - ChromiumEventLogger._log_timed_event( + # These stack health checks currently never happen, + # but they're written this way to future proof any weird event + # overlaps in the future. + if event_name not in self.stack: + # Something went wrong, we never called start on this event, + # or it was skipped due to overlapping events below + log.warning("ChromiumEventLogger: Start event not in stack, ignoring") + return + + event = self._log_timed_event( event_name, time_ns, "E", metadata, ) - @staticmethod + while event_name != self.stack[-1]: + # If the event isn't the most recent one to end, pop + # off the stack until it is. + # Since event_name in self.stack, this pop is always safe + log.warning( + "ChromiumEventLogger: Detected overlapping events, fixing stack" + ) + self.stack.pop() + + log_chromium_event_internal(event, self.stack, self.id_, start_time_ns) + # Finally pop the actual event off the stack + self.stack.pop() + def _log_timed_event( + self, event_name: str, time_ns: int, phase: str, metadata: Optional[Dict[str, Any]] = None, - ) -> None: + ) -> Dict[str, Any]: """ Logs a timed event in chromium format. See log_event_start, log_event_end, etc. """ event = { "name": event_name, - "ts": time_ns / 1000, # Chromium events are in ms + "ts": time_ns / 1000, # Chromium events are in micro seconds "args": metadata, "ph": phase, + # These categories are needed in all chromium traces + "cat": "dynamo_timed", + "tid": 0, "pid": 0, # pid should be specified on all logs, we don't personally care about the actual process id } torch._logging.trace_structured( @@ -876,9 +939,10 @@ def _log_timed_event( suppress_context=False, expect_trace_id=False, # Not every chromium event will have a trace_id ) + return event - @staticmethod def log_instant_event( + self, event_name: str, time_ns: int, metadata: Optional[Dict[str, Any]] = None, @@ -895,7 +959,10 @@ def log_instant_event( "ts": time_ns / 1000, "args": metadata, "ph": "i", - "pid": 0, # pid should be specified on all logs, we don't personally care about the actual process id + # These categories are needed in all chromium traces + "cat": "dynamo_timed", + "tid": 0, + "pid": 0, "s": "p", # We use "process" level instant events so they all appear on the same row in the trace. } torch._logging.trace_structured( @@ -904,6 +971,18 @@ def log_instant_event( suppress_context=False, expect_trace_id=True, ) + # Log an instant event with the same start and end time + log_chromium_event_internal(event, self.stack, self.id_) + + +chromium_event_log = None + + +def get_chromium_event_logger() -> ChromiumEventLogger: + global chromium_event_log + if chromium_event_log is None: + chromium_event_log = ChromiumEventLogger() + return chromium_event_log @dataclasses.dataclass