diff --git a/optax/_src/privacy.py b/optax/_src/privacy.py index 4af02623..d115812c 100644 --- a/optax/_src/privacy.py +++ b/optax/_src/privacy.py @@ -14,10 +14,9 @@ # ============================================================================== """Differential Privacy utilities.""" -from typing import NamedTuple +from typing import Any, NamedTuple import jax -import jax.numpy as jnp from optax._src import base from optax._src import clipping @@ -26,7 +25,10 @@ # pylint:disable=no-value-for-parameter class DifferentiallyPrivateAggregateState(NamedTuple): """State containing PRNGKey for `differentially_private_aggregate`.""" - rng_key: jnp.array + # TODO(optax-dev): rng_key used to be annotated as `jnp.array` but that is + # not a valid annotation (it's a function and secretely resolved to `Any`). + # We should add back typing. + rng_key: Any def differentially_private_aggregate( diff --git a/optax/_src/wrappers.py b/optax/_src/wrappers.py index a2474005..f395ed32 100644 --- a/optax/_src/wrappers.py +++ b/optax/_src/wrappers.py @@ -102,9 +102,12 @@ class ApplyIfFiniteState(NamedTuple): a NaN since this optimizer was initialised. This number is never reset. inner_state: The state of the inner `GradientTransformation`. """ - notfinite_count: jnp.array - last_finite: jnp.array - total_notfinite: jnp.array + # TODO(optax-dev): notfinite_count, last_finite and inner_state used to be + # annotated as `jnp.array` but that is not a valid annotation (it's a function + # and secretely resolved to `Any`. We should add back typing. + notfinite_count: Any + last_finite: Any + total_notfinite: Any inner_state: Any