Skip to content

Commit

Permalink
[LSC] change uses of jax.random.KeyArray and jax.random.PRNGKeyArray …
Browse files Browse the repository at this point in the history
…to jax.Array

This change replaces uses of jax.random.KeyArray and jax.random.PRNGKeyArray in the context of type annotations with jax.Array, which is the correct annotation for JAX PRNG keys moving forward.

The purpose of this change is to remove references to KeyArray and PRNGKeyArray, which are deprecated (jax-ml/jax#17594) and will soon be removed from JAX. The design and thought process behind this is described in https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html.

Note that KeyArray and PRNGKeyArray have always been aliased to Any, so the new type annotation is far more specific than the old one.

PiperOrigin-RevId: 574248274
  • Loading branch information
Jake VanderPlas authored and The swirl_dynamics Authors committed Oct 17, 2023
1 parent 05dab19 commit bb99a5b
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions swirl_dynamics/lib/diffusion/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,13 +293,13 @@ def exponential_noise_schedule(

class NoiseLevelSampling(Protocol):

def __call__(self, rng: jax.random.KeyArray, shape: tuple[int, ...]) -> Array:
def __call__(self, rng: jax.Array, shape: tuple[int, ...]) -> Array:
"""Samples noise levels for training."""
...


def _uniform_samples(
rng: jax.random.KeyArray,
rng: jax.Array,
shape: tuple[int, ...],
uniform_grid: bool,
) -> Array:
Expand All @@ -319,7 +319,7 @@ def log_uniform_sampling(
"""Samples noise whose natural log follows a uniform distribution."""

def _noise_sampling(
rng: jax.random.KeyArray, shape: tuple[int, ...]
rng: jax.Array, shape: tuple[int, ...]
) -> Array:
samples = _uniform_samples(rng, shape, uniform_grid)
log_min, log_max = jnp.log(clip_min), jnp.log(scheme.sigma_max)
Expand All @@ -335,7 +335,7 @@ def time_uniform_sampling(
"""Samples noise from a uniform distribution in t."""

def _noise_sampling(
rng: jax.random.KeyArray, shape: tuple[int, ...]
rng: jax.Array, shape: tuple[int, ...]
) -> Array:
samples = _uniform_samples(rng, shape, uniform_grid)
min_t = scheme.sigma.inverse(clip_min)
Expand Down Expand Up @@ -367,7 +367,7 @@ def normal_sampling(
A normal sampling function.
"""

def _noise_sampler(rng: jax.random.KeyArray, shape: tuple[int, ...]) -> Array:
def _noise_sampler(rng: jax.Array, shape: tuple[int, ...]) -> Array:
log_sigma = jax.random.normal(rng, shape, dtype=jnp.float32)
log_sigma = p_mean + p_std * log_sigma
return jnp.clip(jnp.exp(log_sigma), clip_min, scheme.sigma_max)
Expand Down

0 comments on commit bb99a5b

Please sign in to comment.