diff --git a/equinox/nn/_normalisation.py b/equinox/nn/_normalisation.py index 9405da7d..54743bf4 100644 --- a/equinox/nn/_normalisation.py +++ b/equinox/nn/_normalisation.py @@ -120,7 +120,7 @@ def __call__( """**Arguments:** - `x`: A JAX array, with the same shape as the `shape` passed to `__init__`. - - `state`: Ignored; provided for interchangability with the + - `state`: Ignored; provided for interchangeability with the [`equinox.nn.BatchNorm`][] API. - `key`: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.) @@ -155,9 +155,9 @@ def __call__( inv = jax.lax.rsqrt(variance + self.eps) out = (x - mean) * inv if self.use_weight: - out = self.weight * out + out = self.weight.astype(dtype) * out # pyright: ignore if self.use_bias: - out = out + self.bias + out = out + self.bias.astype(dtype) # pyright: ignore if state is sentinel: return out.astype(orig_dtype) else: @@ -231,10 +231,6 @@ def __init__( self.channels = channels self.eps = eps self.channelwise_affine = channelwise_affine - - with jax.numpy_dtype_promotion("standard"): - dtype = jnp.result_type(dtype, jnp.float32) - self.weight = jnp.ones(channels, dtype=dtype) if channelwise_affine else None self.bias = jnp.zeros(channels, dtype=dtype) if channelwise_affine else None @@ -283,7 +279,7 @@ def __call__( if self.channelwise_affine: weight = left_broadcast_to(self.weight, out.shape) # pyright: ignore bias = left_broadcast_to(self.bias, out.shape) # pyright: ignore - out = weight * out + bias + out = weight.astype(dtype) * out + bias.astype(dtype) if state is sentinel: return out.astype(orig_dtype) else: @@ -354,10 +350,6 @@ def __init__( self.eps = eps self.use_weight = use_weight self.use_bias = use_bias - - with jax.numpy_dtype_promotion("standard"): - dtype = jnp.result_type(dtype, jnp.float32) - self.weight = jnp.ones(shape, dtype=dtype) if use_weight else None self.bias = jnp.zeros(shape, dtype=dtype) if use_bias else None @@ -410,9 +402,9 @@ def __call__( out = inv_rms * x if self.use_weight: - out = self.weight * out + out = self.weight.astype(dtype) * out # pyright: ignore if self.use_bias: - out = out + self.bias + out = out + self.bias.astype(dtype) # pyright: ignore if state is sentinel: return out.astype(orig_dtype) else: