Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Nov 22, 2023
1 parent 9550c15 commit bcf68ed
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion tests/modeldiffs/criteo1tb_layernorm/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
def key_transform(k):
new_key = []
s_count = None
layer_norm = False
print('key')
print(k)
for i in k:
Expand All @@ -31,14 +32,18 @@ def key_transform(k):
name, count = i.split('_')
i = name + '_' + str(s_count * 3 + int(count))
if 'LayerNorm' in i:
layer_norm = True
name, count = i.split('_')
# There is a layernorm on embedding between bottom and top MLP
if s_count is not None:
i = name + '_' + str(s_count * 4 + int(count))
else:
i = name + '_' + str(3)
elif 'weight' in i:
i = i.replace('weight', 'kernel')
if layer_norm:
i = i.replace('weight', 'scale')
else:
i = i.replace('weight', 'kernel')
new_key.append(i)
print(new_key)
return tuple(new_key)
Expand Down

0 comments on commit bcf68ed

Please sign in to comment.