diff --git a/tests/modeldiffs/criteo1tb_layernorm/compare.py b/tests/modeldiffs/criteo1tb_layernorm/compare.py index a8ee234aa..d786bb11c 100644 --- a/tests/modeldiffs/criteo1tb_layernorm/compare.py +++ b/tests/modeldiffs/criteo1tb_layernorm/compare.py @@ -18,6 +18,7 @@ def key_transform(k): new_key = [] s_count = None + layer_norm = False print('key') print(k) for i in k: @@ -31,6 +32,7 @@ def key_transform(k): name, count = i.split('_') i = name + '_' + str(s_count * 3 + int(count)) if 'LayerNorm' in i: + layer_norm = True name, count = i.split('_') # There is a layernorm on embedding between bottom and top MLP if s_count is not None: @@ -38,7 +40,10 @@ def key_transform(k): else: i = name + '_' + str(3) elif 'weight' in i: - i = i.replace('weight', 'kernel') + if layer_norm: + i = i.replace('weight', 'scale') + else: + i = i.replace('weight', 'kernel') new_key.append(i) print(new_key) return tuple(new_key)