diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index 0b7004568..b9dbbc80e 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -37,7 +37,7 @@ def key_transform(k): continue if 'Linear' in i: i = i.replace('Linear', 'Dense') - name, count = i.split('_') + name, _ = i.split('_') block_count = mlp_block_count if mlp_block_count else resnet_block_count i = name + '_' + str(mlp_count * 3 + block_count) elif 'weight' in i: diff --git a/tests/modeldiffs/diff.py b/tests/modeldiffs/diff.py index 81eff8301..f96fa672b 100644 --- a/tests/modeldiffs/diff.py +++ b/tests/modeldiffs/diff.py @@ -1,5 +1,3 @@ -import logging - from flax import jax_utils import jax import numpy as np