diff --git a/tests/modeldiffs/fastmri_layernorm/compare.py b/tests/modeldiffs/fastmri_layernorm/compare.py index 62fbfaceb..30939b7d4 100644 --- a/tests/modeldiffs/fastmri_layernorm/compare.py +++ b/tests/modeldiffs/fastmri_layernorm/compare.py @@ -28,6 +28,7 @@ def sort_key(k): keys = sorted(sd.keys(), key=sort_key) c = 0 + jax_weights_name = 'kernel' for idx, k in enumerate(keys): new_key = [] for idx2, i in enumerate(k): @@ -41,8 +42,10 @@ def sort_key(k): i = i.replace('Conv2d', 'Conv') if 'ConvTranspose2d' in i: i = i.replace('ConvTranspose2d', 'ConvTranspose') + if 'LayerNorm' in i: + jax_weights_name = 'scale' if 'weight' in i: - i = i.replace('weight', 'kernel') + i = i.replace('weight', jax_weights_name) new_key.append(i) new_key = tuple(new_key) sd[new_key] = sd[k]