Skip to content

Commit

Permalink
Merge pull request #1478 from jheek:add-dataclass-autocomplete-support
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 389608780
  • Loading branch information
Flax Authors committed Aug 9, 2021
2 parents f558f49 + 2d05863 commit 9350b44
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 8 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ vNext
-
- Fix the serialization of named tuples. Tuple fields are no longer stored in the state dict and the named tuple class is no longer recreated ([bug](https://github.com/google/flax/issues/1429)).
-
-
- linen Modules and dataclasses made with `flax.struct.dataclass` or `flax.struct.PyTreeNode` are now correctly recognized as dataclasses by static analysis tools like PyLance. Autocomplete of constructors has been verified to work with VSCode.
-
-
- `flax.linen.Conv` no longer interprets an int past as kernel_size as a 1d convolution. Instead a type error is raised stating that
Expand Down
13 changes: 12 additions & 1 deletion flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from flax.core import Scope
from flax.core.scope import CollectionFilter, DenyList, Variable, VariableDict, FrozenVariableDict, union_filters
from flax.core.frozen_dict import FrozenDict, freeze
from flax.struct import __dataclass_transform__

# from .dotgetter import DotGetter

Expand Down Expand Up @@ -366,7 +367,17 @@ def reimport(self, other):
# -----------------------------------------------------------------------------


class Module:
# This metaclass + decorator is used by static analysis tools recognize that
# Module behaves as a dataclass (attributes are constructor args).
if typing.TYPE_CHECKING:
@__dataclass_transform__()
class ModuleMeta(type):
pass
else:
ModuleMeta = type


class Module(metaclass=ModuleMeta):
"""Base class for all neural network modules. Layers and models should subclass this class.
All Flax Modules are Python 3.7
Expand Down
37 changes: 33 additions & 4 deletions flax/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
"""Utilities for defining custom classes that can be used with jax transformations.
"""

from typing import TypeVar
import typing
from typing import TypeVar, Callable, Tuple, Union, Any

from . import serialization

Expand All @@ -39,7 +40,26 @@
import jax


def dataclass(clz: type):

# This decorator is interpreted by static analysis tools as a hint
# that a decorator or metaclass causes dataclass-like behavior.
# See https://github.com/microsoft/pyright/blob/main/specs/dataclass_transforms.md
# for more information about the __dataclass_transform__ magic.
_T = TypeVar("_T")
def __dataclass_transform__(
*,
eq_default: bool = True,
order_default: bool = False,
kw_only_default: bool = False,
field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
) -> Callable[[_T], _T]:
# If used within a stub file, the following implementation can be
# replaced with "...".
return lambda a: a


@__dataclass_transform__()
def dataclass(clz: _T) -> _T:
"""Create a class which can be passed to functional transformations.
NOTE: Inherit from ``PyTreeNode`` instead to avoid type checking issues when
Expand Down Expand Up @@ -77,7 +97,8 @@ def __apply__(self, *args):
Returns:
The new class.
"""
data_clz = dataclasses.dataclass(frozen=True)(clz)
# workaround for pytype not recognizing __dataclass_fields__
data_clz: Any = dataclasses.dataclass(frozen=True)(clz)
meta_fields = []
data_fields = []
for name, field_info in data_clz.__dataclass_fields__.items():
Expand Down Expand Up @@ -143,7 +164,15 @@ def field(pytree_node=True, **kwargs):
TNode = TypeVar('TNode', bound='PyTreeNode')


class PyTreeNode():
if typing.TYPE_CHECKING:
@__dataclass_transform__()
class PyTreeNodeMeta(type):
pass
else:
PyTreeNodeMeta = type


class PyTreeNode(metaclass=PyTreeNodeMeta):
"""Base class for dataclasses that should act like a JAX pytree node.
See ``flax.struct.dataclass`` for the ``jax.tree_util`` behavior.
Expand Down
2 changes: 0 additions & 2 deletions tests/linen/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
import functools
import operator



from absl.testing import absltest

import jax
Expand Down

0 comments on commit 9350b44

Please sign in to comment.