Skip to content

How to do Batch Normalization in JAX/Flax? #921

Answered by marcvanzee
marcvanzee asked this question in Q&A
Discussion options

You must be logged in to vote

The BatchNorm module is in normalization.py. The canonical example using it is Imagenet

In a multi-device setting, every device updates its normalizing parameters (ra_mean and ra_var in the code) based on its own batch statistics (unless we specify axis_name, which makes it global). However, these parameters are only stored in in Module variablebatch_stats and not in params, so they are never synced across devices unless specifically done so.

If they aren't synced, they can theoretically diverge, but if your data is fairly uniform across shards they're likely to trend towards similar values. Syncing before eval is definitely a good idea though, since otherwise your eval results will depen…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant
Converted from issue

This discussion was converted from issue #855 on January 21, 2021 14:22.