From 2d05863510bab1fd1adc7180df4154e63bf0b3e9 Mon Sep 17 00:00:00 2001 From: jheek Date: Fri, 6 Aug 2021 12:29:02 +0000 Subject: [PATCH] Only add metaclass during type checking --- flax/linen/module.py | 9 ++++++--- flax/struct.py | 10 +++++++--- tests/linen/module_test.py | 22 +++++++++------------- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/flax/linen/module.py b/flax/linen/module.py index 5a54bb37d1..0ef3a457bd 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -399,9 +399,12 @@ def reimport(self, other): # This metaclass + decorator is used by static analysis tools recognize that # Module behaves as a dataclass (attributes are constructor args). -@__dataclass_transform__() -class ModuleMeta(type): - pass +if typing.TYPE_CHECKING: + @__dataclass_transform__() + class ModuleMeta(type): + pass +else: + ModuleMeta = type class Module(metaclass=ModuleMeta): diff --git a/flax/struct.py b/flax/struct.py index 3bddde6490..67f0b05fb9 100644 --- a/flax/struct.py +++ b/flax/struct.py @@ -30,6 +30,7 @@ """Utilities for defining custom classes that can be used with jax transformations. """ +import typing from typing import TypeVar, Callable, Tuple, Union, Any from . import serialization @@ -163,9 +164,12 @@ def field(pytree_node=True, **kwargs): TNode = TypeVar('TNode', bound='PyTreeNode') -@__dataclass_transform__() -class PyTreeNodeMeta(type): - pass +if typing.TYPE_CHECKING: + @__dataclass_transform__() + class PyTreeNodeMeta(type): + pass +else: + PyTreeNodeMeta = type class PyTreeNode(metaclass=PyTreeNodeMeta): diff --git a/tests/linen/module_test.py b/tests/linen/module_test.py index 3b4af96063..3c0514a65b 100644 --- a/tests/linen/module_test.py +++ b/tests/linen/module_test.py @@ -17,7 +17,6 @@ import dataclasses import functools import operator -import sys from absl.testing import absltest @@ -1370,18 +1369,15 @@ class Foo(NamedTuple): self.assertEqual(type(xs), Foo) # equality test for NamedTuple doesn't check class! def test_generic_multiple_inheritance(self): - if sys.version_info.major == 3 and sys.version_info.minor >= 7: - # Python 3.6 typing.Generic causes metaclass conflicts. - # This was resolved in Python 3.7 - T = TypeVar('T') - class MyComponent(nn.Module, Generic[T]): - pass - class MyModule(nn.Module): - submodule: MyComponent[jnp.ndarray] - class MyComponent2(Generic[T], nn.Module): - pass - class MyModule2(nn.Module): - submodule: MyComponent2[jnp.ndarray] + T = TypeVar('T') + class MyComponent(nn.Module, Generic[T]): + pass + class MyModule(nn.Module): + submodule: MyComponent[jnp.ndarray] + class MyComponent2(Generic[T], nn.Module): + pass + class MyModule2(nn.Module): + submodule: MyComponent2[jnp.ndarray] if __name__ == '__main__':