From f23638498c7ebd44267b2c4095b9fcc1f4bfed4c Mon Sep 17 00:00:00 2001 From: Maxwell Bileschi Date: Tue, 15 Aug 2023 12:39:29 -0700 Subject: [PATCH] feature to turn off gc for imports, enabled by environment variable PiperOrigin-RevId: 557215795 --- ...ert_gc_disabled_during_import_test_util.py | 19 +++++ t5x/disable_gc_during_import.py | 63 +++++++++++++++ t5x/disable_gc_during_import_test.py | 80 +++++++++++++++++++ t5x/train.py | 4 +- 4 files changed, 165 insertions(+), 1 deletion(-) create mode 100644 t5x/assert_gc_disabled_during_import_test_util.py create mode 100644 t5x/disable_gc_during_import.py create mode 100644 t5x/disable_gc_during_import_test.py diff --git a/t5x/assert_gc_disabled_during_import_test_util.py b/t5x/assert_gc_disabled_during_import_test_util.py new file mode 100644 index 000000000..b7220eb1a --- /dev/null +++ b/t5x/assert_gc_disabled_during_import_test_util.py @@ -0,0 +1,19 @@ +# Copyright 2023 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test utility for disable_gc_during_import_test.py.""" +import gc + +if gc.isenabled(): + raise ValueError("Expected gc to be disabled; was enabled.") diff --git a/t5x/disable_gc_during_import.py b/t5x/disable_gc_during_import.py new file mode 100644 index 000000000..e87772c28 --- /dev/null +++ b/t5x/disable_gc_during_import.py @@ -0,0 +1,63 @@ +# Copyright 2023 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Disables gc during each top-level import. + +Only takes effect when environment variable +EXPERIMENTAL_DISABLE_GC_DURING_IMPORT +is true. + +Some libraries like SeqIO have lots of side-effects during import time. +In some cases, disabling garbage collection for each top-level import can save +minutes of startup time. + +This should be _relatively_ safe, because we don't expect that it's often that +1. There's sufficient memory pressure during an import to cause an OOM, and +2. That memory pressure would have been sufficiently alleviated by garbage + collection. +""" +import builtins +import contextlib +import gc +import os + + +@contextlib.contextmanager +def disabled_gc(): + """When used as context manager, prevents garbage collection in scope.""" + if not gc.isenabled(): + # GC is already disabled; don't make any changes. + yield + return + + gc.disable() + try: + yield + finally: + # We know that the original state was enabled because + # we didn't return above. + gc.enable() + + +_original_importlib_import = builtins.__import__ + + +def gc_disabled_import(*args, **kwargs): + with disabled_gc(): + return _original_importlib_import(*args, **kwargs) + + +def try_disable_gc_during_import(): + if os.environ.get('EXPERIMENTAL_DISABLE_GC_DURING_IMPORT'): + builtins.__import__ = gc_disabled_import diff --git a/t5x/disable_gc_during_import_test.py b/t5x/disable_gc_during_import_test.py new file mode 100644 index 000000000..b3f4f972e --- /dev/null +++ b/t5x/disable_gc_during_import_test.py @@ -0,0 +1,80 @@ +# Copyright 2023 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for disable_gc_during_import.""" +# pylint: disable=g-import-not-at-top,unused-import + +import builtins +import gc +import importlib +import os +import sys +from absl.testing import absltest +from absl.testing import parameterized +from t5x import disable_gc_during_import + +_ORIGINAL_BUILTIN_IMPORT_FN = builtins.__import__ + + +def assert_gc_disabled_during_import(): + # Side effect of importing module is asserting gc is disabled. + if sys.modules.get("t5x.assert_gc_disabled_during_import_test_util"): + sys.modules.pop("t5x.assert_gc_disabled_during_import_test_util", None) + + import t5x.assert_gc_disabled_during_import_test_util + + +class DisableGcDuringImportTest(parameterized.TestCase): + + def setUp(self): + super(DisableGcDuringImportTest, self).setUp() + builtins.__import__ = _ORIGINAL_BUILTIN_IMPORT_FN + os.environ["EXPERIMENTAL_DISABLE_GC_DURING_IMPORT"] = "true" + + def tearDown(self): + super(DisableGcDuringImportTest, self).tearDown() + builtins.__import__ = _ORIGINAL_BUILTIN_IMPORT_FN + os.environ.pop("EXPERIMENTAL_DISABLE_GC_DURING_IMPORT") + + def test_gc_enabled_after_one_import_import_builtin(self): + disable_gc_during_import.try_disable_gc_during_import() + + self.assertTrue(gc.isenabled()) + # Some arbitrary import; not particularly important. + import enum + + assert_gc_disabled_during_import() + + self.assertTrue(gc.isenabled()) + + def test_gc_enabled_after_two_imports_import_builtin(self): + disable_gc_during_import.try_disable_gc_during_import() + # from t5x import disable_gc_during_import + + self.assertTrue(gc.isenabled()) + # Some arbitrary imports; not particularly important which ones. + import contextlib + import enum + + assert_gc_disabled_during_import() + + self.assertTrue(gc.isenabled()) + + def test_test_utils_appropriately_detect_when_gc_enabled(self): + with self.assertRaisesRegex(ValueError, "Expected gc to be disabled"): + assert_gc_disabled_during_import() + + +if __name__ == "__main__": + absltest.main() diff --git a/t5x/train.py b/t5x/train.py index 61682edc8..d6bfd6ab0 100644 --- a/t5x/train.py +++ b/t5x/train.py @@ -16,6 +16,8 @@ """ +# pylint: disable=g-import-not-at-top + import functools import gc import math @@ -25,7 +27,6 @@ # Set Linen to add profiling information when constructing Modules. # Must be set before flax imports. -# pylint:disable=g-import-not-at-top os.environ['FLAX_PROFILE'] = 'true' # TODO(adarob): Re-enable once users are notified and tests are updated. os.environ['FLAX_LAZY_RNG'] = 'no' @@ -47,6 +48,7 @@ import tensorflow as tf # pylint:enable=g-import-not-at-top +# pylint:enable=g-import-not-at-top # Automatically search for gin files relative to the T5X package. _DEFAULT_GIN_SEARCH_PATHS = [