Skip to content

Commit

Permalink
Only add metaclass during type checking
Browse files Browse the repository at this point in the history
  • Loading branch information
jheek committed Aug 9, 2021
1 parent 40104c4 commit 2d05863
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 19 deletions.
9 changes: 6 additions & 3 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,12 @@ def reimport(self, other):

# This metaclass + decorator is used by static analysis tools recognize that
# Module behaves as a dataclass (attributes are constructor args).
@__dataclass_transform__()
class ModuleMeta(type):
pass
if typing.TYPE_CHECKING:
@__dataclass_transform__()
class ModuleMeta(type):
pass
else:
ModuleMeta = type


class Module(metaclass=ModuleMeta):
Expand Down
10 changes: 7 additions & 3 deletions flax/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"""Utilities for defining custom classes that can be used with jax transformations.
"""

import typing
from typing import TypeVar, Callable, Tuple, Union, Any

from . import serialization
Expand Down Expand Up @@ -163,9 +164,12 @@ def field(pytree_node=True, **kwargs):
TNode = TypeVar('TNode', bound='PyTreeNode')


@__dataclass_transform__()
class PyTreeNodeMeta(type):
pass
if typing.TYPE_CHECKING:
@__dataclass_transform__()
class PyTreeNodeMeta(type):
pass
else:
PyTreeNodeMeta = type


class PyTreeNode(metaclass=PyTreeNodeMeta):
Expand Down
22 changes: 9 additions & 13 deletions tests/linen/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import dataclasses
import functools
import operator
import sys

from absl.testing import absltest

Expand Down Expand Up @@ -1370,18 +1369,15 @@ class Foo(NamedTuple):
self.assertEqual(type(xs), Foo) # equality test for NamedTuple doesn't check class!

def test_generic_multiple_inheritance(self):
if sys.version_info.major == 3 and sys.version_info.minor >= 7:
# Python 3.6 typing.Generic causes metaclass conflicts.
# This was resolved in Python 3.7
T = TypeVar('T')
class MyComponent(nn.Module, Generic[T]):
pass
class MyModule(nn.Module):
submodule: MyComponent[jnp.ndarray]
class MyComponent2(Generic[T], nn.Module):
pass
class MyModule2(nn.Module):
submodule: MyComponent2[jnp.ndarray]
T = TypeVar('T')
class MyComponent(nn.Module, Generic[T]):
pass
class MyModule(nn.Module):
submodule: MyComponent[jnp.ndarray]
class MyComponent2(Generic[T], nn.Module):
pass
class MyModule2(nn.Module):
submodule: MyComponent2[jnp.ndarray]


if __name__ == '__main__':
Expand Down

0 comments on commit 2d05863

Please sign in to comment.