diff --git a/tests/modeldiffs/torch2jax_utils.py b/tests/modeldiffs/torch2jax_utils.py index d9264b400..e0df61b4a 100644 --- a/tests/modeldiffs/torch2jax_utils.py +++ b/tests/modeldiffs/torch2jax_utils.py @@ -77,6 +77,7 @@ def key_transform(self, k_transform_fn): } def value_transform(self, v_transform_fn): + print(self.flattened_jax_model.keys()) self.pytorch_sd = { k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k]) for k in self.pytorch_sd