Skip to content

Commit

Permalink
Fix or ignore some pytype errors.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 558848204
  • Loading branch information
ChexDev authored and ChexDev committed Aug 21, 2023
1 parent 3addcce commit 37e1061
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions chex/_src/dataclass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def __post_init__(self, k_init_only):
self.k_non_init = self.k_int * k_init_only

if test_type == 'chex':
cls = chex_dataclass(Class, mappable_dataclass=True)
nested_cls = chex_dataclass(NestedClass, mappable_dataclass=True)
cls = chex_dataclass(Class, mappable_dataclass=True) # pytype: disable=wrong-keyword-args # dataclass_transform
nested_cls = chex_dataclass(NestedClass, mappable_dataclass=True) # pytype: disable=wrong-keyword-args # dataclass_transform
elif test_type == 'original':
cls = mappable_dataclass(orig_dataclass(Class))
nested_cls = mappable_dataclass(orig_dataclass(NestedClass))
Expand Down Expand Up @@ -366,7 +366,7 @@ def f(path, x):
self.assertEqual(out.b, ('b',))

def test_tree_map_with_keys_traversal_order(self):
obj = ReverseOrderNestedDataclass(d=1, c=2)
obj = ReverseOrderNestedDataclass(d=1, c=2) # pytype: disable=wrong-arg-types # dataclass_transform
leaves = []
def f(_, x):
leaves.append(x)
Expand All @@ -378,7 +378,7 @@ def f(_, x):
def test_dataclass_replace(self, frozen):
factor = 5.
obj = dummy_dataclass(frozen=frozen)
obj = obj.replace(a=obj.a.replace(c=factor * obj.a.c))
obj = obj.replace(a=obj.a.replace(c=factor * obj.a.c)) # pytype: disable=attribute-error # dataclass_transform
obj = obj.replace(a=obj.a.replace(d=factor * obj.a.d))
obj = obj.replace(b=factor * obj.b)
target_obj = dummy_dataclass(factor=factor, frozen=frozen)
Expand All @@ -401,7 +401,7 @@ def test_dataclass_requires_kwargs_by_default(self):
def test_dataclass_mappable_dataclass_false(self):
factor = 1.0

@chex_dataclass(mappable_dataclass=False)
@chex_dataclass(mappable_dataclass=False) # pytype: disable=wrong-keyword-args # dataclass_transform
class NonMappableDataclass:
a: NestedDataclass
b: pytypes.ArrayDevice
Expand Down Expand Up @@ -456,9 +456,9 @@ class SimpleClass():
data: int = 1

obj_a = SimpleClass(data=1)
state = obj_a.__getstate__()
state = obj_a.__getstate__() # pytype: disable=attribute-error # dataclass_transform
obj_b = SimpleClass(data=2)
obj_b.__setstate__(state)
obj_b.__setstate__(state) # pytype: disable=attribute-error # dataclass_transform
self.assertEqual(obj_a, obj_b)

def test_unexpected_kwargs(self):
Expand All @@ -470,7 +470,7 @@ class SimpleDataclass:

SimpleDataclass(a=1, b=3)
with self.assertRaisesRegex(ValueError, 'init.*got unexpected kwargs'):
SimpleDataclass(a=1, b=3, c=4)
SimpleDataclass(a=1, b=3, c=4) # pytype: disable=wrong-keyword-args # dataclass_transform

def test_tuple_conversion(self):

Expand All @@ -480,9 +480,9 @@ class SimpleDataclass:
a: int

obj = SimpleDataclass(a=2, b=1)
self.assertSequenceEqual(obj.to_tuple(), (1, 2))
self.assertSequenceEqual(obj.to_tuple(), (1, 2)) # pytype: disable=attribute-error # dataclass_transform

obj2 = SimpleDataclass.from_tuple((1, 2))
obj2 = SimpleDataclass.from_tuple((1, 2)) # pytype: disable=attribute-error # dataclass_transform
self.assertEqual(obj.a, obj2.a)
self.assertEqual(obj.b, obj2.b)

Expand All @@ -492,7 +492,7 @@ class SimpleDataclass:
)
def test_tuple_rev_conversion(self, frozen):
obj = dummy_dataclass(frozen=frozen)
asserts.assert_trees_all_close(type(obj).from_tuple(obj.to_tuple()), obj)
asserts.assert_trees_all_close(type(obj).from_tuple(obj.to_tuple()), obj) # pytype: disable=attribute-error # dataclass_transform

@parameterized.named_parameters(
('frozen', True),
Expand Down Expand Up @@ -543,17 +543,17 @@ def test_disallowed_fields(self):
# pylint:disable=unused-variable
with self.assertRaisesRegex(ValueError, 'dataclass fields are disallowed'):

@chex_dataclass(mappable_dataclass=False)
@chex_dataclass(mappable_dataclass=False) # pytype: disable=wrong-keyword-args # dataclass_transform
class InvalidNonMappable:
from_tuple: int

@chex_dataclass(mappable_dataclass=False)
@chex_dataclass(mappable_dataclass=False) # pytype: disable=wrong-keyword-args # dataclass_transform
class ValidMappable:
get: int

with self.assertRaisesRegex(ValueError, 'dataclass fields are disallowed'):

@chex_dataclass(mappable_dataclass=True)
@chex_dataclass(mappable_dataclass=True) # pytype: disable=wrong-keyword-args # dataclass_transform
class InvalidMappable:
get: int
from_tuple: int
Expand All @@ -563,12 +563,12 @@ class InvalidMappable:
@parameterized.parameters(True, False)
def test_flatten_is_leaf(self, is_mappable):

@chex_dataclass(mappable_dataclass=is_mappable)
@chex_dataclass(mappable_dataclass=is_mappable) # pytype: disable=wrong-keyword-args # dataclass_transform
class _InnerDcls:
v_1: int
v_2: int

@chex_dataclass(mappable_dataclass=is_mappable)
@chex_dataclass(mappable_dataclass=is_mappable) # pytype: disable=wrong-keyword-args # dataclass_transform
class _Dcls:
str_val: str
# pytype: disable=invalid-annotation # enable-bare-annotations
Expand Down Expand Up @@ -622,7 +622,7 @@ class Bar:
def test_generic_dataclass(self, mappable):
T = TypeVar('T')

@chex_dataclass(mappable_dataclass=mappable)
@chex_dataclass(mappable_dataclass=mappable) # pytype: disable=wrong-keyword-args # dataclass_transform
class GenericDataclass(Generic[T]):
a: T # pytype: disable=invalid-annotation # enable-bare-annotations

Expand All @@ -631,7 +631,7 @@ class GenericDataclass(Generic[T]):

def test_mappable_eq_override(self):

@chex_dataclass(mappable_dataclass=True)
@chex_dataclass(mappable_dataclass=True) # pytype: disable=wrong-keyword-args # dataclass_transform
class EqDataclass:
a: pytypes.ArrayDevice

Expand Down Expand Up @@ -681,7 +681,7 @@ def test_flatten_roundtrip_ordering(self, dcls):
self.assertSequenceEqual(dataclasses.fields(obj2), dataclasses.fields(obj))

def test_flatten_respects_post_init(self):
obj = PostInitDataclass(a=1)
obj = PostInitDataclass(a=1) # pytype: disable=wrong-arg-types # dataclass_transform
with self.assertRaises(ValueError):
_ = jax.tree_util.tree_map(lambda x: 0, obj)

Expand Down

0 comments on commit 37e1061

Please sign in to comment.