From 329c82822ca691238221ad378a783d92aad5d11c Mon Sep 17 00:00:00 2001 From: ChexDev Date: Fri, 21 Oct 2022 15:20:53 -0700 Subject: [PATCH] [Flaxformer] Additional asserts to help find confusing misconfigurations of models (added in zT5 testing). PiperOrigin-RevId: 482900882 --- chex/_src/asserts.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/chex/_src/asserts.py b/chex/_src/asserts.py index 24946341..9f6d7581 100644 --- a/chex/_src/asserts.py +++ b/chex/_src/asserts.py @@ -514,7 +514,8 @@ def _shape_matches(actual_shape: Sequence[int], def assert_shape( inputs: Union[Scalar, Union[Array, Sequence[Array]]], expected_shapes: Union[_ai.TShapeMatcher, - Sequence[_ai.TShapeMatcher]]) -> None: + Sequence[_ai.TShapeMatcher]], + explanation: Optional[Union[str, Callable[[], str]]] = None) -> None: """Checks that the shape of all inputs matches specified ``expected_shapes``. Valid usages include: @@ -535,6 +536,8 @@ def assert_shape( where the expected shape is a sequence of integer and `None` dimensions; if all inputs have same shape, a single shape may be passed as ``expected_shapes``. + explanation: Additional message to give context when this assertion fails + (or a function/closure that returns such a message). Raises: AssertionError: If the lengths of ``inputs`` and ``expected_shapes`` do not @@ -564,9 +567,19 @@ def assert_shape( errors.append((idx, shape, _ai.format_shape_matcher(expected))) if errors: + if callable(explanation): + try: + explanation: str = explanation() + except Exception as e: # pylint: disable=broad-except + explanation = ("[[`explanation` callback failed: " + + "\n".join(traceback.format_exception( + e.__class__, e, e.__traceback__, limit=4)) + "]]") + if not explanation: + explanation = "" msg = "; ".join( f"input {e[0]} has shape {e[1]} but expected {e[2]}" for e in errors) - raise AssertionError(f"Error in shape compatibility check: {msg}.") + raise AssertionError( + f"Error in shape compatibility check: {msg}. {explanation}") @_static_assertion