From ab687a0325ee2184b68c2d6f6d0e8645145a14ce Mon Sep 17 00:00:00 2001 From: "rzou (Meta Employee)" Date: Thu, 13 Jun 2024 09:54:03 -0700 Subject: [PATCH] Change Dynamo's custom ops warning message to be less spammy (#128456) Summary: This is a short-term fix (for 2.4). In the longer term we should fix https://github.com/pytorch/pytorch/issues/128430 The problem is that warnings.warn that are inside Dynamo print all the time. Python warnings are supposed to print once, unless their cache is reset: Dynamo ends up resetting that cache everytime it runs. As a workaround we provide our own warn_once cache that is keyed on the warning msg. I am not worried about this increasing memory usage because that's effectively what python's warnings.warn cache does. X-link: https://github.com/pytorch/pytorch/pull/128456 Approved by: https://github.com/anijain2305 Reviewed By: clee2000 Differential Revision: D58501328 Pulled By: zou3519 fbshipit-source-id: 99dcfddfae27de2f1de6e9685aa990a738531199 --- .../dynamo/dynamobench/_dynamo/utils.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index fe2f096ec4..e171fc905e 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -23,6 +23,7 @@ import time import types import typing +import warnings import weakref from contextlib import contextmanager from functools import lru_cache, wraps @@ -2751,3 +2752,18 @@ def __init__(self, s): def __repr__(self): return self.s + + +warn_once_cache: Set[str] = set() + + +def warn_once(msg, stacklevel=1): + # Dynamo causes all warnings.warn (in user code and in Dynamo code) to print all the time. + # https://github.com/pytorch/pytorch/issues/128427. + # warn_once is a workaround: if the msg has been warned on before, then we will not + # warn again. + # NB: it's totally ok to store a cache of all the strings: this is what warnings.warn does as well. + if msg in warn_once_cache: + return + warn_once_cache.add(msg) + warnings.warn(msg, stacklevel=stacklevel + 1)