From 443eef39620806ae82072b934bedec82cd2dc8e5 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Mon, 9 Dec 2024 11:28:26 -0800 Subject: [PATCH] Exposing named_chain in docs --- docs/api/combining_optimizers.rst | 2 ++ optax/transforms/_combining.py | 51 ++++++++++++++++++++----------- 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/docs/api/combining_optimizers.rst b/docs/api/combining_optimizers.rst index 48ff3aa49..1e9405308 100644 --- a/docs/api/combining_optimizers.rst +++ b/docs/api/combining_optimizers.rst @@ -5,11 +5,13 @@ Combining Optimizers .. autosummary:: chain + named_chain multi_transform Chain ~~~~~ .. autofunction:: chain +.. autofunction:: named_chain Multi-transform ~~~~~~~~~~~~~~~ diff --git a/optax/transforms/_combining.py b/optax/transforms/_combining.py index 364031b9a..b2839e732 100644 --- a/optax/transforms/_combining.py +++ b/optax/transforms/_combining.py @@ -34,7 +34,9 @@ def chain( updates in the given order. Args: - *args: a sequence of chainable (init_fn, update_fn) tuples. + *args: an arbitrary number of ``transform``-s of + :class:`GradientTransformation` or + :class:`GradientTransformationExtraArgs`. Returns: A :class:`GradientTransformationExtraArgs`, created by chaining the input @@ -55,6 +57,18 @@ def chain( >>> state = chained_transform.init(params) >>> updates = {'a': -0.5} >>> updates, new_state = chained_transform.update(updates, state, params) + + An optimizer in the chain might require extra args: + + >>> import optax + >>> opt1 = optax.scale(0.1) # scale incoming gradients + >>> opt2 = optax.polyak_sgd() # requires a `value` extra arg for `update` + >>> chained_transform = optax.chain(opt1, opt2) + >>> state = chained_transform.init(0.5) + >>> extra_args = {"value": 1.0} + >>> updates, new_state = chained_transform.update( + ... 0.7, state, 0.7, **extra_args # extra args for all transforms + ... ) """ transforms = [base.with_extra_args_support(t) for t in args] @@ -85,13 +99,13 @@ def update_fn(updates, state, params=None, **extra_args): def named_chain( - *transforms: tuple[str, base.GradientTransformation] + *args: tuple[str, base.GradientTransformation] ) -> base.GradientTransformationExtraArgs: - """Chains optax gradient transformations. + """Applies a list of named chainable update transformations. A variant of :func:`optax.chain` that allows to name each transformation. - Here the ``transforms`` are ``(name, transformation)`` pairs, constituted of a + Here the ``args`` are ``(name, transformation)`` pairs, constituted of a string ``name`` and an associated transformation ``transformation``. The gradient transformation must be an instance of :class:`GradientTransformation` or :class:`GradientTransformationExtraArgs`. @@ -101,26 +115,29 @@ def named_chain( with a given ``name`` can be easily retrieved as ``opt_state[name]``. Args: - *transforms: an arbitrary number of ``(name, tx)`` pairs, constituted of a - string ``name`` and an associated transformation ``tx``. The latter is a - :class:`GradientTransformation` or + *args: an arbitrary number of ``(name, transform)`` pairs, constituted of a + string ``name`` and an associated transformation ``transform``. The latter + is a :class:`GradientTransformation` or :class:`GradientTransformationExtraArgs`. Returns: A single (init_fn, update_fn) tuple. Examples: - - >>> # tx1 is a GradientTransformation with no extra_args. - >>> # tx2 is a GradientTransformationExtraArgs that requires `loss`. - >>> # tx3 is a GradientTransformationExtraArgs that requires `temperature`. - >>> tx = named_chain(('one', tx1), ('two', tx2), ('three', tx3)) - >>> extra_args={'loss': 0.3, 'temperature': 0.01} - >>> tx.init(params) - >>> tx.update(grads, state, params, **extra_args) + >>> import optax + >>> opt1 = optax.scale(0.1) # scale incoming gradients + >>> opt2 = optax.polyak_sgd() # requires a `value` extra arg for `update` + >>> chained_transform = optax.named_chain(("scale", opt1), ("sgd", opt2)) + >>> state = chained_transform.init(0.5) + >>> extra_args = {"value": 1.0} + >>> updates, new_state = chained_transform.update( + ... 0.7, state, 0.7, **extra_args # extra args for all transforms + ... ) + >>> tuple(new_state.keys()) == ("scale", "sgd") + True """ - names = [name for name, _ in transforms] + names = [name for name, _ in args] if len(names) != len(set(names)): raise ValueError( @@ -128,7 +145,7 @@ def named_chain( ) transforms = [ - (name, base.with_extra_args_support(t)) for name, t in transforms + (name, base.with_extra_args_support(t)) for name, t in args ] def init_fn(params):