diff --git a/chex/_src/dataclass_test.py b/chex/_src/dataclass_test.py index d82649c..5bf8266 100644 --- a/chex/_src/dataclass_test.py +++ b/chex/_src/dataclass_test.py @@ -13,6 +13,9 @@ # limitations under the License. # ============================================================================== """Tests for `dataclass.py`.""" + +# pytype: disable=wrong-keyword-args # dataclass_transform + import copy import dataclasses import pickle @@ -129,8 +132,8 @@ def __post_init__(self, k_init_only): self.k_non_init = self.k_int * k_init_only if test_type == 'chex': - cls = chex_dataclass(Class, mappable_dataclass=True) # pytype: disable=wrong-keyword-args # dataclass_transform - nested_cls = chex_dataclass(NestedClass, mappable_dataclass=True) # pytype: disable=wrong-keyword-args # dataclass_transform + cls = chex_dataclass(Class, mappable_dataclass=True) + nested_cls = chex_dataclass(NestedClass, mappable_dataclass=True) elif test_type == 'original': cls = mappable_dataclass(orig_dataclass(Class)) nested_cls = mappable_dataclass(orig_dataclass(NestedClass)) @@ -366,7 +369,9 @@ def f(path, x): self.assertEqual(out.b, ('b',)) def test_tree_map_with_keys_traversal_order(self): - obj = ReverseOrderNestedDataclass(d=1, c=2) # pytype: disable=wrong-arg-types # dataclass_transform + # pytype: disable=wrong-arg-types + obj = ReverseOrderNestedDataclass(d=1, c=2) + # pytype: enable=wrong-arg-types leaves = [] def f(_, x): leaves.append(x) @@ -378,11 +383,13 @@ def f(_, x): def test_dataclass_replace(self, frozen): factor = 5. obj = dummy_dataclass(frozen=frozen) - obj = obj.replace(a=obj.a.replace(c=factor * obj.a.c)) # pytype: disable=attribute-error # dataclass_transform + # pytype: disable=attribute-error # dataclass_transform + obj = obj.replace(a=obj.a.replace(c=factor * obj.a.c)) obj = obj.replace(a=obj.a.replace(d=factor * obj.a.d)) obj = obj.replace(b=factor * obj.b) target_obj = dummy_dataclass(factor=factor, frozen=frozen) asserts.assert_trees_all_close(obj, target_obj) + # pytype: enable=attribute-error def test_dataclass_requires_kwargs_by_default(self): factor = 1.0 @@ -401,7 +408,7 @@ def test_dataclass_requires_kwargs_by_default(self): def test_dataclass_mappable_dataclass_false(self): factor = 1.0 - @chex_dataclass(mappable_dataclass=False) # pytype: disable=wrong-keyword-args # dataclass_transform + @chex_dataclass(mappable_dataclass=False) class NonMappableDataclass: a: NestedDataclass b: pytypes.ArrayDevice @@ -456,9 +463,9 @@ class SimpleClass(): data: int = 1 obj_a = SimpleClass(data=1) - state = obj_a.__getstate__() # pytype: disable=attribute-error # dataclass_transform + state = getattr(obj_a, '__getstate__')() obj_b = SimpleClass(data=2) - obj_b.__setstate__(state) # pytype: disable=attribute-error # dataclass_transform + getattr(obj_b, '__setstate__')(state) self.assertEqual(obj_a, obj_b) def test_unexpected_kwargs(self): @@ -470,7 +477,7 @@ class SimpleDataclass: SimpleDataclass(a=1, b=3) with self.assertRaisesRegex(ValueError, 'init.*got unexpected kwargs'): - SimpleDataclass(a=1, b=3, c=4) # pytype: disable=wrong-keyword-args # dataclass_transform + SimpleDataclass(a=1, b=3, c=4) # pytype: disable=wrong-keyword-args def test_tuple_conversion(self): @@ -480,9 +487,9 @@ class SimpleDataclass: a: int obj = SimpleDataclass(a=2, b=1) - self.assertSequenceEqual(obj.to_tuple(), (1, 2)) # pytype: disable=attribute-error # dataclass_transform + self.assertSequenceEqual(getattr(obj, 'to_tuple')(), (1, 2)) - obj2 = SimpleDataclass.from_tuple((1, 2)) # pytype: disable=attribute-error # dataclass_transform + obj2 = getattr(SimpleDataclass, 'from_tuple')((1, 2)) self.assertEqual(obj.a, obj2.a) self.assertEqual(obj.b, obj2.b) @@ -492,7 +499,10 @@ class SimpleDataclass: ) def test_tuple_rev_conversion(self, frozen): obj = dummy_dataclass(frozen=frozen) - asserts.assert_trees_all_close(type(obj).from_tuple(obj.to_tuple()), obj) # pytype: disable=attribute-error # dataclass_transform + asserts.assert_trees_all_close( + type(obj).from_tuple(obj.to_tuple()), # pytype: disable=attribute-error + obj, + ) @parameterized.named_parameters( ('frozen', True), @@ -543,17 +553,17 @@ def test_disallowed_fields(self): # pylint:disable=unused-variable with self.assertRaisesRegex(ValueError, 'dataclass fields are disallowed'): - @chex_dataclass(mappable_dataclass=False) # pytype: disable=wrong-keyword-args # dataclass_transform + @chex_dataclass(mappable_dataclass=False) class InvalidNonMappable: from_tuple: int - @chex_dataclass(mappable_dataclass=False) # pytype: disable=wrong-keyword-args # dataclass_transform + @chex_dataclass(mappable_dataclass=False) class ValidMappable: get: int with self.assertRaisesRegex(ValueError, 'dataclass fields are disallowed'): - @chex_dataclass(mappable_dataclass=True) # pytype: disable=wrong-keyword-args # dataclass_transform + @chex_dataclass(mappable_dataclass=True) class InvalidMappable: get: int from_tuple: int @@ -563,12 +573,12 @@ class InvalidMappable: @parameterized.parameters(True, False) def test_flatten_is_leaf(self, is_mappable): - @chex_dataclass(mappable_dataclass=is_mappable) # pytype: disable=wrong-keyword-args # dataclass_transform + @chex_dataclass(mappable_dataclass=is_mappable) class _InnerDcls: v_1: int v_2: int - @chex_dataclass(mappable_dataclass=is_mappable) # pytype: disable=wrong-keyword-args # dataclass_transform + @chex_dataclass(mappable_dataclass=is_mappable) class _Dcls: str_val: str # pytype: disable=invalid-annotation # enable-bare-annotations @@ -622,7 +632,7 @@ class Bar: def test_generic_dataclass(self, mappable): T = TypeVar('T') - @chex_dataclass(mappable_dataclass=mappable) # pytype: disable=wrong-keyword-args # dataclass_transform + @chex_dataclass(mappable_dataclass=mappable) class GenericDataclass(Generic[T]): a: T # pytype: disable=invalid-annotation # enable-bare-annotations @@ -631,7 +641,7 @@ class GenericDataclass(Generic[T]): def test_mappable_eq_override(self): - @chex_dataclass(mappable_dataclass=True) # pytype: disable=wrong-keyword-args # dataclass_transform + @chex_dataclass(mappable_dataclass=True) class EqDataclass: a: pytypes.ArrayDevice @@ -681,7 +691,7 @@ def test_flatten_roundtrip_ordering(self, dcls): self.assertSequenceEqual(dataclasses.fields(obj2), dataclasses.fields(obj)) def test_flatten_respects_post_init(self): - obj = PostInitDataclass(a=1) # pytype: disable=wrong-arg-types # dataclass_transform + obj = PostInitDataclass(a=1) # pytype: disable=wrong-arg-types with self.assertRaises(ValueError): _ = jax.tree_util.tree_map(lambda x: 0, obj)