Skip to content

Commit

Permalink
feature to turn off gc for imports, enabled by environment variable
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 557215795
  • Loading branch information
mlbileschi authored and t5-copybara committed Sep 13, 2023
1 parent ad9e0c9 commit f236384
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 1 deletion.
19 changes: 19 additions & 0 deletions t5x/assert_gc_disabled_during_import_test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2023 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test utility for disable_gc_during_import_test.py."""
import gc

if gc.isenabled():
raise ValueError("Expected gc to be disabled; was enabled.")
63 changes: 63 additions & 0 deletions t5x/disable_gc_during_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2023 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Disables gc during each top-level import.
Only takes effect when environment variable
EXPERIMENTAL_DISABLE_GC_DURING_IMPORT
is true.
Some libraries like SeqIO have lots of side-effects during import time.
In some cases, disabling garbage collection for each top-level import can save
minutes of startup time.
This should be _relatively_ safe, because we don't expect that it's often that
1. There's sufficient memory pressure during an import to cause an OOM, and
2. That memory pressure would have been sufficiently alleviated by garbage
collection.
"""
import builtins
import contextlib
import gc
import os


@contextlib.contextmanager
def disabled_gc():
"""When used as context manager, prevents garbage collection in scope."""
if not gc.isenabled():
# GC is already disabled; don't make any changes.
yield
return

gc.disable()
try:
yield
finally:
# We know that the original state was enabled because
# we didn't return above.
gc.enable()


_original_importlib_import = builtins.__import__


def gc_disabled_import(*args, **kwargs):
with disabled_gc():
return _original_importlib_import(*args, **kwargs)


def try_disable_gc_during_import():
if os.environ.get('EXPERIMENTAL_DISABLE_GC_DURING_IMPORT'):
builtins.__import__ = gc_disabled_import
80 changes: 80 additions & 0 deletions t5x/disable_gc_during_import_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2023 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for disable_gc_during_import."""
# pylint: disable=g-import-not-at-top,unused-import

import builtins
import gc
import importlib
import os
import sys
from absl.testing import absltest
from absl.testing import parameterized
from t5x import disable_gc_during_import

_ORIGINAL_BUILTIN_IMPORT_FN = builtins.__import__


def assert_gc_disabled_during_import():
# Side effect of importing module is asserting gc is disabled.
if sys.modules.get("t5x.assert_gc_disabled_during_import_test_util"):
sys.modules.pop("t5x.assert_gc_disabled_during_import_test_util", None)

import t5x.assert_gc_disabled_during_import_test_util


class DisableGcDuringImportTest(parameterized.TestCase):

def setUp(self):
super(DisableGcDuringImportTest, self).setUp()
builtins.__import__ = _ORIGINAL_BUILTIN_IMPORT_FN
os.environ["EXPERIMENTAL_DISABLE_GC_DURING_IMPORT"] = "true"

def tearDown(self):
super(DisableGcDuringImportTest, self).tearDown()
builtins.__import__ = _ORIGINAL_BUILTIN_IMPORT_FN
os.environ.pop("EXPERIMENTAL_DISABLE_GC_DURING_IMPORT")

def test_gc_enabled_after_one_import_import_builtin(self):
disable_gc_during_import.try_disable_gc_during_import()

self.assertTrue(gc.isenabled())
# Some arbitrary import; not particularly important.
import enum

assert_gc_disabled_during_import()

self.assertTrue(gc.isenabled())

def test_gc_enabled_after_two_imports_import_builtin(self):
disable_gc_during_import.try_disable_gc_during_import()
# from t5x import disable_gc_during_import

self.assertTrue(gc.isenabled())
# Some arbitrary imports; not particularly important which ones.
import contextlib
import enum

assert_gc_disabled_during_import()

self.assertTrue(gc.isenabled())

def test_test_utils_appropriately_detect_when_gc_enabled(self):
with self.assertRaisesRegex(ValueError, "Expected gc to be disabled"):
assert_gc_disabled_during_import()


if __name__ == "__main__":
absltest.main()
4 changes: 3 additions & 1 deletion t5x/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
"""

# pylint: disable=g-import-not-at-top

import functools
import gc
import math
Expand All @@ -25,7 +27,6 @@

# Set Linen to add profiling information when constructing Modules.
# Must be set before flax imports.
# pylint:disable=g-import-not-at-top
os.environ['FLAX_PROFILE'] = 'true'
# TODO(adarob): Re-enable once users are notified and tests are updated.
os.environ['FLAX_LAZY_RNG'] = 'no'
Expand All @@ -47,6 +48,7 @@
import tensorflow as tf
# pylint:enable=g-import-not-at-top

# pylint:enable=g-import-not-at-top

# Automatically search for gin files relative to the T5X package.
_DEFAULT_GIN_SEARCH_PATHS = [
Expand Down

0 comments on commit f236384

Please sign in to comment.