Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix pytorch_default_init() #819

Merged
merged 2 commits into from
Dec 12, 2024
Merged

Conversation

EIFY
Copy link

@EIFY EIFY commented Nov 26, 2024

torch.nn.init.trunc_normal_() defaults to truncation at (a, b), not (a * std, b * std). So to conform to JAX's variance_scaling(..., distribution="truncated_normal", ...) we need to multiply by std ourselves. We can see this by initializing a test model. Here is the repo's JAX ViT-S/16:

>>> import jax.numpy
>>> import jax.random
>>> from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.models import ViT
>>> from algorithmic_efficiency.workloads.imagenet_vit.workload import decode_variant
>>> vit = ViT(**decode_variant('S/16'))
>>> x = jax.numpy.zeros((1, 224, 224, 3), jax.numpy.float32)
>>> params = vit.init(jax.random.key(0), x)
>>> for w in [params['params']['conv_patch_extract']['kernel'], params['params']['pre_logits']['kernel']]:
...   print(w.min(), w.max())
...
-0.08204417 0.08203908
-0.11602508 0.116011634

Here is the repo's PyTorch ViT-S/16 before the fix:

>>> from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.models import ViT
>>> from algorithmic_efficiency.workloads.imagenet_vit.workload import decode_variant
>>> vit = ViT(**decode_variant('S/16'))
>>> for w in [vit.conv_patch_extract.weight, vit.pre_logits.weight]:
...   print(w.min(), w.max())
...
tensor(-0.2119, grad_fn=<MinBackward1>) tensor(0.1907, grad_fn=<MaxBackward1>)
tensor(-0.2749, grad_fn=<MinBackward1>) tensor(0.2512, grad_fn=<MaxBackward1>)

Here is the repo's PyTorch ViT-S/16 after the fix:

>>> from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.models import ViT
>>> from algorithmic_efficiency.workloads.imagenet_vit.workload import decode_variant
>>> vit = ViT(**decode_variant('S/16'))
>>> for w in [vit.conv_patch_extract.weight, vit.pre_logits.weight]:
...   print(w.min(), w.max())
... 
tensor(-0.0820, grad_fn=<MinBackward1>) tensor(0.0820, grad_fn=<MaxBackward1>)
tensor(-0.1160, grad_fn=<MinBackward1>) tensor(0.1160, grad_fn=<MaxBackward1>)

Affected current workloads include imagenet_vit, imagenet_resnet, fastmri, and ogbg, along with (retired? test?) workloads cifar and mnist.

I hope this bug doesn't drastically upend the results so far but I don't know 😬

@EIFY EIFY requested a review from a team as a code owner November 26, 2024 22:34
Copy link

github-actions bot commented Nov 26, 2024

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅

@EIFY
Copy link
Author

EIFY commented Nov 27, 2024

recheck

@priyakasimbeg
Copy link
Contributor

priyakasimbeg commented Nov 27, 2024

Hi Jason, thanks for sending this PR in. This is a very interesting and good catch!
Did you sign the CLA with your github email and username to unblock the CLA check?

It's hard to say at the moment what the final effect is of this difference in initialization. My guess is that it is probably not going to upend the results but we can double check this.

@EIFY
Copy link
Author

EIFY commented Nov 27, 2024

Hi Jason, thanks for sending this PR in. This is a very interesting and good catch! Did you sign the CLA with your github email and username to unblock the CLA check?

I have signed and emailed the CLA. I think the system has identified me as a signee, just not rerun the pull_request_target automatically. I have triggered it again with no-change amend.

torch.nn.init.trunc_normal_() defaults to truncation at (a, b),
not (a * std, b * std).
@priyakasimbeg priyakasimbeg changed the base branch from main to dev December 4, 2024 20:11
@priyakasimbeg priyakasimbeg self-requested a review December 12, 2024 18:18
@priyakasimbeg priyakasimbeg merged commit fe90379 into mlcommons:dev Dec 12, 2024
33 of 36 checks passed
@github-actions github-actions bot locked and limited conversation to collaborators Dec 12, 2024
@EIFY EIFY deleted the torch-init-fix branch December 12, 2024 20:12
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants