diff --git a/tests/modeldiffs/torch2jax_utils.py b/tests/modeldiffs/torch2jax_utils.py index 07f7cc360..9600cd204 100644 --- a/tests/modeldiffs/torch2jax_utils.py +++ b/tests/modeldiffs/torch2jax_utils.py @@ -106,6 +106,7 @@ def diff(self): if s_p == s_j: count += 1 else: + print('Difference in pytorch and jax key:') print(k, s_p, s_j) print(f'Number of values with identical shapes: {count}')