How to do Batch Normalization in JAX/Flax? #921
-
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
The BatchNorm module is in normalization.py. The canonical example using it is Imagenet In a multi-device setting, every device updates its normalizing parameters ( If they aren't synced, they can theoretically diverge, but if your data is fairly uniform across shards they're likely to trend towards similar values. Syncing before eval is definitely a good idea though, since otherwise your eval results will depend on which devices process which examples.
|
Beta Was this translation helpful? Give feedback.
The BatchNorm module is in normalization.py. The canonical example using it is Imagenet
In a multi-device setting, every device updates its normalizing parameters (
ra_mean
andra_var
in the code) based on its own batch statistics (unless we specifyaxis_name
, which makes it global). However, these parameters are only stored in in Module variablebatch_stats
and not inparams
, so they are never synced across devices unless specifically done so.If they aren't synced, they can theoretically diverge, but if your data is fairly uniform across shards they're likely to trend towards similar values. Syncing before eval is definitely a good idea though, since otherwise your eval results will depen…