Skip to content

Commit

Permalink
t5x: remove references to deprecated jax.random.KeyArray
Browse files Browse the repository at this point in the history
The correct annotation for PRNG keys is `jax.Array` (see [JEP 9263](https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html)). Note that `jax.random.KeyArray` has always been aliased to `Any` for type checking, so changing this to `jax.Array` makes the annotations more meaningful.

PiperOrigin-RevId: 570443873
  • Loading branch information
Jake VanderPlas authored and t5-copybara committed Oct 3, 2023
1 parent b051e46 commit 5107638
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 22 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
with open('README.md') as fp:
_LONG_DESCRIPTION = fp.read()

_jax_version = '0.4.11'
_jaxlib_version = '0.4.11'
_jax_version = '0.4.16'
_jaxlib_version = '0.4.16'

setuptools.setup(
name='t5x',
Expand Down
16 changes: 11 additions & 5 deletions t5x/contrib/calm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,16 @@ class DecodeFnCallable(typing_extensions.Protocol):
"""Decoding function call signature."""

def __call__(
self, *, inputs: jnp.ndarray, cache: Mapping[str, jnp.ndarray],
tokens_to_logits: TokensIdsToLogitsCallable, eos_id: int,
num_decodes: int, decode_rng: Optional[jax.random.KeyArray],
cache_offset: int, **kwargs
self,
*,
inputs: jnp.ndarray,
cache: Mapping[str, jnp.ndarray],
tokens_to_logits: TokensIdsToLogitsCallable,
eos_id: int,
num_decodes: int,
decode_rng: Optional[jax.Array],
cache_offset: int,
**kwargs,
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]:
"""Decoding function interface.
Expand Down Expand Up @@ -790,7 +796,7 @@ def predict_batch_with_aux(
self,
params: PyTree,
batch: Mapping[str, jnp.ndarray],
rng: Optional[jax.random.KeyArray] = None,
rng: Optional[jax.Array] = None,
decoder_params: Optional[MutableMapping[str, Any]] = None,
return_all_decodes: bool = False,
num_decodes: int = 1,
Expand Down
4 changes: 2 additions & 2 deletions t5x/contrib/moe/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def predict_batch_with_aux( # pylint: disable=useless-super-delegation
self,
params: PyTree,
batch: Mapping[str, jnp.ndarray],
rng: Optional[jax.random.KeyArray] = None,
rng: Optional[jax.Array] = None,
decoder_params: Optional[MutableMapping[str, Any]] = None,
return_all_decodes: bool = False,
num_decodes: int = 1,
Expand Down Expand Up @@ -221,7 +221,7 @@ def predict_batch_with_aux( # pylint: disable=useless-super-delegation
self,
params: PyTree,
batch: Mapping[str, jnp.ndarray],
rng: Optional[jax.random.KeyArray] = None,
rng: Optional[jax.Array] = None,
*,
return_all_decodes: bool = False,
num_decodes: int = 1,
Expand Down
26 changes: 13 additions & 13 deletions t5x/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __call__(
tokens_to_logits: TokensIdsToLogitsCallable,
eos_id: int,
num_decodes: int,
decode_rng: Optional[jax.random.KeyArray],
decode_rng: Optional[jax.Array],
cache_offset: int,
**kwargs,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
Expand Down Expand Up @@ -134,7 +134,7 @@ def loss_fn(
self,
params: PyTree,
batch: Mapping[str, jnp.ndarray],
dropout_rng: Optional[jax.random.KeyArray],
dropout_rng: Optional[jax.Array],
) -> Tuple[jnp.ndarray, MetricsMap]:
"""Computes loss and metrics.
Expand Down Expand Up @@ -178,7 +178,7 @@ def predict_batch(
self,
params: PyTree,
batch: Mapping[str, jnp.ndarray],
rng: Optional[jax.random.KeyArray] = None,
rng: Optional[jax.Array] = None,
) -> jnp.ndarray:
"""Predicts a batch of outputs from the model.
Expand All @@ -197,7 +197,7 @@ def predict_batch_with_aux(
self,
params: PyTree,
batch: Mapping[str, jnp.ndarray],
rng: Optional[jax.random.KeyArray] = None,
rng: Optional[jax.Array] = None,
) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]:
"""Predict a batch from the modelwith auxiliary outputs.
Expand Down Expand Up @@ -225,7 +225,7 @@ def score_batch(
@abc.abstractmethod
def get_initial_variables(
self,
rng: jax.random.KeyArray,
rng: jax.Array,
input_shapes: Mapping[str, Array],
input_types: Optional[Mapping[str, jnp.dtype]] = None,
) -> flax_scope.FrozenVariableDict:
Expand Down Expand Up @@ -280,7 +280,7 @@ def _compute_logits(
self,
params: PyTree,
batch: Mapping[str, jnp.ndarray],
dropout_rng: Optional[jax.random.KeyArray] = None,
dropout_rng: Optional[jax.Array] = None,
) -> jnp.ndarray:
"""Computes logits via a forward pass of the model."""
pass
Expand All @@ -289,7 +289,7 @@ def loss_fn(
self,
params: PyTree,
batch: Mapping[str, jnp.ndarray],
dropout_rng: Optional[jax.random.KeyArray],
dropout_rng: Optional[jax.Array],
) -> Tuple[jnp.ndarray, MetricsMap]:
"""Loss function used for training with a cross-entropy loss."""
logits = self._compute_logits(params, batch, dropout_rng)
Expand Down Expand Up @@ -401,7 +401,7 @@ def __init__(

def get_initial_variables(
self,
rng: jax.random.KeyArray,
rng: jax.Array,
input_shapes: Mapping[str, Array],
input_types: Optional[Mapping[str, jnp.dtype]] = None,
) -> flax_scope.FrozenVariableDict:
Expand Down Expand Up @@ -459,7 +459,7 @@ def _compute_logits( # pytype: disable=signature-mismatch # jax-ndarray
self,
params: PyTree,
batch: Mapping[str, jnp.ndarray],
dropout_rng: Optional[jax.random.KeyArray] = None,
dropout_rng: Optional[jax.Array] = None,
mutable: flax_scope.CollectionFilter = False,
other_variables: Optional[PyTree] = None,
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, flax_scope.FrozenVariableDict]]:
Expand Down Expand Up @@ -587,7 +587,7 @@ def predict_batch_with_aux(
self,
params: PyTree,
batch: Mapping[str, jnp.ndarray],
rng: Optional[jax.random.KeyArray] = None,
rng: Optional[jax.Array] = None,
decoder_params: Optional[MutableMapping[str, Any]] = None,
return_all_decodes: bool = None,
num_decodes: int = None, # pytype:disable=annotation-type-mismatch
Expand Down Expand Up @@ -864,7 +864,7 @@ def __init__(

def get_initial_variables(
self,
rng: jax.random.KeyArray,
rng: jax.Array,
input_shapes: Mapping[str, Array],
input_types: Optional[Mapping[str, jnp.dtype]] = None,
) -> flax_scope.FrozenVariableDict:
Expand Down Expand Up @@ -898,7 +898,7 @@ def _compute_logits(
self,
params: PyTree,
batch: Mapping[str, jnp.ndarray],
dropout_rng: Optional[jax.random.KeyArray] = None,
dropout_rng: Optional[jax.Array] = None,
mutable: flax_scope.CollectionFilter = False,
other_variables: Optional[PyTree] = None,
) -> jnp.ndarray:
Expand Down Expand Up @@ -1090,7 +1090,7 @@ def predict_batch_with_aux(
self,
params: PyTree,
batch: Mapping[str, jnp.ndarray],
rng: Optional[jax.random.KeyArray] = None,
rng: Optional[jax.Array] = None,
*,
return_all_decodes: bool = False,
num_decodes: int = 1,
Expand Down

0 comments on commit 5107638

Please sign in to comment.