Skip to content

Commit

Permalink
remove inlined jax.nn.initializers definitions, resolving TODO of lev…
Browse files Browse the repository at this point in the history
…skaya et al

fixes breakage from cl/655766534 aka jax-ml/jax#21069

PiperOrigin-RevId: 655806010
  • Loading branch information
mattjj authored and t5-copybara committed Jul 25, 2024
1 parent dcdcda9 commit 31410cc
Showing 1 changed file with 2 additions and 59 deletions.
61 changes: 2 additions & 59 deletions t5x/examples/scalable_t5/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,64 +52,7 @@
1.0, 'fan_in', 'normal', out_axis=0
)


# ------------------------------------------------------------------------------
# Temporary inlined JAX N-d initializer code
# TODO(levskaya): remove once new JAX release is out.
# ------------------------------------------------------------------------------
def _compute_fans(shape: jax.core.NamedShape, in_axis=-2, out_axis=-1):
"""Inlined JAX `nn.initializer._compute_fans`."""
if isinstance(in_axis, int):
in_size = shape[in_axis]
else:
in_size = int(np.prod([shape[i] for i in in_axis]))
if isinstance(out_axis, int):
out_size = shape[out_axis]
else:
out_size = int(np.prod([shape[i] for i in out_axis]))
receptive_field_size = shape.total / in_size / out_size
fan_in = in_size * receptive_field_size
fan_out = out_size * receptive_field_size
return fan_in, fan_out


def variance_scaling(
scale, mode, distribution, in_axis=-2, out_axis=-1, dtype=jnp.float_
):
"""Inlined JAX `nn.initializer.variance_scaling`."""

def init(key, shape, dtype=dtype):
dtype = jax.dtypes.canonicalize_dtype(dtype)
shape = jax.core.as_named_shape(shape)
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
if mode == 'fan_in':
denominator = fan_in
elif mode == 'fan_out':
denominator = fan_out
elif mode == 'fan_avg':
denominator = (fan_in + fan_out) / 2
else:
raise ValueError(
'invalid mode for variance scaling initializer: {}'.format(mode)
)
variance = jnp.array(scale / denominator, dtype=dtype)

if distribution == 'truncated_normal':
# constant is stddev of standard normal truncated to (-2, 2)
stddev = jnp.sqrt(variance) / jnp.array(0.87962566103423978, dtype)
return random.truncated_normal(key, -2, 2, shape, dtype) * stddev
elif distribution == 'normal':
return random.normal(key, shape, dtype) * jnp.sqrt(variance)
elif distribution == 'uniform':
return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance)
else:
raise ValueError(
'invalid distribution for variance scaling initializer: {}'.format(
distribution
)
)

return init
variance_scaling = nn.initializers.variance_scaling


# ------------------------------------------------------------------------------
Expand Down Expand Up @@ -420,7 +363,7 @@ def __call__(
return out


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int, ...]:
# A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
return tuple([ax if ax >= 0 else ndim + ax for ax in axes])

Expand Down

0 comments on commit 31410cc

Please sign in to comment.