Skip to content

Commit

Permalink
Fix linter's warnings.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 558854198
  • Loading branch information
hbq1 authored and ChexDev committed Aug 21, 2023
1 parent 37e1061 commit fb08977
Showing 1 changed file with 29 additions and 19 deletions.
48 changes: 29 additions & 19 deletions chex/_src/dataclass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.
# ==============================================================================
"""Tests for `dataclass.py`."""

# pytype: disable=wrong-keyword-args # dataclass_transform

import copy
import dataclasses
import pickle
Expand Down Expand Up @@ -129,8 +132,8 @@ def __post_init__(self, k_init_only):
self.k_non_init = self.k_int * k_init_only

if test_type == 'chex':
cls = chex_dataclass(Class, mappable_dataclass=True) # pytype: disable=wrong-keyword-args # dataclass_transform
nested_cls = chex_dataclass(NestedClass, mappable_dataclass=True) # pytype: disable=wrong-keyword-args # dataclass_transform
cls = chex_dataclass(Class, mappable_dataclass=True)
nested_cls = chex_dataclass(NestedClass, mappable_dataclass=True)
elif test_type == 'original':
cls = mappable_dataclass(orig_dataclass(Class))
nested_cls = mappable_dataclass(orig_dataclass(NestedClass))
Expand Down Expand Up @@ -366,7 +369,9 @@ def f(path, x):
self.assertEqual(out.b, ('b',))

def test_tree_map_with_keys_traversal_order(self):
obj = ReverseOrderNestedDataclass(d=1, c=2) # pytype: disable=wrong-arg-types # dataclass_transform
# pytype: disable=wrong-arg-types
obj = ReverseOrderNestedDataclass(d=1, c=2)
# pytype: enable=wrong-arg-types
leaves = []
def f(_, x):
leaves.append(x)
Expand All @@ -378,11 +383,13 @@ def f(_, x):
def test_dataclass_replace(self, frozen):
factor = 5.
obj = dummy_dataclass(frozen=frozen)
obj = obj.replace(a=obj.a.replace(c=factor * obj.a.c)) # pytype: disable=attribute-error # dataclass_transform
# pytype: disable=attribute-error # dataclass_transform
obj = obj.replace(a=obj.a.replace(c=factor * obj.a.c))
obj = obj.replace(a=obj.a.replace(d=factor * obj.a.d))
obj = obj.replace(b=factor * obj.b)
target_obj = dummy_dataclass(factor=factor, frozen=frozen)
asserts.assert_trees_all_close(obj, target_obj)
# pytype: enable=attribute-error

def test_dataclass_requires_kwargs_by_default(self):
factor = 1.0
Expand All @@ -401,7 +408,7 @@ def test_dataclass_requires_kwargs_by_default(self):
def test_dataclass_mappable_dataclass_false(self):
factor = 1.0

@chex_dataclass(mappable_dataclass=False) # pytype: disable=wrong-keyword-args # dataclass_transform
@chex_dataclass(mappable_dataclass=False)
class NonMappableDataclass:
a: NestedDataclass
b: pytypes.ArrayDevice
Expand Down Expand Up @@ -456,9 +463,9 @@ class SimpleClass():
data: int = 1

obj_a = SimpleClass(data=1)
state = obj_a.__getstate__() # pytype: disable=attribute-error # dataclass_transform
state = getattr(obj_a, '__getstate__')()
obj_b = SimpleClass(data=2)
obj_b.__setstate__(state) # pytype: disable=attribute-error # dataclass_transform
getattr(obj_b, '__setstate__')(state)
self.assertEqual(obj_a, obj_b)

def test_unexpected_kwargs(self):
Expand All @@ -470,7 +477,7 @@ class SimpleDataclass:

SimpleDataclass(a=1, b=3)
with self.assertRaisesRegex(ValueError, 'init.*got unexpected kwargs'):
SimpleDataclass(a=1, b=3, c=4) # pytype: disable=wrong-keyword-args # dataclass_transform
SimpleDataclass(a=1, b=3, c=4) # pytype: disable=wrong-keyword-args

def test_tuple_conversion(self):

Expand All @@ -480,9 +487,9 @@ class SimpleDataclass:
a: int

obj = SimpleDataclass(a=2, b=1)
self.assertSequenceEqual(obj.to_tuple(), (1, 2)) # pytype: disable=attribute-error # dataclass_transform
self.assertSequenceEqual(getattr(obj, 'to_tuple')(), (1, 2))

obj2 = SimpleDataclass.from_tuple((1, 2)) # pytype: disable=attribute-error # dataclass_transform
obj2 = getattr(SimpleDataclass, 'from_tuple')((1, 2))
self.assertEqual(obj.a, obj2.a)
self.assertEqual(obj.b, obj2.b)

Expand All @@ -492,7 +499,10 @@ class SimpleDataclass:
)
def test_tuple_rev_conversion(self, frozen):
obj = dummy_dataclass(frozen=frozen)
asserts.assert_trees_all_close(type(obj).from_tuple(obj.to_tuple()), obj) # pytype: disable=attribute-error # dataclass_transform
asserts.assert_trees_all_close(
type(obj).from_tuple(obj.to_tuple()), # pytype: disable=attribute-error
obj,
)

@parameterized.named_parameters(
('frozen', True),
Expand Down Expand Up @@ -543,17 +553,17 @@ def test_disallowed_fields(self):
# pylint:disable=unused-variable
with self.assertRaisesRegex(ValueError, 'dataclass fields are disallowed'):

@chex_dataclass(mappable_dataclass=False) # pytype: disable=wrong-keyword-args # dataclass_transform
@chex_dataclass(mappable_dataclass=False)
class InvalidNonMappable:
from_tuple: int

@chex_dataclass(mappable_dataclass=False) # pytype: disable=wrong-keyword-args # dataclass_transform
@chex_dataclass(mappable_dataclass=False)
class ValidMappable:
get: int

with self.assertRaisesRegex(ValueError, 'dataclass fields are disallowed'):

@chex_dataclass(mappable_dataclass=True) # pytype: disable=wrong-keyword-args # dataclass_transform
@chex_dataclass(mappable_dataclass=True)
class InvalidMappable:
get: int
from_tuple: int
Expand All @@ -563,12 +573,12 @@ class InvalidMappable:
@parameterized.parameters(True, False)
def test_flatten_is_leaf(self, is_mappable):

@chex_dataclass(mappable_dataclass=is_mappable) # pytype: disable=wrong-keyword-args # dataclass_transform
@chex_dataclass(mappable_dataclass=is_mappable)
class _InnerDcls:
v_1: int
v_2: int

@chex_dataclass(mappable_dataclass=is_mappable) # pytype: disable=wrong-keyword-args # dataclass_transform
@chex_dataclass(mappable_dataclass=is_mappable)
class _Dcls:
str_val: str
# pytype: disable=invalid-annotation # enable-bare-annotations
Expand Down Expand Up @@ -622,7 +632,7 @@ class Bar:
def test_generic_dataclass(self, mappable):
T = TypeVar('T')

@chex_dataclass(mappable_dataclass=mappable) # pytype: disable=wrong-keyword-args # dataclass_transform
@chex_dataclass(mappable_dataclass=mappable)
class GenericDataclass(Generic[T]):
a: T # pytype: disable=invalid-annotation # enable-bare-annotations

Expand All @@ -631,7 +641,7 @@ class GenericDataclass(Generic[T]):

def test_mappable_eq_override(self):

@chex_dataclass(mappable_dataclass=True) # pytype: disable=wrong-keyword-args # dataclass_transform
@chex_dataclass(mappable_dataclass=True)
class EqDataclass:
a: pytypes.ArrayDevice

Expand Down Expand Up @@ -681,7 +691,7 @@ def test_flatten_roundtrip_ordering(self, dcls):
self.assertSequenceEqual(dataclasses.fields(obj2), dataclasses.fields(obj))

def test_flatten_respects_post_init(self):
obj = PostInitDataclass(a=1) # pytype: disable=wrong-arg-types # dataclass_transform
obj = PostInitDataclass(a=1) # pytype: disable=wrong-arg-types
with self.assertRaises(ValueError):
_ = jax.tree_util.tree_map(lambda x: 0, obj)

Expand Down

0 comments on commit fb08977

Please sign in to comment.