Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Implement better wd_ban_list handling #282

Merged
merged 3 commits into from
Oct 24, 2024

Conversation

Vectorrent
Copy link
Contributor

Problem (Why?)

The wd_ban_list argument for get_optimizer_parameters() is somewhat misleading. When you look at it, you would expect any of the default arguments' name-formats to work correctly. However, that is not the case.

wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight')

From this list, only bias is "detected" and "banned" correctly. Neither LayerNorm.bias is detected, nor is LayerNorm.weight. Neither of these parameters have their weight_decay set to 0.

I even tested LayerNorm - and that doesn't work, either.

Solution (What/How?)

The reason this fails is that the wd_ban_list logic is only checking for the actual, fully-qualified parameter names; it is NOT checking for the class name of each nn.Module, as pytorch_optimizer's default arguments and tests would imply.

I implemented a more complete method for handling the wd_ban_list. Now, we check both for "true names", as well as for nn.Module names.

Notes

I've been using this patch in my own code for several weeks now; it seems to work great! Let me know if there is anything you would change.

@kozistr kozistr added the enhancement New feature or request label Oct 24, 2024
kozistr
kozistr previously approved these changes Oct 24, 2024
Copy link
Owner

@kozistr kozistr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi thanks for the contribution!

wd_ban_list logic is only checking for the actual, fully-qualified parameter names

Yes. It was originally intended to exclude parameters that included names on the blacklist. So as you mentioned above, if you have a layer norm layer called 'asdf' and don't put the exact parameter name into wd_ban_list for example, it won't be excluded. I added LayerNorm.bias, and LayerNorm.weight to the default wd_ban_list to align with the usages in the Transformers library.

Your idea sounds good to me also in the aspect of adding module names (e.g. LayerNorm) in the exclusion criteria cuz we usually ban based on the type of module.

your code looks good to me! could you please run make format & make check by any chance? or I can handle it later then.

@Vectorrent
Copy link
Contributor Author

I just pushed a new commit, with a few fixes. However, there is one error I was not able to fix:

pytorch_optimizer/optimizer/utils.py:201:5: D212 [*] Multi-line docstring summary should start at the first line
    |
199 |       wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'),
200 |   ) -> PARAMETERS:
201 |       r"""
    |  _____^
202 | |     Get optimizer parameters while filtering specified modules.
203 | |     :param model_or_parameter: Union[nn.Module, List]. model or parameters.
204 | |     :param weight_decay: float. weight_decay.
205 | |     :param wd_ban_list: List[str]. ban list not to set weight decay.
206 | |     :returns: PARAMETERS. new parameter list.
207 | |     """
    | |_______^ D212
208 |   
209 |       fully_qualified_names = []
    |
    = help: Remove whitespace after opening quotes

Found 3 errors.
[*] 2 fixable with the `--fix` option.
make: *** [Makefile:16: check] Error 1

If you run make format, it fixes this issue. But then, if you run make check, it fails. So, if I manually fix it, then make check will work - but make format will fail, now!

I'm not super familiar with make, so I don't really know what to do here.

@kozistr
Copy link
Owner

kozistr commented Oct 24, 2024

I just pushed a new commit, with a few fixes. However, there is one error I was not able to fix:

pytorch_optimizer/optimizer/utils.py:201:5: D212 [*] Multi-line docstring summary should start at the first line
    |
199 |       wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'),
200 |   ) -> PARAMETERS:
201 |       r"""
    |  _____^
202 | |     Get optimizer parameters while filtering specified modules.
203 | |     :param model_or_parameter: Union[nn.Module, List]. model or parameters.
204 | |     :param weight_decay: float. weight_decay.
205 | |     :param wd_ban_list: List[str]. ban list not to set weight decay.
206 | |     :returns: PARAMETERS. new parameter list.
207 | |     """
    | |_______^ D212
208 |   
209 |       fully_qualified_names = []
    |
    = help: Remove whitespace after opening quotes

Found 3 errors.
[*] 2 fixable with the `--fix` option.
make: *** [Makefile:16: check] Error 1

If you run make format, it fixes this issue. But then, if you run make check, it fails. So, if I manually fix it, then make check will work - but make format will fail, now!

I'm not super familiar with make, so I don't really know what to do here.

it's okay. I can handle lint stuff.

anyway, thanks for the contributions!

@kozistr kozistr merged commit 769e5fb into kozistr:main Oct 24, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request size/S
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants