diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index fe2f096ec4..e171fc905e 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -23,6 +23,7 @@ import time import types import typing +import warnings import weakref from contextlib import contextmanager from functools import lru_cache, wraps @@ -2751,3 +2752,18 @@ def __init__(self, s): def __repr__(self): return self.s + + +warn_once_cache: Set[str] = set() + + +def warn_once(msg, stacklevel=1): + # Dynamo causes all warnings.warn (in user code and in Dynamo code) to print all the time. + # https://github.com/pytorch/pytorch/issues/128427. + # warn_once is a workaround: if the msg has been warned on before, then we will not + # warn again. + # NB: it's totally ok to store a cache of all the strings: this is what warnings.warn does as well. + if msg in warn_once_cache: + return + warn_once_cache.add(msg) + warnings.warn(msg, stacklevel=stacklevel + 1)