Skip to content

Commit

Permalink
Move stochastic gradient estimation functions to separate monte_carlo…
Browse files Browse the repository at this point in the history
… subpackage.

E.g. this means that instead of:

```
import optax
optax.score_function_jacobians(...)
```

You should now use:

```
from optax import monte_carlo`
monte_carlo.score_function_jacobians
```

This is part of a broader effort to restructure the current fully flat API surface, since optax has outgrown the flat structure due to broader scope and complexity.

PiperOrigin-RevId: 570042042
  • Loading branch information
mtthss authored and OptaxDev committed Oct 3, 2023
1 parent e43f9f5 commit 51cc766
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 9 deletions.
11 changes: 2 additions & 9 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from optax import contrib
from optax import experimental
from optax import monte_carlo
from optax._src.alias import adabelief
from optax._src.alias import adafactor
from optax._src.alias import adagrad
Expand Down Expand Up @@ -69,6 +70,7 @@
from optax._src.constrain import NonNegativeParamsState
from optax._src.constrain import zero_nans
from optax._src.constrain import ZeroNansState
# TODO(mtthss): remove flat reference to fns in `cv` from `optax` namespace.
from optax._src.control_variates import control_delta_method
from optax._src.control_variates import control_variates_jacobians
from optax._src.control_variates import moving_avg_baseline
Expand Down Expand Up @@ -119,9 +121,6 @@
from optax._src.second_order import hessian_diag
from optax._src.second_order import hvp
from optax._src.state_utils import tree_map_params
from optax._src.stochastic_gradient_estimators import measure_valued_jacobians
from optax._src.stochastic_gradient_estimators import pathwise_jacobians
from optax._src.stochastic_gradient_estimators import score_function_jacobians
from optax._src.transform import add_decayed_weights
from optax._src.transform import add_noise
from optax._src.transform import AddDecayedWeightsState
Expand Down Expand Up @@ -219,8 +218,6 @@
"constant_schedule",
"ctc_loss",
"ctc_loss_with_forward_probs",
"control_delta_method",
"control_variates_jacobians",
"convex_kl_divergence",
"cosine_decay_schedule",
"cosine_distance",
Expand Down Expand Up @@ -267,8 +264,6 @@
"matrix_inverse_pth_root",
"maybe_update",
"MaybeUpdateState",
"measure_valued_jacobians",
"moving_avg_baseline",
"multi_normal",
"multi_transform",
"MultiSteps",
Expand All @@ -279,7 +274,6 @@
"NonNegativeParamsState",
"OptState",
"Params",
"pathwise_jacobians",
"periodic_update",
"per_example_global_norm_clip",
"piecewise_constant_schedule",
Expand Down Expand Up @@ -324,7 +318,6 @@
"ScaleByTrustRatioState",
"ScaleState",
"Schedule",
"score_function_jacobians",
"set_to_zero",
"sgd",
"sgdr_schedule",
Expand Down
22 changes: 22 additions & 0 deletions optax/monte_carlo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for efficient monte carlo gradient estimation."""

from optax._src.monte_carlo.control_variates import control_delta_method
from optax._src.monte_carlo.control_variates import control_variates_jacobians
from optax._src.monte_carlo.control_variates import moving_avg_baseline
from optax._src.monte_carlo.stochastic_gradient_estimator import measure_valued_jacobians
from optax._src.monte_carlo.stochastic_gradient_estimator import pathwise_jacobians
from optax._src.monte_carlo.stochastic_gradient_estimator import score_function_jacobians
File renamed without changes.
File renamed without changes.

0 comments on commit 51cc766

Please sign in to comment.