From a059b648a5059fd5b07257d1a831334bee4e8f93 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 23 Oct 2023 01:49:52 -0700 Subject: [PATCH] Updated the `hk.mixed_precision.set_policy` docstring `hk.BatchNorm` no longer overflows in f16, because JAX does the accumulation in f32 (see google/jax@b18ca05bc66a332a97f5b797e85b5c0ba316f006). PiperOrigin-RevId: 575745419 --- haiku/_src/mixed_precision.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/haiku/_src/mixed_precision.py b/haiku/_src/mixed_precision.py index b5a1518a1..ee9198b71 100644 --- a/haiku/_src/mixed_precision.py +++ b/haiku/_src/mixed_precision.py @@ -163,19 +163,6 @@ def set_policy(cls: type[hk.Module], policy: jmp.Policy): >>> net = hk.nets.ResNet50(4) >>> x = jnp.ones([4, 224, 224, 3]) >>> print(net(x, is_training=True)) - [[nan nan nan nan] - [nan nan nan nan] - [nan nan nan nan] - [nan nan nan nan]] - - Oh no, nan! This is because modules like batch norm are not numerically stable - in ``float16``. To address this, we apply a second policy to our batch norm - modules to keep them in full precision. We are careful to return a ``float16`` - output from the module such that subsequent modules receive ``float16`` input: - - >>> policy = jmp.get_policy('params=float32,compute=float32,output=float16') - >>> hk.mixed_precision.set_policy(hk.BatchNorm, policy) - >>> print(net(x, is_training=True)) [[0. 0. 0. 0.] [0. 0. 0. 0.] [0. 0. 0. 0.] @@ -186,7 +173,6 @@ def set_policy(cls: type[hk.Module], policy: jmp.Policy): speedup in training time with only a small impact on final top-1 accuracy. >>> hk.mixed_precision.clear_policy(hk.nets.ResNet50) - >>> hk.mixed_precision.clear_policy(hk.BatchNorm) Args: cls: A Haiku module class.