Skip to content

Commit

Permalink
Style: Update PyLint config, add suggestions
Browse files Browse the repository at this point in the history
- change overgeneral-exceptions in pylintrc to fully qualified names
- remove unnecessary parentheses in assignment in canonizers.py
- use {} notation to specify dicts instead of `dict` notation in
  conftest.py
  • Loading branch information
chr5tphr committed Feb 16, 2023
1 parent fe8dbce commit dd3b22b
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -520,5 +520,5 @@ known-third-party=enchant

# Exceptions that will emit a warning when being caught. Defaults to
# "BaseException, Exception".
overgeneral-exceptions=BaseException,
Exception
overgeneral-exceptions=builtins.BaseException,
builtins.Exception
4 changes: 2 additions & 2 deletions src/zennit/canonizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def merge_batch_norm(modules, batch_norm):
`eps`
'''
denominator = (batch_norm.running_var + batch_norm.eps) ** .5
scale = (batch_norm.weight / denominator)
scale = batch_norm.weight / denominator

for module in modules:
original_weight = module.weight.data
Expand All @@ -140,7 +140,7 @@ def merge_batch_norm(modules, batch_norm):
index = (slice(None), *((None,) * (original_weight.ndim - 1)))

# merge batch_norm into linear layer
module.weight.data = (original_weight * scale[index])
module.weight.data = original_weight * scale[index]
module.bias.data = (original_bias - batch_norm.running_mean) * scale + batch_norm.bias

# change batch_norm parameters to produce identity
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def rng(request):
scope='session',
params=[
(torch.nn.ReLU, {}),
(torch.nn.Softmax, dict(dim=1)),
(torch.nn.Softmax, {'dim': 1}),
(torch.nn.Tanh, {}),
(torch.nn.Sigmoid, {}),
(torch.nn.Softplus, dict(beta=1)),
(torch.nn.Softplus, {'beta': 1}),
],
ids=lambda param: param[0].__name__
)
Expand Down

0 comments on commit dd3b22b

Please sign in to comment.