Skip to content

Commit

Permalink
[JAX] Replace uses of jnp.array in types with jnp.ndarray.
Browse files Browse the repository at this point in the history
`jnp.array` is a function, not a type:
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html
so it never makes sense to use `jnp.array` in a type annotation.

Presumably the intent was to write `jnp.ndarray` aka `jax.Array`. Change uses of `jnp.array` to `jnp.ndarray`.

PiperOrigin-RevId: 556827776
  • Loading branch information
hawkinsp authored and OptaxDev committed Aug 14, 2023
1 parent 302edff commit 1b23e56
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
8 changes: 5 additions & 3 deletions optax/_src/privacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
# ==============================================================================
"""Differential Privacy utilities."""

from typing import NamedTuple
from typing import Any, NamedTuple

import jax
import jax.numpy as jnp

from optax._src import base
from optax._src import clipping
Expand All @@ -26,7 +25,10 @@
# pylint:disable=no-value-for-parameter
class DifferentiallyPrivateAggregateState(NamedTuple):
"""State containing PRNGKey for `differentially_private_aggregate`."""
rng_key: jnp.array
# TODO(optax-dev): rng_key used to be annotated as `jnp.array` but that is
# not a valid annotation (it's a function and secretely resolved to `Any`).
# We should add back typing.
rng_key: Any


def differentially_private_aggregate(
Expand Down
9 changes: 6 additions & 3 deletions optax/_src/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,12 @@ class ApplyIfFiniteState(NamedTuple):
a NaN since this optimizer was initialised. This number is never reset.
inner_state: The state of the inner `GradientTransformation`.
"""
notfinite_count: jnp.array
last_finite: jnp.array
total_notfinite: jnp.array
# TODO(optax-dev): notfinite_count, last_finite and inner_state used to be
# annotated as `jnp.array` but that is not a valid annotation (it's a function
# and secretely resolved to `Any`. We should add back typing.
notfinite_count: Any
last_finite: Any
total_notfinite: Any
inner_state: Any


Expand Down

0 comments on commit 1b23e56

Please sign in to comment.