Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix computation in normalization layers #876

Merged
merged 4 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions equinox/nn/_normalisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def __init__(
)
self.use_weight = use_weight
self.use_bias = use_bias

with jax.numpy_dtype_promotion("standard"):
dtype = jnp.result_type(dtype, jnp.float32)
AakashKumarNain marked this conversation as resolved.
Show resolved Hide resolved

self.weight = jnp.ones(shape, dtype=dtype) if use_weight else None
self.bias = jnp.zeros(shape, dtype=dtype) if use_bias else None

Expand Down Expand Up @@ -140,6 +144,11 @@ def __call__(
"`x.shape` ended with `shape`. However, this turned out to be a "
"frequent source of bugs, so we made the check stricter!"
)
orig_dtype = x.dtype
with jax.numpy_dtype_promotion("standard"):
dtype = jnp.result_type(x.dtype, jnp.float32)

x = x.astype(dtype)
mean = jnp.mean(x, keepdims=True)
variance = jnp.var(x, keepdims=True)
variance = jnp.maximum(0.0, variance)
Expand All @@ -150,9 +159,9 @@ def __call__(
if self.use_bias:
out = out + self.bias
if state is sentinel:
return out
return out.astype(orig_dtype)
else:
return out, state
return out.astype(orig_dtype), state


class GroupNorm(Module, strict=True):
Expand Down Expand Up @@ -222,6 +231,10 @@ def __init__(
self.channels = channels
self.eps = eps
self.channelwise_affine = channelwise_affine

with jax.numpy_dtype_promotion("standard"):
dtype = jnp.result_type(dtype, jnp.float32)

self.weight = jnp.ones(channels, dtype=dtype) if channelwise_affine else None
self.bias = jnp.zeros(channels, dtype=dtype) if channelwise_affine else None

Expand Down Expand Up @@ -253,6 +266,12 @@ def __call__(
is passed through unchanged. If `state` is not passed, then just the output is
returned.
"""

orig_dtype = x.dtype
with jax.numpy_dtype_promotion("standard"):
dtype = jnp.result_type(x.dtype, jnp.float32)

x = x.astype(dtype)
channels = x.shape[0]
y = x.reshape(self.groups, channels // self.groups, *x.shape[1:])
mean = jax.vmap(ft.partial(jnp.mean, keepdims=True))(y)
Expand All @@ -266,9 +285,9 @@ def __call__(
bias = left_broadcast_to(self.bias, out.shape) # pyright: ignore
out = weight * out + bias
if state is sentinel:
return out
return out.astype(orig_dtype)
else:
return out, state
return out.astype(orig_dtype), state


class RMSNorm(Module, strict=True):
Expand Down Expand Up @@ -335,6 +354,10 @@ def __init__(
self.eps = eps
self.use_weight = use_weight
self.use_bias = use_bias

with jax.numpy_dtype_promotion("standard"):
dtype = jnp.result_type(dtype, jnp.float32)

self.weight = jnp.ones(shape, dtype=dtype) if use_weight else None
self.bias = jnp.zeros(shape, dtype=dtype) if use_bias else None

Expand Down Expand Up @@ -377,17 +400,20 @@ def __call__(
"to replace `rms_norm(x)` with `jax.vmap(rms_norm)(x)`.\n"
)

orig_dtype = x.dtype

with jax.numpy_dtype_promotion("standard"):
dtype = jnp.result_type(x.dtype, jnp.float32)

inv_rms = jax.lax.rsqrt(jnp.mean(x.astype(dtype) ** 2) + self.eps)
out = (inv_rms * x.astype(dtype)).astype(x.dtype)
x = x.astype(dtype)
inv_rms = jax.lax.rsqrt(jnp.mean(x**2) + self.eps)
out = inv_rms * x

if self.use_weight:
out = self.weight * out
if self.use_bias:
out = out + self.bias
if state is sentinel:
return out
return out.astype(orig_dtype)
else:
return out, state
return out.astype(orig_dtype), state
8 changes: 8 additions & 0 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,12 +886,20 @@ def test_layer_norm(getkey):
assert jnp.allclose(ln(x1), ln(x2), atol=1e-4)
assert jnp.allclose(ln(x1), x3, atol=1e-4)

ln = eqx.nn.LayerNorm(128, dtype=jnp.bfloat16)
x = jrandom.uniform(getkey(), (128,), dtype=jnp.bfloat16)
assert ln(x).dtype == jnp.bfloat16


def test_group_norm(getkey):
gn = eqx.nn.GroupNorm(groups=4, channels=128)
x = jrandom.uniform(getkey(), (128,))
assert gn(x).shape == (128,)

gn = eqx.nn.GroupNorm(groups=4, channels=128, dtype=jnp.bfloat16)
x = jrandom.uniform(getkey(), (128,), dtype=jnp.bfloat16)
assert gn(x).dtype == jnp.bfloat16

gn = eqx.nn.GroupNorm(groups=4, channels=128)
x = jrandom.uniform(getkey(), (128, 4, 5))
assert gn(x).shape == (128, 4, 5)
Expand Down
Loading