diff --git a/chex/_src/asserts.py b/chex/_src/asserts.py index 9c6dffd0..600bbfe5 100644 --- a/chex/_src/asserts.py +++ b/chex/_src/asserts.py @@ -23,7 +23,6 @@ import unittest from unittest import mock -from absl import logging from chex._src import asserts_internal as _ai from chex._src import pytypes import jax @@ -41,13 +40,6 @@ _ai.chex_assertion, jittable_assert_fn=None) -def _ignore_nones_deprecation() -> None: - logging.warning( - "`ignore_nones` argument is non-functional since v0.1.82 and will be" - " removed in the next version, please update your usages." - ) - - def disable_asserts() -> None: """Disables all Chex assertions. @@ -1030,21 +1022,15 @@ def _is_leaf(value): @_static_assertion -def assert_tree_has_only_ndarrays(tree: ArrayTree, - *, - ignore_nones: bool = False) -> None: +def assert_tree_has_only_ndarrays(tree: ArrayTree) -> None: """Checks that all `tree`'s leaves are n-dimensional arrays (tensors). Args: tree: A tree to assert. - ignore_nones: Deprecated. Raises: AssertionError: If the tree contains an object which is not an ndarray. """ - if ignore_nones: - _ignore_nones_deprecation() - errors = [] def _assert_fn(path, leaf): @@ -1085,7 +1071,6 @@ def assert_tree_is_on_host( *, allow_cpu_device: bool = True, allow_sharded_arrays: bool = False, - ignore_nones: bool = False, ) -> None: """Checks that all leaves are ndarrays residing in the host memory (on CPU). @@ -1097,15 +1082,11 @@ def assert_tree_is_on_host( allow_sharded_arrays: Whether to allow sharded JAX arrays. Sharded arrays are considered "on host" only if they are sharded across CPU devices and `allow_cpu_device` is `True`. - ignore_nones: Deprecated. Raises: AssertionError: If the tree contains a leaf that is not an ndarray or does not reside on host. """ - if ignore_nones: - _ignore_nones_deprecation() - assert_tree_has_only_ndarrays(tree) errors = [] @@ -1160,8 +1141,7 @@ def assert_tree_is_on_device(tree: ArrayTree, *, platform: Union[Sequence[str], str] = ("gpu", "tpu"), - device: Optional[pytypes.Device] = None, - ignore_nones: bool = False) -> None: + device: Optional[pytypes.Device] = None) -> None: """Checks that all leaves are ndarrays residing in device memory (in HBM). Sharded DeviceArrays are disallowed. @@ -1172,15 +1152,11 @@ def assert_tree_is_on_device(tree: ArrayTree, reside. Ignored if `device` is specified. device: An optional device where the tree's arrays are expected to reside. Any device (except CPU) is accepted if not specified. - ignore_nones: Deprecated. Raises: AssertionError: If the tree contains a leaf that is not an ndarray or does not reside on the specified device or platform. """ - if ignore_nones: - _ignore_nones_deprecation() - assert_tree_has_only_ndarrays(tree) # If device is specified, require its platform. @@ -1226,23 +1202,18 @@ def _assert_fn(path, leaf): @_static_assertion def assert_tree_is_sharded(tree: ArrayTree, *, - devices: Sequence[pytypes.Device], - ignore_nones: bool = False) -> None: + devices: Sequence[pytypes.Device]) -> None: """Checks that all leaves are ndarrays sharded across the specified devices. Args: tree: A tree to assert. devices: A list of devices which the tree's leaves are expected to be sharded across. This list is order-sensitive. - ignore_nones: Deprecated. Raises: AssertionError: If the tree contains a leaf that is not a device array sharded across the specified devices. """ - if ignore_nones: - _ignore_nones_deprecation() - assert_tree_has_only_ndarrays(tree) errors = [] @@ -1280,22 +1251,13 @@ def _assert_fn(path, leaf): @_static_assertion def assert_tree_shape_prefix(tree: ArrayTree, - shape_prefix: Sequence[int], - *, - ignore_nones: bool = False) -> None: + shape_prefix: Sequence[int]) -> None: """Checks that all ``tree`` leaves' shapes have the same prefix. Args: tree: A tree to check. shape_prefix: An expected shape prefix. - ignore_nones: Deprecated. - - Raises: - AssertionError: If some leaf's shape doesn't start with ``shape_prefix``. """ - if ignore_nones: - _ignore_nones_deprecation() - # To compare with the leaf's `shape`, convert int sequence to tuple. shape_prefix = tuple(shape_prefix) @@ -1327,23 +1289,18 @@ def _assert_fn(path, leaf): @_static_assertion -def assert_tree_shape_suffix(tree: ArrayTree, - shape_suffix: Sequence[int], - *, - ignore_nones: bool = False) -> None: +def assert_tree_shape_suffix( + tree: ArrayTree, shape_suffix: Sequence[int] +) -> None: """Checks that all ``tree`` leaves' shapes have the same suffix. Args: tree: A tree to check. shape_suffix: An expected shape suffix. - ignore_nones: Deprecated. Raises: AssertionError: If some leaf's shape doesn't start with ``shape_suffix``. """ - if ignore_nones: - _ignore_nones_deprecation() - # To compare with the leaf's `shape`, convert int sequence to tuple. shape_suffix = tuple(shape_suffix) @@ -1410,8 +1367,7 @@ def assert_trees_all_equal_structs(*trees: ArrayTree) -> None: @_static_assertion def assert_trees_all_equal_comparator(equality_comparator: _ai.TLeavesEqCmpFn, error_msg_fn: _ai.TLeavesEqCmpErrorFn, - *trees: ArrayTree, - ignore_nones: bool = False) -> None: + *trees: ArrayTree) -> None: """Checks that all trees are equal as per the custom comparator for leaves. Args: @@ -1421,16 +1377,12 @@ def assert_trees_all_equal_comparator(equality_comparator: _ai.TLeavesEqCmpFn, ``equality_comparator`` leaves and returning an error message. *trees: A sequence of (at least 2) trees to check on equality as per ``equality_comparator``. - ignore_nones: Deprecated. Raises: ValueError: If ``trees`` does not contain at least 2 elements. AssertionError: if ``equality_comparator`` returns `False` for any pair of trees from ``trees``. """ - if ignore_nones: - _ignore_nones_deprecation() - if len(trees) < 2: raise ValueError( "Assertions over only one tree does not make sense. Maybe you wrote " @@ -1465,20 +1417,15 @@ def tree_error_msg_fn(l_1: _ai.TLeaf, l_2: _ai.TLeaf, path: str, i_1: int, @_static_assertion -def assert_trees_all_equal_dtypes(*trees: ArrayTree, - ignore_nones: bool = False) -> None: +def assert_trees_all_equal_dtypes(*trees: ArrayTree) -> None: """Checks that trees' leaves have the same dtype. Args: *trees: A sequence of (at least 2) trees to check. - ignore_nones: Deprecated. Raises: AssertionError: If leaves' dtypes for any two trees differ. """ - if ignore_nones: - _ignore_nones_deprecation() - def cmp_fn(arr_1, arr_2): return (hasattr(arr_1, "dtype") and hasattr(arr_2, "dtype") and arr_1.dtype == arr_2.dtype) @@ -1494,38 +1441,30 @@ def err_msg_fn(arr_1, arr_2): @_static_assertion -def assert_trees_all_equal_sizes(*trees: ArrayTree, - ignore_nones: bool = False) -> None: +def assert_trees_all_equal_sizes(*trees: ArrayTree) -> None: """Checks that trees have the same structure and leaves' sizes. Args: *trees: A sequence of (at least 2) trees with array leaves. - ignore_nones: Deprecated. Raises: AssertionError: If trees' structures or leaves' sizes are different. """ - if ignore_nones: - _ignore_nones_deprecation() cmp_fn = lambda arr_1, arr_2: arr_1.size == arr_2.size err_msg_fn = lambda arr_1, arr_2: f"sizes: {arr_1.size} != {arr_2.size}" assert_trees_all_equal_comparator(cmp_fn, err_msg_fn, *trees) @_static_assertion -def assert_trees_all_equal_shapes(*trees: ArrayTree, - ignore_nones: bool = False) -> None: +def assert_trees_all_equal_shapes(*trees: ArrayTree) -> None: """Checks that trees have the same structure and leaves' shapes. Args: *trees: A sequence of (at least 2) trees with array leaves. - ignore_nones: Deprecated Raises: AssertionError: If trees' structures or leaves' shapes are different. """ - if ignore_nones: - _ignore_nones_deprecation() cmp_fn = lambda arr_1, arr_2: arr_1.shape == arr_2.shape err_msg_fn = lambda arr_1, arr_2: f"shapes: {arr_1.shape} != {arr_2.shape}" assert_trees_all_equal_comparator(cmp_fn, err_msg_fn, *trees) @@ -1538,19 +1477,15 @@ def assert_trees_all_equal_shapes(*trees: ArrayTree, @_static_assertion -def assert_trees_all_equal_shapes_and_dtypes( - *trees: ArrayTree, ignore_nones: bool = False) -> None: +def assert_trees_all_equal_shapes_and_dtypes(*trees: ArrayTree) -> None: """Checks that trees' leaves have the same shape and dtype. Args: *trees: A sequence of (at least 2) trees to check. - ignore_nones: Deprecated. Raises: AssertionError: If leaves' shapes or dtypes for any two trees differ. """ - if ignore_nones: - _ignore_nones_deprecation() assert_trees_all_equal_shapes(*trees) assert_trees_all_equal_dtypes(*trees) @@ -1575,9 +1510,7 @@ def _assert_tree_all_finite_static(tree_like: ArrayTree) -> None: raise AssertionError(f"Tree contains non-finite value: {error_msg}.") -def _assert_tree_all_finite_jittable( - tree_like: ArrayTree -) -> Array: +def _assert_tree_all_finite_jittable(tree_like: ArrayTree) -> Array: """A jittable version of `_assert_tree_all_finite_static`.""" labeled_tree = jax.tree_map( lambda x: jax.lax.select(jnp.isfinite(x).all(), .0, jnp.nan), tree_like @@ -1601,7 +1534,7 @@ def _assert_tree_all_finite_jittable( @_static_assertion def _assert_trees_all_equal_static( - *trees: ArrayTree, ignore_nones: bool = False, strict: bool = False + *trees: ArrayTree, strict: bool = False ) -> None: """Checks that all trees have leaves with *exactly* equal values. @@ -1610,16 +1543,12 @@ def _assert_trees_all_equal_static( Args: *trees: A sequence of (at least 2) trees with array leaves. - ignore_nones: Deprecated. strict: If True, disable special scalar handling as described in `np.testing.assert_array_equals` notes section. Raises: AssertionError: If the leaf values actual and desired are not exactly equal. """ - if ignore_nones: - _ignore_nones_deprecation() - def assert_fn(arr_1, arr_2): np.testing.assert_array_equal( _ai.jnp_to_np_array(arr_1), @@ -1647,7 +1576,7 @@ def err_msg_fn(arr_1, arr_2) -> str: def _assert_trees_all_equal_jittable( - *trees: ArrayTree, ignore_nones: bool = False, strict: bool = True, + *trees: ArrayTree, strict: bool = True, ) -> Array: """A jittable version of `_assert_trees_all_equal_static`.""" if not strict: @@ -1659,9 +1588,6 @@ def _assert_trees_all_equal_jittable( " version." ) - if ignore_nones: - _ignore_nones_deprecation() - err_msg_template = "Values not exactly equal: {arr_1} != {arr_2}." cmp_fn = lambda x, y: jnp.array_equal(x, y, equal_nan=True) return _ai.assert_trees_all_eq_comparator_jittable( @@ -1678,8 +1604,7 @@ def _assert_trees_all_equal_jittable( def _assert_trees_all_close_static(*trees: ArrayTree, rtol: float = 1e-06, - atol: float = .0, - ignore_nones: bool = False) -> None: + atol: float = .0) -> None: """Checks that all trees have leaves with approximately equal values. This compares the difference between values of actual and desired up to @@ -1689,15 +1614,11 @@ def _assert_trees_all_close_static(*trees: ArrayTree, *trees: A sequence of (at least 2) trees with array leaves. rtol: A relative tolerance. atol: An absolute tolerance. - ignore_nones: Deprecated. Raises: AssertionError: If actual and desired values are not equal up to specified tolerance. """ - if ignore_nones: - _ignore_nones_deprecation() - def assert_fn(arr_1, arr_2): np.testing.assert_allclose( _ai.jnp_to_np_array(arr_1), @@ -1727,12 +1648,8 @@ def err_msg_fn(arr_1, arr_2) -> str: def _assert_trees_all_close_jittable(*trees: ArrayTree, rtol: float = 1e-06, - atol: float = .0, - ignore_nones: bool = False) -> Array: + atol: float = .0) -> Array: """A jittable version of `_assert_trees_all_close_static`.""" - if ignore_nones: - _ignore_nones_deprecation() - err_msg_template = ( f"Values not approximately equal ({rtol=}, {atol=}): " + "{arr_1} != {arr_2}." @@ -1757,7 +1674,6 @@ def _assert_trees_all_close_jittable(*trees: ArrayTree, def _assert_trees_all_close_ulp_static( *trees: ArrayTree, maxulp: int = 1, - ignore_nones: bool = False, ) -> None: """Checks that tree leaves differ by at most `maxulp` Units in the Last Place. @@ -1795,15 +1711,11 @@ def _assert_trees_all_close_ulp_static( Args: *trees: A sequence of (at least 2) trees with array leaves. maxulp: The maximum number of ULPs by which leaves may differ. - ignore_nones: Deprecated. Raises: AssertionError: If actual and desired values are not equal up to specified tolerance. """ - if ignore_nones: - _ignore_nones_deprecation() - def assert_fn(arr_1, arr_2): if ( getattr(arr_1, "dtype", None) == jnp.bfloat16 @@ -1849,7 +1761,6 @@ def err_msg_fn(arr_1, arr_2) -> str: def _assert_trees_all_close_ulp_jittable( *trees: ArrayTree, maxulp: int = 1, - ignore_nones: bool = False, ) -> jax.Array: """A dummy jittable version of `_assert_trees_all_close_ulp_static`. @@ -1860,7 +1771,6 @@ def _assert_trees_all_close_ulp_jittable( Args: *trees: Ignored. maxulp: Ignored. - ignore_nones: Deprecated. Raises: NotImplementedError: unconditionally. @@ -1869,9 +1779,6 @@ def _assert_trees_all_close_ulp_jittable( Never returns. (We pretend jax.Array to satisfy the type checker.) """ del trees, maxulp - if ignore_nones: - _ignore_nones_deprecation() - raise NotImplementedError( f"{_ai.ERR_PREFIX}assert_trees_all_close_ulp is not supported within JIT " "contexts."