diff --git a/optax/__init__.py b/optax/__init__.py index 6e7ff09d6..f80121f5f 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -14,8 +14,8 @@ # ============================================================================== """Optax: composable gradient processing and optimization, in JAX.""" -from optax import contrib from optax import experimental +from optax._src import contrib from optax._src.alias import adabelief from optax._src.alias import adafactor from optax._src.alias import adagrad diff --git a/optax/contrib/__init__.py b/optax/_src/contrib/__init__.py similarity index 100% rename from optax/contrib/__init__.py rename to optax/_src/contrib/__init__.py