Skip to content

Commit

Permalink
Updated the hk.mixed_precision.set_policy docstring
Browse files Browse the repository at this point in the history
`hk.BatchNorm` no longer overflows in f16, because JAX does the accumulation
in f32 (see jax-ml/jax@b18ca05).

PiperOrigin-RevId: 575745419
  • Loading branch information
superbobry authored and copybara-github committed Oct 23, 2023
1 parent 86a00ea commit eb38b36
Showing 1 changed file with 0 additions and 14 deletions.
14 changes: 0 additions & 14 deletions haiku/_src/mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,6 @@ def set_policy(cls: type[hk.Module], policy: jmp.Policy):
>>> net = hk.nets.ResNet50(4)
>>> x = jnp.ones([4, 224, 224, 3])
>>> print(net(x, is_training=True))
[[nan nan nan nan]
[nan nan nan nan]
[nan nan nan nan]
[nan nan nan nan]]
Oh no, nan! This is because modules like batch norm are not numerically stable
in ``float16``. To address this, we apply a second policy to our batch norm
modules to keep them in full precision. We are careful to return a ``float16``
output from the module such that subsequent modules receive ``float16`` input:
>>> policy = jmp.get_policy('params=float32,compute=float32,output=float16')
>>> hk.mixed_precision.set_policy(hk.BatchNorm, policy)
>>> print(net(x, is_training=True))
[[0. 0. 0. 0.]
[0. 0. 0. 0.]
[0. 0. 0. 0.]
Expand All @@ -186,7 +173,6 @@ def set_policy(cls: type[hk.Module], policy: jmp.Policy):
speedup in training time with only a small impact on final top-1 accuracy.
>>> hk.mixed_precision.clear_policy(hk.nets.ResNet50)
>>> hk.mixed_precision.clear_policy(hk.BatchNorm)
Args:
cls: A Haiku module class.
Expand Down

0 comments on commit eb38b36

Please sign in to comment.