Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Dec 7, 2023
1 parent 4ba6736 commit 73b72b1
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion tests/modeldiffs/fastmri_layernorm/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def sort_key(k):

keys = sorted(sd.keys(), key=sort_key)
c = 0
jax_weights_name = 'kernel'
for idx, k in enumerate(keys):
new_key = []
for idx2, i in enumerate(k):
Expand All @@ -41,8 +42,10 @@ def sort_key(k):
i = i.replace('Conv2d', 'Conv')
if 'ConvTranspose2d' in i:
i = i.replace('ConvTranspose2d', 'ConvTranspose')
if 'LayerNorm' in i:
jax_weights_name = 'scale'
if 'weight' in i:
i = i.replace('weight', 'kernel')
i = i.replace('weight', jax_weights_name)
new_key.append(i)
new_key = tuple(new_key)
sd[new_key] = sd[k]
Expand Down

0 comments on commit 73b72b1

Please sign in to comment.