Skip to content

Commit

Permalink
[JAX] Fix incorrect type annotations.
Browse files Browse the repository at this point in the history
An upcoming change to JAX will teach pytype more accurate types for functions in the jax.numpy module. This reveals a number of type errors in downstream users of JAX. In particular, pytype is able to infer `jax.Array` accurately as a type in many more cases.

PiperOrigin-RevId: 556042678
  • Loading branch information
hawkinsp authored and OptaxDev committed Aug 14, 2023
1 parent cebdeff commit 92aad90
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion optax/_src/control_variates.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def delta(
return control_variate

def expected_value_delta(
params: base.Params, state: CvState) -> float:
params: base.Params, state: CvState) -> jax.Array:
""""Expected value of second order expansion of `function` at dist mean."""
del state
mean_dist = params[0]
Expand Down

0 comments on commit 92aad90

Please sign in to comment.