Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix biasScaleShape of GroupNormalizationV21 to support ranks > 4 #3030

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jorickert
Copy link
Collaborator

Before this PR the oneDimShape assumed a spacial rank of two, which is only correct for rank==4.

@hamptonm1
Copy link
Collaborator

@jorickert Hello!! Can you direct me to the link or document which validates your findings please? I just want to have a better understanding. Thank you!

@jorickert
Copy link
Collaborator Author

jorickert commented Dec 17, 2024

@hamptonm1
According to the onnx spec the input for GroupNormV21 has the dimensions (N, C, D1,D2, Dn), scale and bias have the dimension C

The formula for GroupNorm is:
y = scale * (x - mean) / sqrt(variance + epsilon) + bias where the mean and variance are computed per instance per group of channels. The formula for LayerNorm is generally the same, the main difference (in onnx) being that it allows the selection of axes for the mean and variance.

Internally, the GroupNorm reshapes the input to (N, G, C // G, D1, D2, Dn) and then performs a LayerNorm like operation on it, with the reduction axes being C // G to Dn. This can be seen in the GroupNorm paper https://arxiv.org/pdf/1803.08494 figure 3.

The decomposition/conversion mimics this:

  1. Manually reshape the input to (N, G, C // G, D1, D2, Dn)
  2. Perform a LayerNorm with axes= C // G to Dn
  3. Manually reshape to (N, C, D1,D2, Dn)
    Additionally, it is required to reshape the scale and bias. The input scale and bias have shape C.
    To make them compatible with the LayerNorm and broadcasting, the following reshape needs to be done:
    C -> (G, C // G, 1, 1, 1) where the number of once equals the number of spacial dimensions.
    Without my PR, the reshape is to (G, C // G, 1, 1), no matter how many spacial dimensions exist.

While writing this down, I realized, that the lit tests are still wrong for GroupNormV21. The input scale and bias have the size C//G = 2 (which is correct for GroupNormV18) , but should be C = 4 . I will fix this

Copy link
Collaborator

@AlexandreEichenberger AlexandreEichenberger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, let me know if you need me to merge it (and if so, when you are ready to do so),

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminder to update the bias and scale size in the lit test inputs

Copy link
Collaborator

@hamptonm1 hamptonm1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay it works for me then!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants