diff --git a/optax/_src/control_variates.py b/optax/_src/control_variates.py index 9c4485c40..2ea129061 100644 --- a/optax/_src/control_variates.py +++ b/optax/_src/control_variates.py @@ -210,17 +210,18 @@ def _map(cv, params, samples, state): def control_variates_jacobians( function: Callable[[chex.Array], float], - control_variate_from_function: Callable[[Callable[[chex.Array], float]], - ControlVariate], - grad_estimator: Callable[..., jnp.array], + control_variate_from_function: Callable[ + [Callable[[chex.Array], float]], ControlVariate + ], + grad_estimator: Callable[..., jnp.ndarray], params: base.Params, dist_builder: Callable[..., Any], rng: chex.PRNGKey, num_samples: int, control_variate_state: CvState = None, estimate_cv_coeffs: bool = False, - estimate_cv_coeffs_num_samples: int = 20) -> Tuple[ - Sequence[chex.Array], CvState]: + estimate_cv_coeffs_num_samples: int = 20, +) -> Tuple[Sequence[chex.Array], CvState]: r"""Obtain jacobians using control variates. We will compute each term individually. The first term will use stochastic @@ -338,15 +339,17 @@ def param_fn(x): def estimate_control_variate_coefficients( function: Callable[[chex.Array], float], - control_variate_from_function: Callable[[Callable[[chex.Array], float]], - ControlVariate], - grad_estimator: Callable[..., jnp.array], + control_variate_from_function: Callable[ + [Callable[[chex.Array], float]], ControlVariate + ], + grad_estimator: Callable[..., jnp.ndarray], params: base.Params, dist_builder: Callable[..., Any], rng: chex.PRNGKey, num_samples: int, control_variate_state: CvState = None, - eps: float = 1e-3) -> Sequence[float]: + eps: float = 1e-3, +) -> Sequence[float]: r"""Estimates the control variate coefficients for the given parameters. For each variable `var_k`, the coefficient is given by: