Skip to content

Commit

Permalink
DoG (https://arxiv.org/pdf/2302.12022.pdf) learning rate-free optimiz…
Browse files Browse the repository at this point in the history
…ation algorithm.

PiperOrigin-RevId: 557271064
  • Loading branch information
OptaxDev committed Aug 15, 2023
1 parent 1b23e56 commit e5261c8
Showing 1 changed file with 76 additions and 1 deletion.
77 changes: 76 additions & 1 deletion optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,6 +1172,82 @@ def update_fn(updates, state, params=None):
return base.GradientTransformation(init_fn, update_fn)


class ScaleByBlockwiseDogState(NamedTuple):
"""State for scale_by_blockwise_dog."""

max_dist: base.OptState
g_sos: base.OptState
init_parms: base.OptState


def scale_by_blockwise_dog(eps=1e-4, param_dtype=jnp.float32, global_scale=1.0):
"""Blockwise implementation of the DoG hyperparameter-free optimizer.
By "blockwise", we mean applying the DoG learning rate estimator to each
parameter array independently ("L-DoG" in the paper).
The DoG state is a single copy of the model parameters.
References:
["DoG is SGD’s Best Friend: A Parameter-Free Dynamic Step Size
Schedule"](https://arxiv.org/pdf/2302.12022.pdf)
Args:
eps: Small loading term to avoid zero learning rates and divide-by-zero
errors.
param_dtype: dtype for storing initial parameters.
global_scale: Global scale factor, typically 1.0 or -1.0
Returns:
A `GradientTransformation` object.
"""

def _l2(x, y):
return jnp.sqrt(jnp.square(x - y).sum())

def _scalar_copy(tree):
"""Make a scalar for each leaf in the tree."""
return jax.tree_util.tree_map(lambda x: jnp.zeros(1), tree)

def init_fn(params):
return ScaleByBlockwiseDogState(
_scalar_copy(params),
_scalar_copy(params),
jax.tree_map(lambda x: x.astype(param_dtype), params),
)

def update_fn(updates, state, params):
# update max distance
max_dist = jax.tree_map(
lambda d, x, y: jnp.maximum(d, _l2(x, y)),
state.max_dist,
params,
state.init_params,
)

# update sum-of-squares
g_sos = jax.tree_map(
lambda x, y: x + jnp.square(y).sum(), state.g_sos, updates
)

def _clip(x):
return jnp.maximum(eps, x)

def _tx(g, d, g_sos):
"""Apply the transformation."""
eta = global_scale * (_clip(d) / _clip(jnp.sqrt(g_sos)))[0]
return eta * g

updates = jax.tree_map(_tx, max_dist, g_sos, updates)

# new state
state = ScaleByBlockwiseDogState(max_dist, g_sos, state.init_params)

return updates, state

return base.GradientTransformation(init_fn, update_fn)


# TODO(b/183800387): remove legacy aliases.
# These legacy aliases are here for checkpoint compatibility
# To be removed once checkpoints have updated.
Expand All @@ -1181,4 +1257,3 @@ def update_fn(updates, state, params=None):
additive_weight_decay = add_decayed_weights
ClipState = clipping.ClipState
ClipByGlobalNormState = clipping.ClipByGlobalNormState

0 comments on commit e5261c8

Please sign in to comment.