Skip to content

Commit

Permalink
Simplification for transforms with empty state
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianp authored Mar 28, 2024
1 parent bb5d0c2 commit a62fd74
Showing 1 changed file with 4 additions and 28 deletions.
32 changes: 4 additions & 28 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,10 +885,6 @@ def update_fn(updates, state, params=None):

return base.GradientTransformation(init_fn, update_fn)


AddDecayedWeightsState = base.EmptyState


def add_decayed_weights(
weight_decay: Union[float, jax.Array] = 0.0,
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None
Expand All @@ -906,10 +902,6 @@ def add_decayed_weights(
A `GradientTransformation` object.
"""

def init_fn(params):
del params
return AddDecayedWeightsState()

def update_fn(updates, state, params):
if params is None:
raise ValueError(base.NO_PARAMS_MSG)
Expand All @@ -921,8 +913,8 @@ def update_fn(updates, state, params):
# E.g. it is common to skip weight decay on bias units and batch stats.
if mask is not None:
return wrappers.masked(
base.GradientTransformation(init_fn, update_fn), mask)
return base.GradientTransformation(init_fn, update_fn)
base.GradientTransformation(_init_empty_state, update_fn), mask)
return base.GradientTransformation(_init_empty_state, update_fn)


class ScaleByScheduleState(NamedTuple):
Expand Down Expand Up @@ -982,10 +974,6 @@ def update_fn(updates, state, params=None):
return base.GradientTransformation(init_fn, update_fn)


class ScaleByTrustRatioState(NamedTuple):
"""The scale and decay trust ratio transformation is stateless."""


def scale_by_trust_ratio(
min_norm: float = 0.0,
trust_coefficient: float = 1.,
Expand All @@ -1005,10 +993,6 @@ def scale_by_trust_ratio(
A `GradientTransformation` object.
"""

def init_fn(params):
del params
return ScaleByTrustRatioState()

def update_fn(updates, state, params):
if params is None:
raise ValueError(base.NO_PARAMS_MSG)
Expand All @@ -1031,7 +1015,7 @@ def _scale_update(update, param):
updates = jax.tree_util.tree_map(_scale_update, updates, params)
return updates, state

return base.GradientTransformation(init_fn, update_fn)
return base.GradientTransformation(_init_empty_state, update_fn)


class AddNoiseState(NamedTuple):
Expand Down Expand Up @@ -1131,10 +1115,6 @@ def _subtract_mean(g):
else:
return g


CentralState = base.EmptyState


def centralize() -> base.GradientTransformation:
"""Centralize gradients.
Expand All @@ -1145,16 +1125,12 @@ def centralize() -> base.GradientTransformation:
A `GradientTransformation` object.
"""

def init_fn(params):
del params
return CentralState()

def update_fn(updates, state, params=None):
del params
updates = jax.tree_util.tree_map(_subtract_mean, updates)
return updates, state

return base.GradientTransformation(init_fn, update_fn)
return base.GradientTransformation(_init_empty_state, update_fn)


class ScaleBySM3State(NamedTuple):
Expand Down

0 comments on commit a62fd74

Please sign in to comment.