Skip to content

Commit

Permalink
debug
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Dec 7, 2023
1 parent 174f0a4 commit 99da8ba
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions tests/modeldiffs/fastmri_layernorm/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def sort_key(k):
c = 0
for idx, k in enumerate(keys):
new_key = []
jax_weights_name = 'kernel'
layernorm = False
for idx2, i in enumerate(k):
if 'ModuleList' in i or 'Sequential' in i:
continue
Expand All @@ -43,10 +43,13 @@ def sort_key(k):
if 'ConvTranspose2d' in i:
i = i.replace('ConvTranspose2d', 'ConvTranspose')
if 'LayerNorm' in i:
jax_weights_name = 'scale'
layernorm = True
continue
if 'weight' in i:
i = i.replace('weight', jax_weights_name)
if layernorm:
i = i.replace('weight', 'scale')
else:
i = i.replace('weight', 'dense')
new_key.append(i)
new_key = tuple(new_key)
sd[new_key] = sd[k]
Expand Down

0 comments on commit 99da8ba

Please sign in to comment.