From 92aad90bb5c150e6d82755c095fff37a54f035a0 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 11 Aug 2023 12:13:52 -0700 Subject: [PATCH] [JAX] Fix incorrect type annotations. An upcoming change to JAX will teach pytype more accurate types for functions in the jax.numpy module. This reveals a number of type errors in downstream users of JAX. In particular, pytype is able to infer `jax.Array` accurately as a type in many more cases. PiperOrigin-RevId: 556042678 --- optax/_src/control_variates.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/_src/control_variates.py b/optax/_src/control_variates.py index bf7ac2b4d..9c4485c40 100644 --- a/optax/_src/control_variates.py +++ b/optax/_src/control_variates.py @@ -111,7 +111,7 @@ def delta( return control_variate def expected_value_delta( - params: base.Params, state: CvState) -> float: + params: base.Params, state: CvState) -> jax.Array: """"Expected value of second order expansion of `function` at dist mean.""" del state mean_dist = params[0]