diff --git a/pylintrc b/pylintrc index a28602f..05bc46b 100644 --- a/pylintrc +++ b/pylintrc @@ -520,5 +520,5 @@ known-third-party=enchant # Exceptions that will emit a warning when being caught. Defaults to # "BaseException, Exception". -overgeneral-exceptions=BaseException, - Exception +overgeneral-exceptions=builtins.BaseException, + builtins.Exception diff --git a/src/zennit/canonizers.py b/src/zennit/canonizers.py index aa49bd3..e596584 100644 --- a/src/zennit/canonizers.py +++ b/src/zennit/canonizers.py @@ -124,7 +124,7 @@ def merge_batch_norm(modules, batch_norm): `eps` ''' denominator = (batch_norm.running_var + batch_norm.eps) ** .5 - scale = (batch_norm.weight / denominator) + scale = batch_norm.weight / denominator for module in modules: original_weight = module.weight.data @@ -140,7 +140,7 @@ def merge_batch_norm(modules, batch_norm): index = (slice(None), *((None,) * (original_weight.ndim - 1))) # merge batch_norm into linear layer - module.weight.data = (original_weight * scale[index]) + module.weight.data = original_weight * scale[index] module.bias.data = (original_bias - batch_norm.running_mean) * scale + batch_norm.bias # change batch_norm parameters to produce identity diff --git a/tests/conftest.py b/tests/conftest.py index 7238d5a..8c56b9f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -58,10 +58,10 @@ def rng(request): scope='session', params=[ (torch.nn.ReLU, {}), - (torch.nn.Softmax, dict(dim=1)), + (torch.nn.Softmax, {'dim': 1}), (torch.nn.Tanh, {}), (torch.nn.Sigmoid, {}), - (torch.nn.Softplus, dict(beta=1)), + (torch.nn.Softplus, {'beta': 1}), ], ids=lambda param: param[0].__name__ )