From a62fd7448500a34a6c3a2ff2e5e1d9cccf31ef1a Mon Sep 17 00:00:00 2001 From: Fabian Pedregosa Date: Thu, 28 Mar 2024 09:24:07 +0100 Subject: [PATCH] Simplification for transforms with empty state --- optax/_src/transform.py | 32 ++++---------------------------- 1 file changed, 4 insertions(+), 28 deletions(-) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index c8f92422..a4a78890 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -885,10 +885,6 @@ def update_fn(updates, state, params=None): return base.GradientTransformation(init_fn, update_fn) - -AddDecayedWeightsState = base.EmptyState - - def add_decayed_weights( weight_decay: Union[float, jax.Array] = 0.0, mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None @@ -906,10 +902,6 @@ def add_decayed_weights( A `GradientTransformation` object. """ - def init_fn(params): - del params - return AddDecayedWeightsState() - def update_fn(updates, state, params): if params is None: raise ValueError(base.NO_PARAMS_MSG) @@ -921,8 +913,8 @@ def update_fn(updates, state, params): # E.g. it is common to skip weight decay on bias units and batch stats. if mask is not None: return wrappers.masked( - base.GradientTransformation(init_fn, update_fn), mask) - return base.GradientTransformation(init_fn, update_fn) + base.GradientTransformation(_init_empty_state, update_fn), mask) + return base.GradientTransformation(_init_empty_state, update_fn) class ScaleByScheduleState(NamedTuple): @@ -982,10 +974,6 @@ def update_fn(updates, state, params=None): return base.GradientTransformation(init_fn, update_fn) -class ScaleByTrustRatioState(NamedTuple): - """The scale and decay trust ratio transformation is stateless.""" - - def scale_by_trust_ratio( min_norm: float = 0.0, trust_coefficient: float = 1., @@ -1005,10 +993,6 @@ def scale_by_trust_ratio( A `GradientTransformation` object. """ - def init_fn(params): - del params - return ScaleByTrustRatioState() - def update_fn(updates, state, params): if params is None: raise ValueError(base.NO_PARAMS_MSG) @@ -1031,7 +1015,7 @@ def _scale_update(update, param): updates = jax.tree_util.tree_map(_scale_update, updates, params) return updates, state - return base.GradientTransformation(init_fn, update_fn) + return base.GradientTransformation(_init_empty_state, update_fn) class AddNoiseState(NamedTuple): @@ -1131,10 +1115,6 @@ def _subtract_mean(g): else: return g - -CentralState = base.EmptyState - - def centralize() -> base.GradientTransformation: """Centralize gradients. @@ -1145,16 +1125,12 @@ def centralize() -> base.GradientTransformation: A `GradientTransformation` object. """ - def init_fn(params): - del params - return CentralState() - def update_fn(updates, state, params=None): del params updates = jax.tree_util.tree_map(_subtract_mean, updates) return updates, state - return base.GradientTransformation(init_fn, update_fn) + return base.GradientTransformation(_init_empty_state, update_fn) class ScaleBySM3State(NamedTuple):