Skip to content

Commit

Permalink
ignore false pywright type warning
Browse files Browse the repository at this point in the history
  • Loading branch information
AakashKumarNain committed Oct 13, 2024
1 parent 8c538a0 commit abb5703
Showing 1 changed file with 6 additions and 14 deletions.
20 changes: 6 additions & 14 deletions equinox/nn/_normalisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __call__(
"""**Arguments:**
- `x`: A JAX array, with the same shape as the `shape` passed to `__init__`.
- `state`: Ignored; provided for interchangability with the
- `state`: Ignored; provided for interchangeability with the
[`equinox.nn.BatchNorm`][] API.
- `key`: Ignored; provided for compatibility with the rest of the Equinox API.
(Keyword only argument.)
Expand Down Expand Up @@ -155,9 +155,9 @@ def __call__(
inv = jax.lax.rsqrt(variance + self.eps)
out = (x - mean) * inv
if self.use_weight:
out = self.weight * out
out = self.weight.astype(dtype) * out # pyright: ignore
if self.use_bias:
out = out + self.bias
out = out + self.bias.astype(dtype) # pyright: ignore
if state is sentinel:
return out.astype(orig_dtype)
else:
Expand Down Expand Up @@ -231,10 +231,6 @@ def __init__(
self.channels = channels
self.eps = eps
self.channelwise_affine = channelwise_affine

with jax.numpy_dtype_promotion("standard"):
dtype = jnp.result_type(dtype, jnp.float32)

self.weight = jnp.ones(channels, dtype=dtype) if channelwise_affine else None
self.bias = jnp.zeros(channels, dtype=dtype) if channelwise_affine else None

Expand Down Expand Up @@ -283,7 +279,7 @@ def __call__(
if self.channelwise_affine:
weight = left_broadcast_to(self.weight, out.shape) # pyright: ignore
bias = left_broadcast_to(self.bias, out.shape) # pyright: ignore
out = weight * out + bias
out = weight.astype(dtype) * out + bias.astype(dtype)
if state is sentinel:
return out.astype(orig_dtype)
else:
Expand Down Expand Up @@ -354,10 +350,6 @@ def __init__(
self.eps = eps
self.use_weight = use_weight
self.use_bias = use_bias

with jax.numpy_dtype_promotion("standard"):
dtype = jnp.result_type(dtype, jnp.float32)

self.weight = jnp.ones(shape, dtype=dtype) if use_weight else None
self.bias = jnp.zeros(shape, dtype=dtype) if use_bias else None

Expand Down Expand Up @@ -410,9 +402,9 @@ def __call__(
out = inv_rms * x

if self.use_weight:
out = self.weight * out
out = self.weight.astype(dtype) * out # pyright: ignore
if self.use_bias:
out = out + self.bias
out = out + self.bias.astype(dtype) # pyright: ignore
if state is sentinel:
return out.astype(orig_dtype)
else:
Expand Down

0 comments on commit abb5703

Please sign in to comment.